feat(state.db): persist platform_message_id; restore yuanbao exact-id recall

PR #29211 dropped JSONL gateway transcripts and noted that the platform's
own `message_id` field (used by Yuanbao's recall guard to redact a
message by exact platform id) was no longer preserved — falling back to
content-match.  That fallback works for the common case but redacts the
wrong row when two messages share text (or fails to match when content
is post-processed).

Restore exact-id matching by giving state.db a column for it:

- New `platform_message_id TEXT` column on the messages table
  (SCHEMA_VERSION bump 11 → 12; column added via declarative reconciler
  on existing DBs, no version-gated migration block needed)
- Partial index `idx_messages_platform_msg_id` on
  (session_id, platform_message_id) to keep recall's point-lookup cheap
  even on large sessions
- `append_message()` and `replace_messages()` accept the new value:
  the gateway-facing `append_to_transcript` in `gateway/session.py`
  forwards either `message["platform_message_id"]` or the legacy
  `message["message_id"]` key (yuanbao's existing convention)
- `get_messages_as_conversation()` surfaces the column back on the
  message dict as `message_id` so platform code reads the same shape
  it used to read from JSONL
- Yuanbao `_patch_transcript`: restore branch A1 (exact id match)
  ahead of A2 (content match) ahead of B (system-note).  Both branches
  log which one fired so operators can tell from gateway.log whether
  recall hit the canonical path or had to fall back.

Tests:
- New low-level round-trip tests in `test_hermes_state.py` for both
  `append_message` and `replace_messages` paths
- The PR's `test_yuanbao_recall_db_only.py` was rewritten to assert
  the new contract: branch A1 (id match) works against DB-only
  transcripts, and branch A2 (content match) still recovers rows that
  were observed without a platform id (e.g. agent-processed @bot
  messages where run.py doesn't carry msg_id through)
This commit is contained in:
Teknium 2026-05-20 12:55:01 -07:00
parent 0cc1a1d2d9
commit 31a0100104
5 changed files with 185 additions and 38 deletions

View File

@ -1410,33 +1410,43 @@ class RecallGuardMiddleware(InboundMiddleware):
logger.warning("[%s] Recall: failed to resolve session: %s", adapter.name, exc) logger.warning("[%s] Recall: failed to resolve session: %s", adapter.name, exc)
return return
# Load transcript from canonical store (state.db). See Branch A below # Load transcript from canonical store (state.db). Since PR #29278
# for why we can no longer match by platform `message_id`. # added a ``platform_message_id`` column to the messages table and
# ``append_to_transcript`` wires the incoming dict's ``message_id``
# into it, ``load_transcript`` returns rows with ``message_id`` set
# for any message that was observed with one — Branch A1 (exact id
# match) is the canonical path again.
try: try:
transcript = store.load_transcript(sid) transcript = store.load_transcript(sid)
except Exception as exc: except Exception as exc:
logger.warning("[%s] Recall: failed to load transcript: %s", adapter.name, exc) logger.warning("[%s] Recall: failed to load transcript: %s", adapter.name, exc)
return return
# Branch A: content-match redaction. state.db does NOT preserve the # Branch A1: exact platform message_id match. Authoritative when the
# platform `message_id` (only its own autoincrement primary key), so we # row was persisted with a platform_message_id (observed group
# cannot redact by exact id. Match by content instead. Most yuanbao # messages and any inbound message whose adapter carried a msg_id).
# recalls carry the recalled text via `recalled_content`, which is
# sufficient for any non-duplicate message.
#
# TODO: add a `platform_message_id` column to state.db messages to
# restore exact-id matching. Tracked separately.
target = None target = None
if recalled_content: branch_label = ""
for entry in transcript:
if entry.get("message_id") == recalled_id:
target = entry
branch_label = "branch A1: id match"
break
# Branch A2: content-match fallback for messages that lack an exact
# platform id on the row — e.g. agent-processed @bot messages
# (run.py doesn't carry msg_id through) or older rows persisted
# before the platform_message_id column existed.
if target is None and recalled_content:
for entry in transcript: for entry in transcript:
if entry.get("role") == "user" and entry.get("content") == recalled_content: if entry.get("role") == "user" and entry.get("content") == recalled_content:
target = entry target = entry
branch_label = "branch A2: content match"
break break
if target is not None: if target is not None:
target["content"] = cls._REDACTED target["content"] = cls._REDACTED
try: try:
store.rewrite_transcript(sid, transcript) store.rewrite_transcript(sid, transcript)
logger.info("[%s] Recall: redacted msg_id=%s (branch A: content match)", adapter.name, recalled_id) logger.info("[%s] Recall: redacted msg_id=%s (%s)", adapter.name, recalled_id, branch_label)
except Exception as exc: except Exception as exc:
logger.warning("[%s] Recall: rewrite_transcript failed: %s", adapter.name, exc) logger.warning("[%s] Recall: rewrite_transcript failed: %s", adapter.name, exc)
return return

View File

@ -1271,6 +1271,12 @@ class SessionStore:
reasoning_details=message.get("reasoning_details") if message.get("role") == "assistant" else None, reasoning_details=message.get("reasoning_details") if message.get("role") == "assistant" else None,
codex_reasoning_items=message.get("codex_reasoning_items") if message.get("role") == "assistant" else None, codex_reasoning_items=message.get("codex_reasoning_items") if message.get("role") == "assistant" else None,
codex_message_items=message.get("codex_message_items") if message.get("role") == "assistant" else None, codex_message_items=message.get("codex_message_items") if message.get("role") == "assistant" else None,
# Platform-side message id (yuanbao msg_id, telegram update_id, …).
# Accept either explicit ``platform_message_id`` or the legacy
# ``message_id`` key the JSONL transcript used.
platform_message_id=(
message.get("platform_message_id") or message.get("message_id")
),
) )
except Exception as e: except Exception as e:
logger.debug("Session DB operation failed: %s", e) logger.debug("Session DB operation failed: %s", e)

View File

@ -33,7 +33,7 @@ T = TypeVar("T")
DEFAULT_DB_PATH = get_hermes_home() / "state.db" DEFAULT_DB_PATH = get_hermes_home() / "state.db"
SCHEMA_VERSION = 11 SCHEMA_VERSION = 12
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# WAL-compatibility fallback # WAL-compatibility fallback
@ -236,7 +236,8 @@ CREATE TABLE IF NOT EXISTS messages (
reasoning_content TEXT, reasoning_content TEXT,
reasoning_details TEXT, reasoning_details TEXT,
codex_reasoning_items TEXT, codex_reasoning_items TEXT,
codex_message_items TEXT codex_message_items TEXT,
platform_message_id TEXT
); );
CREATE TABLE IF NOT EXISTS state_meta ( CREATE TABLE IF NOT EXISTS state_meta (
@ -571,6 +572,19 @@ class SessionDB:
# column gets created here. # column gets created here.
self._reconcile_columns(cursor) self._reconcile_columns(cursor)
# Indexes that reference reconciler-added columns must be created
# AFTER _reconcile_columns runs — declaring them in SCHEMA_SQL
# makes the initial executescript fail on legacy DBs (the index's
# WHERE clause references a column that doesn't exist yet).
try:
cursor.execute(
"CREATE INDEX IF NOT EXISTS idx_messages_platform_msg_id "
"ON messages(session_id, platform_message_id) "
"WHERE platform_message_id IS NOT NULL"
)
except sqlite3.OperationalError as exc:
logger.debug("idx_messages_platform_msg_id create skipped: %s", exc)
# ── Schema version bookkeeping ───────────────────────────────── # ── Schema version bookkeeping ─────────────────────────────────
# Bump to current so future data migrations (if any) can gate on # Bump to current so future data migrations (if any) can gate on
# version. No version-gated column additions remain. # version. No version-gated column additions remain.
@ -1445,12 +1459,19 @@ class SessionDB:
reasoning_details: Any = None, reasoning_details: Any = None,
codex_reasoning_items: Any = None, codex_reasoning_items: Any = None,
codex_message_items: Any = None, codex_message_items: Any = None,
platform_message_id: str = None,
) -> int: ) -> int:
""" """
Append a message to a session. Returns the message row ID. Append a message to a session. Returns the message row ID.
Also increments the session's message_count (and tool_call_count Also increments the session's message_count (and tool_call_count
if role is 'tool' or tool_calls is present). if role is 'tool' or tool_calls is present).
``platform_message_id`` is the external messaging platform's own
message ID (e.g. Telegram update_id, Yuanbao msg_id). It is
independent of the SQLite autoincrement primary key and is used by
platform-specific flows like yuanbao's recall guard to redact a
message by its platform-side identifier.
""" """
# Serialize structured fields to JSON before entering the write txn # Serialize structured fields to JSON before entering the write txn
reasoning_details_json = ( reasoning_details_json = (
@ -1480,8 +1501,8 @@ class SessionDB:
"""INSERT INTO messages (session_id, role, content, tool_call_id, """INSERT INTO messages (session_id, role, content, tool_call_id,
tool_calls, tool_name, timestamp, token_count, finish_reason, tool_calls, tool_name, timestamp, token_count, finish_reason,
reasoning, reasoning_content, reasoning_details, codex_reasoning_items, reasoning, reasoning_content, reasoning_details, codex_reasoning_items,
codex_message_items) codex_message_items, platform_message_id)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
( (
session_id, session_id,
role, role,
@ -1497,6 +1518,7 @@ class SessionDB:
reasoning_details_json, reasoning_details_json,
codex_items_json, codex_items_json,
codex_message_items_json, codex_message_items_json,
platform_message_id,
), ),
) )
msg_id = cursor.lastrowid msg_id = cursor.lastrowid
@ -1558,13 +1580,18 @@ class SessionDB:
json.dumps(codex_message_items) if codex_message_items else None json.dumps(codex_message_items) if codex_message_items else None
) )
tool_calls_json = json.dumps(tool_calls) if tool_calls else None tool_calls_json = json.dumps(tool_calls) if tool_calls else None
# Accept either `platform_message_id` (new explicit name) or
# `message_id` (yuanbao's existing convention on message dicts).
platform_msg_id = (
msg.get("platform_message_id") or msg.get("message_id")
)
conn.execute( conn.execute(
"""INSERT INTO messages (session_id, role, content, tool_call_id, """INSERT INTO messages (session_id, role, content, tool_call_id,
tool_calls, tool_name, timestamp, token_count, finish_reason, tool_calls, tool_name, timestamp, token_count, finish_reason,
reasoning, reasoning_content, reasoning_details, codex_reasoning_items, reasoning, reasoning_content, reasoning_details, codex_reasoning_items,
codex_message_items) codex_message_items, platform_message_id)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
( (
session_id, session_id,
role, role,
@ -1580,6 +1607,7 @@ class SessionDB:
reasoning_details_json, reasoning_details_json,
codex_items_json, codex_items_json,
codex_message_items_json, codex_message_items_json,
platform_msg_id,
), ),
) )
total_messages += 1 total_messages += 1
@ -1897,7 +1925,7 @@ class SessionDB:
rows = self._conn.execute( rows = self._conn.execute(
"SELECT role, content, tool_call_id, tool_calls, tool_name, " "SELECT role, content, tool_call_id, tool_calls, tool_name, "
"finish_reason, reasoning, reasoning_content, reasoning_details, " "finish_reason, reasoning, reasoning_content, reasoning_details, "
"codex_reasoning_items, codex_message_items " "codex_reasoning_items, codex_message_items, platform_message_id "
f"FROM messages WHERE session_id IN ({placeholders}) ORDER BY id", f"FROM messages WHERE session_id IN ({placeholders}) ORDER BY id",
tuple(session_ids), tuple(session_ids),
).fetchall() ).fetchall()
@ -1918,6 +1946,13 @@ class SessionDB:
except (json.JSONDecodeError, TypeError): except (json.JSONDecodeError, TypeError):
logger.warning("Failed to deserialize tool_calls in conversation replay, falling back to []") logger.warning("Failed to deserialize tool_calls in conversation replay, falling back to []")
msg["tool_calls"] = [] msg["tool_calls"] = []
# Surface the platform-side message id (e.g. yuanbao msg_id,
# telegram update_id) so platform-specific flows like recall
# can match by external identifier instead of having to fall
# back to content-match heuristics. Exposed as ``message_id``
# for backward compatibility with the JSONL transcript shape.
if row["platform_message_id"]:
msg["message_id"] = row["platform_message_id"]
# Restore reasoning fields on assistant messages so providers # Restore reasoning fields on assistant messages so providers
# that replay reasoning (OpenRouter, OpenAI, Nous) receive # that replay reasoning (OpenRouter, OpenAI, Nous) receive
# coherent multi-turn reasoning context. # coherent multi-turn reasoning context.

View File

@ -1,31 +1,88 @@
"""Yuanbao recall: branch A (content-match) works against DB-only transcripts.""" """Yuanbao recall: branch A1 (exact id) and A2 (content-match) against DB-only transcripts.
state.db persists the platform-side ``message_id`` via the
``platform_message_id`` column (added in the salvage of PR #29211) and
``load_transcript`` surfaces it back on each message dict as ``message_id``
so the recall guard's exact-id match path stays canonical even with the
JSONL file gone. When a row has no platform id (e.g. agent-processed
@bot messages whose adapter didn't carry a msg_id, or pre-column legacy
rows), recall falls through to content-match.
"""
from gateway.session import SessionStore from gateway.session import SessionStore
from gateway.config import GatewayConfig from gateway.config import GatewayConfig
def test_recall_content_match_finds_target_in_db_transcript(tmp_path, monkeypatch): def _pin_db(monkeypatch, tmp_path):
"""state.db doesn't preserve message_id, so recall uses content-match. """Force SessionDB() to write into tmp_path instead of the real ~/.hermes."""
Pin DEFAULT_DB_PATH to tmp_path so SessionDB() can't write to the real
~/.hermes/state.db. (Module-level constant snapshot, see test_load_transcript_db_only.)
"""
import hermes_state import hermes_state
monkeypatch.setattr(hermes_state, "DEFAULT_DB_PATH", tmp_path / "state.db") monkeypatch.setattr(hermes_state, "DEFAULT_DB_PATH", tmp_path / "state.db")
def test_recall_branch_a1_exact_id_match_round_trips_through_db(tmp_path, monkeypatch):
"""A user message persisted with ``message_id`` must round-trip through
state.db so recall can find and redact it by exact id (branch A1)."""
_pin_db(monkeypatch, tmp_path)
config = GatewayConfig() config = GatewayConfig()
store = SessionStore(sessions_dir=tmp_path, config=config) store = SessionStore(sessions_dir=tmp_path, config=config)
sid = "test-yuanbao-recall" sid = "test-yuanbao-recall-a1"
store._db.create_session(session_id=sid, source="yuanbao:group:G") store._db.create_session(session_id=sid, source="yuanbao:group:G")
store.append_to_transcript(sid, {"role": "user", "content": "sensitive content", "timestamp": 1.0}) store.append_to_transcript(sid, {
store.append_to_transcript(sid, {"role": "assistant", "content": "ack", "timestamp": 2.0}) "role": "user",
"content": "sensitive content",
"timestamp": 1.0,
"message_id": "platform-msg-abc",
})
store.append_to_transcript(sid, {
"role": "assistant",
"content": "ack",
"timestamp": 2.0,
})
# DB-only history carries no platform message_id (PR #29211 dropped that path).
history = store.load_transcript(sid) history = store.load_transcript(sid)
assert all("message_id" not in msg for msg in history) # The user row must carry its platform id back so the recall guard can
# match by exact id; the assistant row had no platform id so it should
# not gain one spuriously.
user_msg = next(m for m in history if m["role"] == "user")
assistant_msg = next(m for m in history if m["role"] == "assistant")
assert user_msg.get("message_id") == "platform-msg-abc"
assert "message_id" not in assistant_msg
# Branch A: content match finds the target row that recall would redact. # Branch A1: locate the row by exact platform id — no content heuristics.
target = next((m for m in history target = next(
if m.get("role") == "user" and m.get("content") == "sensitive content"), None) (m for m in history if m.get("message_id") == "platform-msg-abc"),
None,
)
assert target is not None
assert target["content"] == "sensitive content"
def test_recall_branch_a2_content_match_when_no_platform_id(tmp_path, monkeypatch):
"""Rows that lack a platform_message_id (e.g. agent-processed @bot
messages) still match by content as a fallback."""
_pin_db(monkeypatch, tmp_path)
config = GatewayConfig()
store = SessionStore(sessions_dir=tmp_path, config=config)
sid = "test-yuanbao-recall-a2"
store._db.create_session(session_id=sid, source="yuanbao:group:G")
# No message_id on the dict — simulates an agent-processed message
# that did not carry the platform msg_id through.
store.append_to_transcript(sid, {
"role": "user",
"content": "sensitive content",
"timestamp": 1.0,
})
history = store.load_transcript(sid)
assert all("message_id" not in m for m in history)
# Branch A2: content match recovers the target.
target = next(
(m for m in history
if m.get("role") == "user" and m.get("content") == "sensitive content"),
None,
)
assert target is not None assert target is not None
# Caller would then redact: target["content"] = REDACTED; store.rewrite_transcript(sid, history)

View File

@ -316,6 +316,42 @@ class TestMessageStorage:
assert conv[0] == {"role": "user", "content": "Hello"} assert conv[0] == {"role": "user", "content": "Hello"}
assert conv[1] == {"role": "assistant", "content": "Hi!"} assert conv[1] == {"role": "assistant", "content": "Hi!"}
def test_platform_message_id_round_trips(self, db):
"""Platform-side message ids (yuanbao msg_id, telegram update_id, …)
survive append get_messages_as_conversation under the
``message_id`` key so platform recall flows can match by exact id."""
db.create_session(session_id="s_pmi", source="yuanbao")
db.append_message(
"s_pmi",
role="user",
content="hi",
platform_message_id="abc-123",
)
db.append_message("s_pmi", role="assistant", content="hello")
conv = db.get_messages_as_conversation("s_pmi")
user_msg = next(m for m in conv if m["role"] == "user")
assistant_msg = next(m for m in conv if m["role"] == "assistant")
assert user_msg.get("message_id") == "abc-123"
# Assistant row had no platform id — must not gain one spuriously.
assert "message_id" not in assistant_msg
def test_replace_messages_preserves_platform_message_id(self, db):
"""``rewrite_transcript`` (which goes through replace_messages) must
keep the platform_message_id round-trip working for /retry, /undo,
/compress and yuanbao's recall rewrite path."""
db.create_session(session_id="s_rep", source="yuanbao")
db.replace_messages(
"s_rep",
[
{"role": "user", "content": "x", "message_id": "ext-1"},
{"role": "assistant", "content": "y"},
],
)
conv = db.get_messages_as_conversation("s_rep")
assert next(m for m in conv if m["role"] == "user").get("message_id") == "ext-1"
assert "message_id" not in next(m for m in conv if m["role"] == "assistant")
def test_get_messages_as_conversation_includes_ancestor_chain(self, db): def test_get_messages_as_conversation_includes_ancestor_chain(self, db):
db.create_session("root", "tui") db.create_session("root", "tui")
db.append_message("root", role="user", content="first prompt") db.append_message("root", role="user", content="first prompt")
@ -1462,9 +1498,10 @@ class TestSchemaInit:
assert "schema_version" in tables assert "schema_version" in tables
def test_schema_version(self, db): def test_schema_version(self, db):
from hermes_state import SCHEMA_VERSION
cursor = db._conn.execute("SELECT version FROM schema_version") cursor = db._conn.execute("SELECT version FROM schema_version")
version = cursor.fetchone()[0] version = cursor.fetchone()[0]
assert version == 11 assert version == SCHEMA_VERSION
def test_title_column_exists(self, db): def test_title_column_exists(self, db):
"""Verify the title column was created in the sessions table.""" """Verify the title column was created in the sessions table."""
@ -1760,8 +1797,9 @@ class TestSchemaInit:
migrated_db = SessionDB(db_path=db_path) migrated_db = SessionDB(db_path=db_path)
# Verify migration # Verify migration
from hermes_state import SCHEMA_VERSION
cursor = migrated_db._conn.execute("SELECT version FROM schema_version") cursor = migrated_db._conn.execute("SELECT version FROM schema_version")
assert cursor.fetchone()[0] == 11 assert cursor.fetchone()[0] == SCHEMA_VERSION
# Verify title column exists and is NULL for existing sessions # Verify title column exists and is NULL for existing sessions
session = migrated_db.get_session("existing") session = migrated_db.get_session("existing")
@ -2952,11 +2990,12 @@ class TestFTS5ToolCallMigration:
assert len(session_db.search_messages("LEGACYARG")) == 1, \ assert len(session_db.search_messages("LEGACYARG")) == 1, \
"v11 migration must backfill tool_calls JSON into FTS" "v11 migration must backfill tool_calls JSON into FTS"
# schema_version bumped # schema_version bumped
from hermes_state import SCHEMA_VERSION
row = session_db._conn.execute( row = session_db._conn.execute(
"SELECT version FROM schema_version LIMIT 1" "SELECT version FROM schema_version LIMIT 1"
).fetchone() ).fetchone()
version = row["version"] if hasattr(row, "keys") else row[0] version = row["version"] if hasattr(row, "keys") else row[0]
assert version == 11 assert version == SCHEMA_VERSION
finally: finally:
session_db.close() session_db.close()