fix(compressor): ABC compliance — total_tokens, api_mode, logger consistency
This commit is contained in:
parent
75643a6154
commit
8b2adead78
@ -1429,6 +1429,7 @@ def init_agent(
|
|||||||
base_url=agent.base_url,
|
base_url=agent.base_url,
|
||||||
api_key=getattr(agent, "api_key", ""),
|
api_key=getattr(agent, "api_key", ""),
|
||||||
provider=agent.provider,
|
provider=agent.provider,
|
||||||
|
api_mode=agent.api_mode,
|
||||||
)
|
)
|
||||||
if not agent.quiet_mode:
|
if not agent.quiet_mode:
|
||||||
_ra().logger.info("Using context engine: %s", _selected_engine.name)
|
_ra().logger.info("Using context engine: %s", _selected_engine.name)
|
||||||
|
|||||||
@ -609,6 +609,7 @@ class ContextCompressor(ContextEngine):
|
|||||||
"""Update tracked token usage from API response."""
|
"""Update tracked token usage from API response."""
|
||||||
self.last_prompt_tokens = usage.get("prompt_tokens", 0)
|
self.last_prompt_tokens = usage.get("prompt_tokens", 0)
|
||||||
self.last_completion_tokens = usage.get("completion_tokens", 0)
|
self.last_completion_tokens = usage.get("completion_tokens", 0)
|
||||||
|
self.last_total_tokens = usage.get("total_tokens", self.last_prompt_tokens + self.last_completion_tokens)
|
||||||
|
|
||||||
def should_compress(self, prompt_tokens: int = None) -> bool:
|
def should_compress(self, prompt_tokens: int = None) -> bool:
|
||||||
"""Check if context exceeds the compression threshold.
|
"""Check if context exceeds the compression threshold.
|
||||||
@ -897,7 +898,7 @@ class ContextCompressor(ContextEngine):
|
|||||||
into the warning log.
|
into the warning log.
|
||||||
"""
|
"""
|
||||||
self._summary_model_fallen_back = True
|
self._summary_model_fallen_back = True
|
||||||
logging.warning(
|
logger.warning(
|
||||||
"Summary model '%s' %s (%s). "
|
"Summary model '%s' %s (%s). "
|
||||||
"Falling back to main model '%s' for compression.",
|
"Falling back to main model '%s' for compression.",
|
||||||
self.summary_model, reason, e, self.model,
|
self.summary_model, reason, e, self.model,
|
||||||
@ -1086,7 +1087,7 @@ The user has requested that this compaction PRIORITISE preserving all informatio
|
|||||||
# No provider configured — long cooldown, unlikely to self-resolve
|
# No provider configured — long cooldown, unlikely to self-resolve
|
||||||
self._summary_failure_cooldown_until = time.monotonic() + _SUMMARY_FAILURE_COOLDOWN_SECONDS
|
self._summary_failure_cooldown_until = time.monotonic() + _SUMMARY_FAILURE_COOLDOWN_SECONDS
|
||||||
self._last_summary_error = "no auxiliary LLM provider configured"
|
self._last_summary_error = "no auxiliary LLM provider configured"
|
||||||
logging.warning("Context compression: no provider available for "
|
logger.warning("Context compression: no provider available for "
|
||||||
"summary. Middle turns will be dropped without summary "
|
"summary. Middle turns will be dropped without summary "
|
||||||
"for %d seconds.",
|
"for %d seconds.",
|
||||||
_SUMMARY_FAILURE_COOLDOWN_SECONDS)
|
_SUMMARY_FAILURE_COOLDOWN_SECONDS)
|
||||||
@ -1182,7 +1183,7 @@ The user has requested that this compaction PRIORITISE preserving all informatio
|
|||||||
if len(err_text) > 220:
|
if len(err_text) > 220:
|
||||||
err_text = err_text[:217].rstrip() + "..."
|
err_text = err_text[:217].rstrip() + "..."
|
||||||
self._last_summary_error = err_text
|
self._last_summary_error = err_text
|
||||||
logging.warning(
|
logger.warning(
|
||||||
"Failed to generate context summary: %s. "
|
"Failed to generate context summary: %s. "
|
||||||
"Further summary attempts paused for %d seconds.",
|
"Further summary attempts paused for %d seconds.",
|
||||||
e,
|
e,
|
||||||
|
|||||||
@ -200,6 +200,7 @@ class ContextEngine(ABC):
|
|||||||
base_url: str = "",
|
base_url: str = "",
|
||||||
api_key: str = "",
|
api_key: str = "",
|
||||||
provider: str = "",
|
provider: str = "",
|
||||||
|
api_mode: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Called when the user switches models or on fallback activation.
|
"""Called when the user switches models or on fallback activation.
|
||||||
|
|
||||||
|
|||||||
22
tests/agent/test_last_total_tokens.py
Normal file
22
tests/agent/test_last_total_tokens.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
"""Test that last_total_tokens is correctly set by ContextCompressor."""
|
||||||
|
|
||||||
|
from agent.context_compressor import ContextCompressor
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_from_response_sets_total_tokens():
|
||||||
|
"""ABC contract: last_total_tokens must be set from API response."""
|
||||||
|
c = ContextCompressor(model="test", quiet_mode=True, config_context_length=200000)
|
||||||
|
|
||||||
|
c.update_from_response({"prompt_tokens": 100, "completion_tokens": 30, "total_tokens": 130})
|
||||||
|
assert c.last_total_tokens == 130
|
||||||
|
|
||||||
|
c.update_from_response({"prompt_tokens": 100, "completion_tokens": 30})
|
||||||
|
assert c.last_total_tokens == 130
|
||||||
|
|
||||||
|
|
||||||
|
def test_session_reset_clears_total_tokens():
|
||||||
|
"""on_session_reset must zero total_tokens."""
|
||||||
|
c = ContextCompressor(model="test", quiet_mode=True, config_context_length=200000)
|
||||||
|
c.update_from_response({"prompt_tokens": 100, "completion_tokens": 30, "total_tokens": 130})
|
||||||
|
c.on_session_reset()
|
||||||
|
assert c.last_total_tokens == 0
|
||||||
@ -87,5 +87,4 @@ def test_plugin_engine_update_model_args():
|
|||||||
assert kw["context_length"] == 131_072
|
assert kw["context_length"] == 131_072
|
||||||
assert "model" in kw
|
assert "model" in kw
|
||||||
assert "provider" in kw
|
assert "provider" in kw
|
||||||
# Should NOT pass api_mode — the ABC doesn't accept it
|
assert "api_mode" in kw
|
||||||
assert "api_mode" not in kw
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user