301 lines
12 KiB
Python
301 lines
12 KiB
Python
"""Experimental Semantic RLE context engine.
|
||
|
||
MVP goals:
|
||
- keep the hot tail verbatim;
|
||
- collapse older chat into a deterministic factual ledger;
|
||
- mark superseded facts instead of silently forgetting them;
|
||
- redact likely credentials and IP-like sensitive strings before ledgering.
|
||
|
||
This plugin is intentionally deterministic and does not call cloud LLMs.
|
||
It is not enabled by default; select it explicitly with ``context.engine``.
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import hashlib
|
||
import re
|
||
from dataclasses import dataclass
|
||
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
||
|
||
from agent.context_engine import ContextEngine
|
||
|
||
Message = Dict[str, Any]
|
||
|
||
_TOKEN_PATTERNS: tuple[re.Pattern[str], ...] = (
|
||
re.compile(r"\b(?:sk|xox[baprs]?|gh[pousr]|hf|AIza|ya29|pat|tok)[-_][A-Za-z0-9_./+=-]{12,}\b"),
|
||
re.compile(r"\b[A-Za-z0-9_./+=-]{32,}\b"),
|
||
)
|
||
_KEY_VALUE_SECRET_RE = re.compile(
|
||
r"(?i)\b(api[_-]?key|token|secret|password|passwd|authorization|bearer)\b\s*[:=]\s*([^\s,;]+)"
|
||
)
|
||
_IPV4_RE = re.compile(r"\b(?:\d{1,3}\.){3}\d{1,3}\b")
|
||
_SERVER_RE = re.compile(r"(?i)\bserver\s*(?:is|=|:|->|to)?\s*([A-Za-z0-9][A-Za-z0-9._-]{1,})")
|
||
_DECISION_RE = re.compile(r"(?i)\b(decided|decision|решили|решение|choose|chosen|use|используем)\b")
|
||
_OBLIGATION_RE = re.compile(r"(?i)\b(todo|надо|нужно|must|should|обяз|follow up|сделай|сделать)\b")
|
||
_QUESTION_RE = re.compile(r"(?i)(\?\s*$|\b(unresolved|open question|вопрос|непонятно|уточнить)\b)")
|
||
_FACT_RE = re.compile(r"(?i)\b(is|are|=|:|это|будет|uses|runs|host|server|model|provider)\b")
|
||
|
||
|
||
@dataclass
|
||
class LedgerFact:
|
||
key: str
|
||
value: str
|
||
role: str
|
||
turn_index: int
|
||
active: bool = True
|
||
superseded_by: Optional[str] = None
|
||
|
||
def line(self) -> str:
|
||
state = "active" if self.active else f"superseded by {self.superseded_by or 'newer fact'}"
|
||
return f"- [{state}] {self.key}: {self.value} (turn {self.turn_index}, {self.role})"
|
||
|
||
|
||
class SemanticRLEEngine(ContextEngine):
|
||
"""Deterministic context engine for a hot-tail + semantic-ledger experiment."""
|
||
|
||
threshold_percent = 0.75
|
||
protect_first_n = 1
|
||
protect_last_n = 8
|
||
|
||
def __init__(self, context_length: int = 200_000, hot_tail_messages: int = 8) -> None:
|
||
self.context_length = context_length
|
||
self.threshold_tokens = int(context_length * self.threshold_percent)
|
||
self.hot_tail_messages = hot_tail_messages
|
||
self.last_prompt_tokens = 0
|
||
self.last_completion_tokens = 0
|
||
self.last_total_tokens = 0
|
||
self.compression_count = 0
|
||
self._last_ledger: dict[str, Any] = {}
|
||
|
||
@property
|
||
def name(self) -> str:
|
||
return "semantic_rle"
|
||
|
||
def is_available(self) -> bool:
|
||
return True
|
||
|
||
def update_from_response(self, usage: Dict[str, Any]) -> None:
|
||
self.last_prompt_tokens = int(usage.get("prompt_tokens") or usage.get("input_tokens") or 0)
|
||
self.last_completion_tokens = int(usage.get("completion_tokens") or usage.get("output_tokens") or 0)
|
||
self.last_total_tokens = int(usage.get("total_tokens") or (self.last_prompt_tokens + self.last_completion_tokens))
|
||
|
||
def should_compress(self, prompt_tokens: Optional[int] = None) -> bool:
|
||
tokens = self.last_prompt_tokens if prompt_tokens is None else int(prompt_tokens)
|
||
return bool(self.threshold_tokens and tokens >= self.threshold_tokens)
|
||
|
||
def should_compress_preflight(self, messages: List[Message]) -> bool:
|
||
return self.has_content_to_compress(messages)
|
||
|
||
def has_content_to_compress(self, messages: List[Message]) -> bool:
|
||
return len(self._non_system(messages)) > self.hot_tail_messages
|
||
|
||
def update_model(
|
||
self,
|
||
model: str,
|
||
context_length: int,
|
||
base_url: str = "",
|
||
api_key: str = "",
|
||
provider: str = "",
|
||
api_mode: str = "",
|
||
) -> None:
|
||
self.context_length = int(context_length or self.context_length or 0)
|
||
self.threshold_tokens = int(self.context_length * self.threshold_percent) if self.context_length else 0
|
||
|
||
def compress(
|
||
self,
|
||
messages: List[Message],
|
||
current_tokens: Optional[int] = None,
|
||
focus_topic: Optional[str] = None,
|
||
) -> List[Message]:
|
||
"""Return original head + semantic ledger for cold turns + verbatim hot tail.
|
||
|
||
The deterministic path is fail-closed: on unexpected errors, return a
|
||
shallow copy of the original message list rather than dropping context.
|
||
"""
|
||
self.compression_count += 1
|
||
try:
|
||
if not messages:
|
||
return []
|
||
|
||
copied = [dict(m) for m in messages]
|
||
system_head = [m for m in copied if m.get("role") == "system"]
|
||
non_system = [m for m in copied if m.get("role") != "system"]
|
||
if len(non_system) <= self.hot_tail_messages:
|
||
return copied
|
||
|
||
hot_tail = non_system[-self.hot_tail_messages :]
|
||
cold = non_system[: -self.hot_tail_messages]
|
||
ledger = self._build_ledger(cold, focus_topic=focus_topic)
|
||
self._last_ledger = ledger
|
||
|
||
summary_message: Message = {
|
||
"role": "system",
|
||
"content": self._render_summary(ledger, focus_topic=focus_topic),
|
||
}
|
||
return [*system_head, summary_message, *hot_tail]
|
||
except Exception:
|
||
return [dict(m) for m in messages]
|
||
|
||
def get_status(self) -> Dict[str, Any]:
|
||
status = super().get_status()
|
||
status.update(
|
||
{
|
||
"engine": self.name,
|
||
"hot_tail_messages": self.hot_tail_messages,
|
||
"ledger_counts": {
|
||
key: len(value) for key, value in self._last_ledger.items() if isinstance(value, list)
|
||
},
|
||
}
|
||
)
|
||
return status
|
||
|
||
@staticmethod
|
||
def _non_system(messages: Iterable[Message]) -> list[Message]:
|
||
return [m for m in messages if m.get("role") != "system"]
|
||
|
||
def _build_ledger(self, messages: List[Message], focus_topic: Optional[str] = None) -> dict[str, Any]:
|
||
facts_by_key: dict[str, LedgerFact] = {}
|
||
superseded: list[LedgerFact] = []
|
||
decisions: list[str] = []
|
||
obligations: list[str] = []
|
||
questions: list[str] = []
|
||
credential_refs: list[str] = []
|
||
retrieval_notes: list[str] = []
|
||
|
||
for index, msg in enumerate(messages, start=1):
|
||
role = str(msg.get("role", "unknown"))
|
||
text = self._string_content(msg.get("content", ""))
|
||
if not text.strip():
|
||
continue
|
||
sanitized, refs = self._sanitize(text)
|
||
credential_refs.extend(refs)
|
||
snippet = self._snippet(sanitized)
|
||
|
||
fact = self._extract_fact(sanitized, role=role, turn_index=index)
|
||
if fact:
|
||
old = facts_by_key.get(fact.key)
|
||
if old and old.value != fact.value:
|
||
old.active = False
|
||
old.superseded_by = fact.value
|
||
superseded.append(old)
|
||
facts_by_key[fact.key] = fact
|
||
|
||
if _DECISION_RE.search(sanitized):
|
||
decisions.append(f"- {snippet} (turn {index}, {role})")
|
||
if _OBLIGATION_RE.search(sanitized):
|
||
obligations.append(f"- {snippet} (turn {index}, {role})")
|
||
if _QUESTION_RE.search(sanitized):
|
||
questions.append(f"- {snippet} (turn {index}, {role})")
|
||
if focus_topic and focus_topic.lower() in sanitized.lower():
|
||
retrieval_notes.append(f"- Focus match `{focus_topic}` at turn {index}: {snippet}")
|
||
|
||
active_facts = [fact for fact in facts_by_key.values() if fact.active]
|
||
return {
|
||
"active_facts": active_facts,
|
||
"decisions": self._dedupe(decisions),
|
||
"obligations": self._dedupe(obligations),
|
||
"superseded_facts": superseded,
|
||
"unresolved_questions": self._dedupe(questions),
|
||
"credential_refs": self._dedupe(credential_refs),
|
||
"retrieval_notes": self._dedupe(retrieval_notes),
|
||
"cold_turns_compacted": len(messages),
|
||
}
|
||
|
||
def _extract_fact(self, text: str, role: str, turn_index: int) -> Optional[LedgerFact]:
|
||
server = _SERVER_RE.search(text)
|
||
if server:
|
||
return LedgerFact("server", server.group(1), role, turn_index)
|
||
if not _FACT_RE.search(text):
|
||
return None
|
||
cleaned = self._snippet(text, limit=180)
|
||
key = self._fact_key(cleaned)
|
||
return LedgerFact(key, cleaned, role, turn_index)
|
||
|
||
@staticmethod
|
||
def _fact_key(text: str) -> str:
|
||
lower = text.lower()
|
||
before_sep = re.split(r"\s*(?:is|are|=|:|это|будет|uses|runs)\s*", lower, maxsplit=1)[0]
|
||
words = re.findall(r"[a-zа-я0-9_-]+", before_sep)[:6]
|
||
return " ".join(words) or "fact"
|
||
|
||
@staticmethod
|
||
def _string_content(content: Any) -> str:
|
||
if isinstance(content, str):
|
||
return content
|
||
if isinstance(content, list):
|
||
parts: list[str] = []
|
||
for item in content:
|
||
if isinstance(item, dict):
|
||
if isinstance(item.get("text"), str):
|
||
parts.append(item["text"])
|
||
elif isinstance(item.get("content"), str):
|
||
parts.append(item["content"])
|
||
elif isinstance(item, str):
|
||
parts.append(item)
|
||
return "\n".join(parts)
|
||
return str(content)
|
||
|
||
def _sanitize(self, text: str) -> Tuple[str, list[str]]:
|
||
refs: list[str] = []
|
||
|
||
def ref_for(raw: str, kind: str = "credential") -> str:
|
||
digest = hashlib.sha256(raw.encode("utf-8", "ignore")).hexdigest()[:10]
|
||
ref = f"credential_ref:{kind}:{digest}"
|
||
refs.append(ref)
|
||
return ref
|
||
|
||
def replace_key_value(match: re.Match[str]) -> str:
|
||
key = match.group(1)
|
||
raw = match.group(2)
|
||
return f"{key}=<{ref_for(raw)}>"
|
||
|
||
sanitized = _KEY_VALUE_SECRET_RE.sub(replace_key_value, text)
|
||
for pattern in _TOKEN_PATTERNS:
|
||
sanitized = pattern.sub(lambda m: f"<{ref_for(m.group(0))}>", sanitized)
|
||
sanitized = _IPV4_RE.sub("[REDACTED_IP]", sanitized)
|
||
return sanitized, refs
|
||
|
||
@staticmethod
|
||
def _snippet(text: str, limit: int = 220) -> str:
|
||
compact = " ".join(text.split())
|
||
if len(compact) <= limit:
|
||
return compact
|
||
return compact[: limit - 1].rstrip() + "…"
|
||
|
||
@staticmethod
|
||
def _dedupe(items: Iterable[str]) -> list[str]:
|
||
seen: set[str] = set()
|
||
result: list[str] = []
|
||
for item in items:
|
||
if item not in seen:
|
||
seen.add(item)
|
||
result.append(item)
|
||
return result
|
||
|
||
@staticmethod
|
||
def _render_summary(ledger: dict[str, Any], focus_topic: Optional[str] = None) -> str:
|
||
sections: list[str] = [
|
||
"Semantic RLE context ledger (deterministic, older turns compacted).",
|
||
f"Cold turns compacted: {ledger.get('cold_turns_compacted', 0)}.",
|
||
"Hot tail messages after this block are preserved verbatim.",
|
||
]
|
||
if focus_topic:
|
||
sections.append(f"Compression focus: {focus_topic}")
|
||
|
||
def add_section(title: str, lines: list[str]) -> None:
|
||
sections.append(f"\n## {title}")
|
||
sections.extend(lines or ["- None detected."])
|
||
|
||
add_section("Active facts", [f.line() for f in ledger.get("active_facts", [])])
|
||
add_section("Decisions", ledger.get("decisions", []))
|
||
add_section("Obligations", ledger.get("obligations", []))
|
||
add_section("Superseded facts", [f.line() for f in ledger.get("superseded_facts", [])])
|
||
add_section("Unresolved questions", ledger.get("unresolved_questions", []))
|
||
add_section("Credential refs", [f"- {ref}" for ref in ledger.get("credential_refs", [])])
|
||
add_section("Retrieval notes", ledger.get("retrieval_notes", []))
|
||
return "\n".join(sections)
|
||
|
||
|
||
def register(ctx: Any) -> None:
|
||
ctx.register_context_engine(SemanticRLEEngine())
|