feat(auxiliary): self-healing fallback chain with proactive context-length checks
- Add _context_length_error() for pre-API-call context window validation - Add _is_context_length_error() to detect context-length API errors - Enhance _try_configured_fallback_chain() with failed_model/messages/max_tokens - Fix skip logic: same provider + different model = valid self-healing rung - Integrate context-length checks into call_llm() and async_call_llm() - Trigger fallback chain on context_length errors (not just payment/connection)
This commit is contained in:
parent
2517917de3
commit
e5d74fa32a
@ -2831,16 +2831,74 @@ def _try_main_agent_model_fallback(
|
||||
return client, resolved_model or main_model, label
|
||||
|
||||
|
||||
def _coerce_positive_int(value: Any) -> Optional[int]:
|
||||
try:
|
||||
parsed = int(value)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
return parsed if parsed > 0 else None
|
||||
|
||||
|
||||
def _estimate_auxiliary_request_tokens(messages: list, max_tokens: Optional[int] = None) -> int:
|
||||
"""Rough token estimate for local auxiliary context-window checks."""
|
||||
try:
|
||||
from agent.model_metadata import estimate_messages_tokens_rough
|
||||
input_tokens = estimate_messages_tokens_rough(messages or [])
|
||||
except Exception:
|
||||
input_tokens = 0
|
||||
for msg in messages or []:
|
||||
content = msg.get("content", "") if isinstance(msg, dict) else str(msg)
|
||||
input_tokens += max(1, len(str(content)) // 4)
|
||||
return input_tokens + (_coerce_positive_int(max_tokens) or 0)
|
||||
|
||||
|
||||
def _context_length_error(
|
||||
*,
|
||||
task: str,
|
||||
provider: str,
|
||||
model: Optional[str],
|
||||
context_length: Optional[int],
|
||||
messages: list,
|
||||
max_tokens: Optional[int],
|
||||
) -> Optional[ValueError]:
|
||||
ctx = _coerce_positive_int(context_length)
|
||||
if not ctx:
|
||||
return None
|
||||
estimated = _estimate_auxiliary_request_tokens(messages, max_tokens)
|
||||
if estimated <= ctx:
|
||||
return None
|
||||
return ValueError(
|
||||
f"Auxiliary {task or 'call'} request needs ~{estimated} tokens, "
|
||||
f"exceeding configured context_length={ctx} for "
|
||||
f"{provider or 'auto'}/{model or 'default'}"
|
||||
)
|
||||
|
||||
|
||||
def _is_context_length_error(exc: Exception) -> bool:
|
||||
text = str(exc).lower()
|
||||
return (
|
||||
"context_length" in text
|
||||
or "context length" in text
|
||||
or "context window" in text
|
||||
or "too many tokens" in text
|
||||
or "exceeding configured context" in text
|
||||
or "exceeds the max_model_len" in text
|
||||
)
|
||||
|
||||
|
||||
def _try_configured_fallback_chain(
|
||||
task: str,
|
||||
failed_provider: str,
|
||||
reason: str = "error",
|
||||
failed_model: Optional[str] = None,
|
||||
messages: Optional[list] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
) -> Tuple[Optional[Any], Optional[str], str]:
|
||||
"""Try user-configured fallback_chain for a specific auxiliary task.
|
||||
|
||||
Reads auxiliary.<task>.fallback_chain from config.yaml and tries each
|
||||
entry in order. Each entry must have at least ``provider``; ``model``,
|
||||
``base_url``, and ``api_key`` are optional.
|
||||
``base_url``, ``api_key``, and ``context_length`` are optional.
|
||||
|
||||
Returns:
|
||||
(client, model, provider_label) or (None, None, "") if no fallback.
|
||||
@ -2853,21 +2911,46 @@ def _try_configured_fallback_chain(
|
||||
if not chain or not isinstance(chain, list):
|
||||
return None, None, ""
|
||||
|
||||
skip = failed_provider.lower().strip()
|
||||
skip_provider = failed_provider.lower().strip()
|
||||
skip_model = str(failed_model or "").lower().strip()
|
||||
tried = []
|
||||
|
||||
for i, entry in enumerate(chain):
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
fb_provider = str(entry.get("provider", "")).strip()
|
||||
if not fb_provider or fb_provider.lower() == skip:
|
||||
if not fb_provider:
|
||||
continue
|
||||
fb_model = str(entry.get("model", "")).strip() or None
|
||||
# Skip only the exact failed provider+model pair. Same provider with a
|
||||
# different model is a valid self-healing rung (e.g. opencode_go
|
||||
# deepseek-v4-pro -> opencode_go gpt-5.5).
|
||||
if fb_provider.lower() == skip_provider and (
|
||||
not skip_model or (fb_model or "").lower() == skip_model
|
||||
):
|
||||
continue
|
||||
fb_base_url = str(entry.get("base_url", "")).strip() or None
|
||||
fb_api_key = str(entry.get("api_key", "")).strip() or None
|
||||
|
||||
label = f"fallback_chain[{i}]({fb_provider})"
|
||||
|
||||
if messages is not None:
|
||||
context_err = _context_length_error(
|
||||
task=task,
|
||||
provider=fb_provider,
|
||||
model=fb_model,
|
||||
context_length=entry.get("context_length"),
|
||||
messages=messages,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
if context_err is not None:
|
||||
logger.info(
|
||||
"Auxiliary %s: skipping %s (%s) because it also exceeds context_length: %s",
|
||||
task, label, fb_model or "default", context_err,
|
||||
)
|
||||
tried.append(label)
|
||||
continue
|
||||
|
||||
try:
|
||||
fb_client = _resolve_single_provider(
|
||||
fb_provider, fb_model, fb_base_url, fb_api_key)
|
||||
@ -2889,7 +2972,6 @@ def _try_configured_fallback_chain(
|
||||
)
|
||||
return None, None, ""
|
||||
|
||||
|
||||
def _resolve_single_provider(
|
||||
provider: str,
|
||||
model: Optional[str] = None,
|
||||
@ -4889,6 +4971,17 @@ def call_llm(
|
||||
# Handle unsupported temperature, max_tokens vs max_completion_tokens retry,
|
||||
# then payment fallback.
|
||||
try:
|
||||
task_context = _get_auxiliary_task_config(task).get("context_length") if task else None
|
||||
context_err = _context_length_error(
|
||||
task=task or "call",
|
||||
provider=resolved_provider,
|
||||
model=final_model,
|
||||
context_length=task_context,
|
||||
messages=messages,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
if context_err is not None:
|
||||
raise context_err
|
||||
return _validate_llm_response(
|
||||
client.chat.completions.create(**kwargs), task)
|
||||
except Exception as first_err:
|
||||
@ -5072,6 +5165,7 @@ def call_llm(
|
||||
_is_payment_error(first_err)
|
||||
or _is_connection_error(first_err)
|
||||
or _is_rate_limit_error(first_err)
|
||||
or _is_context_length_error(first_err)
|
||||
)
|
||||
# Respect explicit provider choice for transient errors (auth, request
|
||||
# validation, etc.) but allow fallback when the provider clearly cannot
|
||||
@ -5082,7 +5176,11 @@ def call_llm(
|
||||
is_auto = resolved_provider in {"auto", "", None}
|
||||
# Capacity errors bypass the explicit-provider gate: the provider
|
||||
# literally cannot serve this request regardless of user intent.
|
||||
is_capacity_error = _is_payment_error(first_err) or _is_connection_error(first_err)
|
||||
is_capacity_error = (
|
||||
_is_payment_error(first_err)
|
||||
or _is_connection_error(first_err)
|
||||
or _is_context_length_error(first_err)
|
||||
)
|
||||
if should_fallback and (is_auto or is_capacity_error):
|
||||
if _is_payment_error(first_err):
|
||||
reason = "payment error"
|
||||
@ -5095,6 +5193,8 @@ def call_llm(
|
||||
)
|
||||
elif _is_rate_limit_error(first_err):
|
||||
reason = "rate limit"
|
||||
elif _is_context_length_error(first_err):
|
||||
reason = "context length"
|
||||
else:
|
||||
reason = "connection error"
|
||||
logger.info("Auxiliary %s: %s on %s (%s), trying fallback",
|
||||
@ -5112,7 +5212,8 @@ def call_llm(
|
||||
resolved_provider, task, reason=reason)
|
||||
else:
|
||||
fb_client, fb_model, fb_label = _try_configured_fallback_chain(
|
||||
task, resolved_provider or "auto", reason=reason)
|
||||
task, resolved_provider or "auto", reason=reason,
|
||||
failed_model=final_model, messages=messages, max_tokens=max_tokens)
|
||||
if fb_client is None:
|
||||
fb_client, fb_model, fb_label = _try_main_agent_model_fallback(
|
||||
resolved_provider, task, reason=reason)
|
||||
@ -5295,6 +5396,17 @@ async def async_call_llm(
|
||||
kwargs["messages"] = _convert_openai_images_to_anthropic(kwargs["messages"])
|
||||
|
||||
try:
|
||||
task_context = _get_auxiliary_task_config(task).get("context_length") if task else None
|
||||
context_err = _context_length_error(
|
||||
task=task or "call",
|
||||
provider=resolved_provider,
|
||||
model=final_model,
|
||||
context_length=task_context,
|
||||
messages=messages,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
if context_err is not None:
|
||||
raise context_err
|
||||
return _validate_llm_response(
|
||||
await client.chat.completions.create(**kwargs), task)
|
||||
except Exception as first_err:
|
||||
@ -5446,12 +5558,17 @@ async def async_call_llm(
|
||||
_is_payment_error(first_err)
|
||||
or _is_connection_error(first_err)
|
||||
or _is_rate_limit_error(first_err)
|
||||
or _is_context_length_error(first_err)
|
||||
)
|
||||
# Capacity errors (payment/quota/connection) bypass the explicit-provider
|
||||
# gate — the provider cannot serve the request regardless of user intent.
|
||||
# See #26803: daily token quota must fall back like a 402 credit error.
|
||||
is_auto = resolved_provider in {"auto", "", None}
|
||||
is_capacity_error = _is_payment_error(first_err) or _is_connection_error(first_err)
|
||||
is_capacity_error = (
|
||||
_is_payment_error(first_err)
|
||||
or _is_connection_error(first_err)
|
||||
or _is_context_length_error(first_err)
|
||||
)
|
||||
if should_fallback and (is_auto or is_capacity_error):
|
||||
if _is_payment_error(first_err):
|
||||
reason = "payment error"
|
||||
@ -5460,6 +5577,8 @@ async def async_call_llm(
|
||||
)
|
||||
elif _is_rate_limit_error(first_err):
|
||||
reason = "rate limit"
|
||||
elif _is_context_length_error(first_err):
|
||||
reason = "context length"
|
||||
else:
|
||||
reason = "connection error"
|
||||
logger.info("Auxiliary %s (async): %s on %s (%s), trying fallback",
|
||||
@ -5476,7 +5595,8 @@ async def async_call_llm(
|
||||
resolved_provider, task, reason=reason)
|
||||
else:
|
||||
fb_client, fb_model, fb_label = _try_configured_fallback_chain(
|
||||
task, resolved_provider or "auto", reason=reason)
|
||||
task, resolved_provider or "auto", reason=reason,
|
||||
failed_model=final_model, messages=messages, max_tokens=max_tokens)
|
||||
if fb_client is None:
|
||||
fb_client, fb_model, fb_label = _try_main_agent_model_fallback(
|
||||
resolved_provider, task, reason=reason)
|
||||
|
||||
@ -1402,6 +1402,83 @@ class TestAuxiliaryFallbackLayering:
|
||||
|
||||
assert main_client.chat.completions.create.called
|
||||
|
||||
def test_context_length_failure_uses_configured_chain_same_provider_different_model(self, monkeypatch):
|
||||
"""Local auxiliary context_length failure should self-heal via fallback_chain."""
|
||||
monkeypatch.setenv("OPENCODE_GO_API_KEY", "go-key")
|
||||
|
||||
primary_client = MagicMock()
|
||||
chain_client = MagicMock()
|
||||
chain_client.chat.completions.create.return_value = MagicMock(choices=[
|
||||
MagicMock(message=MagicMock(content="from opencode gpt fallback"))
|
||||
])
|
||||
|
||||
task_cfg = {
|
||||
"provider": "opencode_go",
|
||||
"model": "deepseek-v4-pro",
|
||||
"context_length": 100,
|
||||
"fallback_chain": [
|
||||
{"provider": "opencode_go", "model": "gpt-5.5", "context_length": 10000},
|
||||
],
|
||||
}
|
||||
|
||||
def resolve_single(provider, model=None, base_url=None, api_key=None):
|
||||
assert provider == "opencode_go"
|
||||
assert model == "gpt-5.5"
|
||||
return chain_client
|
||||
|
||||
with patch("agent.auxiliary_client._get_cached_client",
|
||||
return_value=(primary_client, "deepseek-v4-pro")), \
|
||||
patch("agent.auxiliary_client._resolve_task_provider_model",
|
||||
return_value=("opencode_go", "deepseek-v4-pro", None, None, None)), \
|
||||
patch("agent.auxiliary_client._get_auxiliary_task_config",
|
||||
return_value=task_cfg), \
|
||||
patch("agent.auxiliary_client._resolve_single_provider",
|
||||
side_effect=resolve_single), \
|
||||
patch("agent.auxiliary_client._try_main_agent_model_fallback") as main_fb:
|
||||
result = call_llm(
|
||||
task="compression",
|
||||
messages=[{"role": "user", "content": "x" * 1000}],
|
||||
max_tokens=2000,
|
||||
)
|
||||
|
||||
primary_client.chat.completions.create.assert_not_called()
|
||||
chain_client.chat.completions.create.assert_called_once()
|
||||
main_fb.assert_not_called()
|
||||
assert result.choices[0].message.content == "from opencode gpt fallback"
|
||||
|
||||
def test_configured_chain_skips_too_small_fallback_context(self):
|
||||
"""fallback_chain should continue past entries that cannot fit the request."""
|
||||
from agent.auxiliary_client import _try_configured_fallback_chain
|
||||
|
||||
too_small = MagicMock()
|
||||
fits = MagicMock()
|
||||
task_cfg = {
|
||||
"fallback_chain": [
|
||||
{"provider": "custom", "model": "tiny", "context_length": 100},
|
||||
{"provider": "custom", "model": "gemma-local", "context_length": 10000},
|
||||
]
|
||||
}
|
||||
|
||||
def resolve_single(provider, model=None, base_url=None, api_key=None):
|
||||
return too_small if model == "tiny" else fits
|
||||
|
||||
with patch("agent.auxiliary_client._get_auxiliary_task_config",
|
||||
return_value=task_cfg), \
|
||||
patch("agent.auxiliary_client._resolve_single_provider",
|
||||
side_effect=resolve_single):
|
||||
client, model, label = _try_configured_fallback_chain(
|
||||
"compression",
|
||||
"opencode_go",
|
||||
reason="context length",
|
||||
failed_model="deepseek-v4-pro",
|
||||
messages=[{"role": "user", "content": "x" * 1000}],
|
||||
max_tokens=2000,
|
||||
)
|
||||
|
||||
assert client is fits
|
||||
assert model == "gemma-local"
|
||||
assert label == "fallback_chain[1](custom)"
|
||||
|
||||
def test_warning_emitted_when_all_fallbacks_exhausted(self, monkeypatch, caplog):
|
||||
"""When chain AND main model both fail, a user-visible warning fires before re-raise."""
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "or-key")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user