feat(auxiliary): self-healing fallback chain with proactive context-length checks

- Add _context_length_error() for pre-API-call context window validation
- Add _is_context_length_error() to detect context-length API errors
- Enhance _try_configured_fallback_chain() with failed_model/messages/max_tokens
- Fix skip logic: same provider + different model = valid self-healing rung
- Integrate context-length checks into call_llm() and async_call_llm()
- Trigger fallback chain on context_length errors (not just payment/connection)
This commit is contained in:
Anton Palgunov 2026-05-29 16:15:00 +00:00
parent 2517917de3
commit e5d74fa32a
2 changed files with 205 additions and 8 deletions

View File

@ -2831,16 +2831,74 @@ def _try_main_agent_model_fallback(
return client, resolved_model or main_model, label
def _coerce_positive_int(value: Any) -> Optional[int]:
try:
parsed = int(value)
except (TypeError, ValueError):
return None
return parsed if parsed > 0 else None
def _estimate_auxiliary_request_tokens(messages: list, max_tokens: Optional[int] = None) -> int:
"""Rough token estimate for local auxiliary context-window checks."""
try:
from agent.model_metadata import estimate_messages_tokens_rough
input_tokens = estimate_messages_tokens_rough(messages or [])
except Exception:
input_tokens = 0
for msg in messages or []:
content = msg.get("content", "") if isinstance(msg, dict) else str(msg)
input_tokens += max(1, len(str(content)) // 4)
return input_tokens + (_coerce_positive_int(max_tokens) or 0)
def _context_length_error(
*,
task: str,
provider: str,
model: Optional[str],
context_length: Optional[int],
messages: list,
max_tokens: Optional[int],
) -> Optional[ValueError]:
ctx = _coerce_positive_int(context_length)
if not ctx:
return None
estimated = _estimate_auxiliary_request_tokens(messages, max_tokens)
if estimated <= ctx:
return None
return ValueError(
f"Auxiliary {task or 'call'} request needs ~{estimated} tokens, "
f"exceeding configured context_length={ctx} for "
f"{provider or 'auto'}/{model or 'default'}"
)
def _is_context_length_error(exc: Exception) -> bool:
text = str(exc).lower()
return (
"context_length" in text
or "context length" in text
or "context window" in text
or "too many tokens" in text
or "exceeding configured context" in text
or "exceeds the max_model_len" in text
)
def _try_configured_fallback_chain(
task: str,
failed_provider: str,
reason: str = "error",
failed_model: Optional[str] = None,
messages: Optional[list] = None,
max_tokens: Optional[int] = None,
) -> Tuple[Optional[Any], Optional[str], str]:
"""Try user-configured fallback_chain for a specific auxiliary task.
Reads auxiliary.<task>.fallback_chain from config.yaml and tries each
entry in order. Each entry must have at least ``provider``; ``model``,
``base_url``, and ``api_key`` are optional.
``base_url``, ``api_key``, and ``context_length`` are optional.
Returns:
(client, model, provider_label) or (None, None, "") if no fallback.
@ -2853,21 +2911,46 @@ def _try_configured_fallback_chain(
if not chain or not isinstance(chain, list):
return None, None, ""
skip = failed_provider.lower().strip()
skip_provider = failed_provider.lower().strip()
skip_model = str(failed_model or "").lower().strip()
tried = []
for i, entry in enumerate(chain):
if not isinstance(entry, dict):
continue
fb_provider = str(entry.get("provider", "")).strip()
if not fb_provider or fb_provider.lower() == skip:
if not fb_provider:
continue
fb_model = str(entry.get("model", "")).strip() or None
# Skip only the exact failed provider+model pair. Same provider with a
# different model is a valid self-healing rung (e.g. opencode_go
# deepseek-v4-pro -> opencode_go gpt-5.5).
if fb_provider.lower() == skip_provider and (
not skip_model or (fb_model or "").lower() == skip_model
):
continue
fb_base_url = str(entry.get("base_url", "")).strip() or None
fb_api_key = str(entry.get("api_key", "")).strip() or None
label = f"fallback_chain[{i}]({fb_provider})"
if messages is not None:
context_err = _context_length_error(
task=task,
provider=fb_provider,
model=fb_model,
context_length=entry.get("context_length"),
messages=messages,
max_tokens=max_tokens,
)
if context_err is not None:
logger.info(
"Auxiliary %s: skipping %s (%s) because it also exceeds context_length: %s",
task, label, fb_model or "default", context_err,
)
tried.append(label)
continue
try:
fb_client = _resolve_single_provider(
fb_provider, fb_model, fb_base_url, fb_api_key)
@ -2889,7 +2972,6 @@ def _try_configured_fallback_chain(
)
return None, None, ""
def _resolve_single_provider(
provider: str,
model: Optional[str] = None,
@ -4889,6 +4971,17 @@ def call_llm(
# Handle unsupported temperature, max_tokens vs max_completion_tokens retry,
# then payment fallback.
try:
task_context = _get_auxiliary_task_config(task).get("context_length") if task else None
context_err = _context_length_error(
task=task or "call",
provider=resolved_provider,
model=final_model,
context_length=task_context,
messages=messages,
max_tokens=max_tokens,
)
if context_err is not None:
raise context_err
return _validate_llm_response(
client.chat.completions.create(**kwargs), task)
except Exception as first_err:
@ -5072,6 +5165,7 @@ def call_llm(
_is_payment_error(first_err)
or _is_connection_error(first_err)
or _is_rate_limit_error(first_err)
or _is_context_length_error(first_err)
)
# Respect explicit provider choice for transient errors (auth, request
# validation, etc.) but allow fallback when the provider clearly cannot
@ -5082,7 +5176,11 @@ def call_llm(
is_auto = resolved_provider in {"auto", "", None}
# Capacity errors bypass the explicit-provider gate: the provider
# literally cannot serve this request regardless of user intent.
is_capacity_error = _is_payment_error(first_err) or _is_connection_error(first_err)
is_capacity_error = (
_is_payment_error(first_err)
or _is_connection_error(first_err)
or _is_context_length_error(first_err)
)
if should_fallback and (is_auto or is_capacity_error):
if _is_payment_error(first_err):
reason = "payment error"
@ -5095,6 +5193,8 @@ def call_llm(
)
elif _is_rate_limit_error(first_err):
reason = "rate limit"
elif _is_context_length_error(first_err):
reason = "context length"
else:
reason = "connection error"
logger.info("Auxiliary %s: %s on %s (%s), trying fallback",
@ -5112,7 +5212,8 @@ def call_llm(
resolved_provider, task, reason=reason)
else:
fb_client, fb_model, fb_label = _try_configured_fallback_chain(
task, resolved_provider or "auto", reason=reason)
task, resolved_provider or "auto", reason=reason,
failed_model=final_model, messages=messages, max_tokens=max_tokens)
if fb_client is None:
fb_client, fb_model, fb_label = _try_main_agent_model_fallback(
resolved_provider, task, reason=reason)
@ -5295,6 +5396,17 @@ async def async_call_llm(
kwargs["messages"] = _convert_openai_images_to_anthropic(kwargs["messages"])
try:
task_context = _get_auxiliary_task_config(task).get("context_length") if task else None
context_err = _context_length_error(
task=task or "call",
provider=resolved_provider,
model=final_model,
context_length=task_context,
messages=messages,
max_tokens=max_tokens,
)
if context_err is not None:
raise context_err
return _validate_llm_response(
await client.chat.completions.create(**kwargs), task)
except Exception as first_err:
@ -5446,12 +5558,17 @@ async def async_call_llm(
_is_payment_error(first_err)
or _is_connection_error(first_err)
or _is_rate_limit_error(first_err)
or _is_context_length_error(first_err)
)
# Capacity errors (payment/quota/connection) bypass the explicit-provider
# gate — the provider cannot serve the request regardless of user intent.
# See #26803: daily token quota must fall back like a 402 credit error.
is_auto = resolved_provider in {"auto", "", None}
is_capacity_error = _is_payment_error(first_err) or _is_connection_error(first_err)
is_capacity_error = (
_is_payment_error(first_err)
or _is_connection_error(first_err)
or _is_context_length_error(first_err)
)
if should_fallback and (is_auto or is_capacity_error):
if _is_payment_error(first_err):
reason = "payment error"
@ -5460,6 +5577,8 @@ async def async_call_llm(
)
elif _is_rate_limit_error(first_err):
reason = "rate limit"
elif _is_context_length_error(first_err):
reason = "context length"
else:
reason = "connection error"
logger.info("Auxiliary %s (async): %s on %s (%s), trying fallback",
@ -5476,7 +5595,8 @@ async def async_call_llm(
resolved_provider, task, reason=reason)
else:
fb_client, fb_model, fb_label = _try_configured_fallback_chain(
task, resolved_provider or "auto", reason=reason)
task, resolved_provider or "auto", reason=reason,
failed_model=final_model, messages=messages, max_tokens=max_tokens)
if fb_client is None:
fb_client, fb_model, fb_label = _try_main_agent_model_fallback(
resolved_provider, task, reason=reason)

View File

@ -1402,6 +1402,83 @@ class TestAuxiliaryFallbackLayering:
assert main_client.chat.completions.create.called
def test_context_length_failure_uses_configured_chain_same_provider_different_model(self, monkeypatch):
"""Local auxiliary context_length failure should self-heal via fallback_chain."""
monkeypatch.setenv("OPENCODE_GO_API_KEY", "go-key")
primary_client = MagicMock()
chain_client = MagicMock()
chain_client.chat.completions.create.return_value = MagicMock(choices=[
MagicMock(message=MagicMock(content="from opencode gpt fallback"))
])
task_cfg = {
"provider": "opencode_go",
"model": "deepseek-v4-pro",
"context_length": 100,
"fallback_chain": [
{"provider": "opencode_go", "model": "gpt-5.5", "context_length": 10000},
],
}
def resolve_single(provider, model=None, base_url=None, api_key=None):
assert provider == "opencode_go"
assert model == "gpt-5.5"
return chain_client
with patch("agent.auxiliary_client._get_cached_client",
return_value=(primary_client, "deepseek-v4-pro")), \
patch("agent.auxiliary_client._resolve_task_provider_model",
return_value=("opencode_go", "deepseek-v4-pro", None, None, None)), \
patch("agent.auxiliary_client._get_auxiliary_task_config",
return_value=task_cfg), \
patch("agent.auxiliary_client._resolve_single_provider",
side_effect=resolve_single), \
patch("agent.auxiliary_client._try_main_agent_model_fallback") as main_fb:
result = call_llm(
task="compression",
messages=[{"role": "user", "content": "x" * 1000}],
max_tokens=2000,
)
primary_client.chat.completions.create.assert_not_called()
chain_client.chat.completions.create.assert_called_once()
main_fb.assert_not_called()
assert result.choices[0].message.content == "from opencode gpt fallback"
def test_configured_chain_skips_too_small_fallback_context(self):
"""fallback_chain should continue past entries that cannot fit the request."""
from agent.auxiliary_client import _try_configured_fallback_chain
too_small = MagicMock()
fits = MagicMock()
task_cfg = {
"fallback_chain": [
{"provider": "custom", "model": "tiny", "context_length": 100},
{"provider": "custom", "model": "gemma-local", "context_length": 10000},
]
}
def resolve_single(provider, model=None, base_url=None, api_key=None):
return too_small if model == "tiny" else fits
with patch("agent.auxiliary_client._get_auxiliary_task_config",
return_value=task_cfg), \
patch("agent.auxiliary_client._resolve_single_provider",
side_effect=resolve_single):
client, model, label = _try_configured_fallback_chain(
"compression",
"opencode_go",
reason="context length",
failed_model="deepseek-v4-pro",
messages=[{"role": "user", "content": "x" * 1000}],
max_tokens=2000,
)
assert client is fits
assert model == "gemma-local"
assert label == "fallback_chain[1](custom)"
def test_warning_emitted_when_all_fallbacks_exhausted(self, monkeypatch, caplog):
"""When chain AND main model both fail, a user-visible warning fires before re-raise."""
monkeypatch.setenv("OPENROUTER_API_KEY", "or-key")