From eec19ca22f67eeb3f2705d970ac7d59212d1bba8 Mon Sep 17 00:00:00 2001 From: Trung Nguyen Date: Tue, 19 May 2026 10:53:20 +0200 Subject: [PATCH 1/2] refactor(runtime): skip auto-compaction for non-token overflow MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Auto-compaction is only useful when the rejection is a token-count overflow — summarising older turns reduces the input token count. For wire-level overflow ([OverflowKindWire]) the request body itself exceeds the provider's cap, and the latest turn alone is over the limit; the compaction call would have to send the same oversized history and would also be rejected. For media overflow ([OverflowKindMedia]) we have no media-stripping during compaction today, so a retry would resend the same attachment and fail again. In both cases the recovery attempt always fails, then we surface the error anyway, while having spent an extra provider call and several seconds of wall-clock latency. This change skips compaction for those two kinds and surfaces the error directly. The token-overflow path is unchanged. --- pkg/runtime/loop_steps.go | 44 +++++++++++++++------ pkg/runtime/loop_steps_test.go | 72 ++++++++++++++++++++++++++++++++++ 2 files changed, 104 insertions(+), 12 deletions(-) diff --git a/pkg/runtime/loop_steps.go b/pkg/runtime/loop_steps.go index 7ab158a8b..740d50f27 100644 --- a/pkg/runtime/loop_steps.go +++ b/pkg/runtime/loop_steps.go @@ -161,22 +161,42 @@ func (r *LocalRuntime) handleStreamError( // request instead of surfacing raw errors. We allow at most // r.maxOverflowCompactions consecutive attempts to avoid an infinite // loop when compaction cannot reduce the context enough. + // + // Compaction only helps for token-count overflow ([OverflowKindTokens]): + // summarising older turns reduces the input token count. + // + // For wire-level overflow ([OverflowKindWire]) the request body itself + // is over the provider's cap; the latest turn alone is too large and + // would still have to be sent during compaction. For media overflow + // ([OverflowKindMedia]) we have no media-stripping path today, so a + // retry would resend the same oversized attachment. In both cases the + // recovery attempt always fails, so we skip it and surface the error + // directly — the user can act on it (smaller paste, smaller file) + // faster, without burning a second provider call. if _, ok := errors.AsType[*modelerrors.ContextOverflowError](err); ok && r.sessionCompaction && *overflowCompactions < r.maxOverflowCompactions { - *overflowCompactions++ - slog.WarnContext(ctx, "Context window overflow detected, attempting auto-compaction", + kind := modelerrors.OverflowKindOf(err) + if kind == modelerrors.OverflowKindTokens { + *overflowCompactions++ + slog.WarnContext(ctx, "Context window overflow detected, attempting auto-compaction", + "agent", a.Name(), + "session_id", sess.ID, + "input_tokens", sess.InputTokens, + "output_tokens", sess.OutputTokens, + "context_limit", contextLimit, + "attempt", *overflowCompactions, + ) + events.Emit(Warning( + "The conversation has exceeded the model's context window. Automatically compacting the conversation history...", + a.Name(), + )) + r.compactWithReason(ctx, sess, "", compactionReasonOverflow, events) + return streamErrorRetry + } + slog.InfoContext(ctx, "Skipping auto-compaction for non-token overflow", "agent", a.Name(), "session_id", sess.ID, - "input_tokens", sess.InputTokens, - "output_tokens", sess.OutputTokens, - "context_limit", contextLimit, - "attempt", *overflowCompactions, + "overflow_kind", string(kind), ) - events.Emit(Warning( - "The conversation has exceeded the model's context window. Automatically compacting the conversation history...", - a.Name(), - )) - r.compactWithReason(ctx, sess, "", compactionReasonOverflow, events) - return streamErrorRetry } streamSpan.RecordError(err) diff --git a/pkg/runtime/loop_steps_test.go b/pkg/runtime/loop_steps_test.go index 8c60a2c25..f2115acb6 100644 --- a/pkg/runtime/loop_steps_test.go +++ b/pkg/runtime/loop_steps_test.go @@ -292,3 +292,75 @@ func TestHandleStreamError_GenericError_FatalAndEmitsError(t *testing.T) { } assert.True(t, sawError, "generic error should emit ErrorEvent") } + +// TestHandleStreamError_WireOverflowSkipsCompaction verifies that wire-level +// overflow does not trigger auto-compaction. Compaction would resend the same +// oversized request that just got rejected, so it is guaranteed to fail; we +// surface the error directly instead. +func TestHandleStreamError_WireOverflowSkipsCompaction(t *testing.T) { + t.Parallel() + + rt, a := newTestRuntime(t) + sess := session.New() + events := make(chan Event, 16) + _, sp := noop.NewTracerProvider().Tracer("t").Start(t.Context(), "x") + + overflow := &modelerrors.ContextOverflowError{ + Underlying: errors.New("HTTP 413: Payload Too Large"), + Kind: modelerrors.OverflowKindWire, + } + overflowCount := 0 + + outcome := rt.handleStreamError(t.Context(), sess, a, overflow, 1000, &overflowCount, sp, NewChannelSink(events)) + + assert.Equal(t, streamErrorFatal, outcome, "wire overflow must not trigger compaction retry") + assert.Equal(t, 0, overflowCount, "wire overflow must not bump the compaction counter") + + got := drainEvents(events) + var sawError bool + var sawWarning bool + var errCode string + for _, ev := range got { + switch e := ev.(type) { + case *ErrorEvent: + sawError = true + errCode = e.Code + case *WarningEvent: + sawWarning = true + } + } + assert.True(t, sawError, "wire overflow should emit an ErrorEvent") + assert.False(t, sawWarning, "wire overflow should not emit the compaction warning") + assert.Equal(t, ErrorCodeRequestTooLarge, errCode, "ErrorEvent.Code should distinguish wire overflow") +} + +// TestHandleStreamError_MediaOverflowSkipsCompaction verifies the same skip +// behaviour for media-size rejections. Without media-stripping during +// compaction, the offending attachment would be resent and fail again. +func TestHandleStreamError_MediaOverflowSkipsCompaction(t *testing.T) { + t.Parallel() + + rt, a := newTestRuntime(t) + sess := session.New() + events := make(chan Event, 16) + _, sp := noop.NewTracerProvider().Tracer("t").Start(t.Context(), "x") + + overflow := &modelerrors.ContextOverflowError{ + Underlying: errors.New("image exceeds 5 MB maximum"), + Kind: modelerrors.OverflowKindMedia, + } + overflowCount := 0 + + outcome := rt.handleStreamError(t.Context(), sess, a, overflow, 1000, &overflowCount, sp, NewChannelSink(events)) + + assert.Equal(t, streamErrorFatal, outcome, "media overflow must not trigger compaction retry") + assert.Equal(t, 0, overflowCount, "media overflow must not bump the compaction counter") + + var errCode string + for _, ev := range drainEvents(events) { + if e, ok := ev.(*ErrorEvent); ok { + errCode = e.Code + } + } + assert.Equal(t, ErrorCodeMediaTooLarge, errCode, "ErrorEvent.Code should distinguish media overflow") +} From 693a6c1ccf29b6886a9d2798ba50b96f4af847bb Mon Sep 17 00:00:00 2001 From: Trung Nguyen Date: Tue, 19 May 2026 13:50:39 +0200 Subject: [PATCH 2/2] feat(runtime): scrub oversized user message after wire/media overflow MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When the provider rejects a request because the body itself is over the wire-size cap or contains an oversized attachment, the offending user message stays verbatim in the session. Every subsequent call reloads that message as part of the conversation history and trips the same limit. The session is effectively dead until the user starts over. Add a hygiene step that runs on wire- and media-overflow rejections: walk back to the latest user message, replace each media part (image, file, document) with a text placeholder that records what was attached, and replace plain-text content over 1 MiB with a size-noting placeholder. The rewrite is mirrored to the session store so the next session load reflects it; the in-memory mutation alone keeps the current process healthy even if the store write fails. A Warning event is emitted so the UI can tell the user that their previous message was rewritten in place. The fatal ErrorEvent for the original rejection is still emitted — scrubbing is in addition to surfacing the error, not instead of it. Token-overflow is unchanged: it still goes through auto-compaction, which is the correct mechanism for that shape of failure. --- pkg/runtime/loop_steps.go | 50 ++-- pkg/runtime/overflow_recovery.go | 289 +++++++++++++++++++ pkg/runtime/overflow_recovery_test.go | 390 ++++++++++++++++++++++++++ pkg/session/session.go | 49 ++++ 4 files changed, 756 insertions(+), 22 deletions(-) create mode 100644 pkg/runtime/overflow_recovery.go create mode 100644 pkg/runtime/overflow_recovery_test.go diff --git a/pkg/runtime/loop_steps.go b/pkg/runtime/loop_steps.go index 740d50f27..7be44cf93 100644 --- a/pkg/runtime/loop_steps.go +++ b/pkg/runtime/loop_steps.go @@ -156,26 +156,28 @@ func (r *LocalRuntime) handleStreamError( return streamErrorFatal } - // Auto-recovery: if the error is a context overflow and session - // compaction is enabled, compact the conversation and retry the - // request instead of surfacing raw errors. We allow at most - // r.maxOverflowCompactions consecutive attempts to avoid an infinite - // loop when compaction cannot reduce the context enough. + // Overflow handling has two independent concerns: // - // Compaction only helps for token-count overflow ([OverflowKindTokens]): - // summarising older turns reduces the input token count. + // 1. Auto-compaction (token overflow only): summarise older + // turns to fit the context window, then retry. Gated by + // r.sessionCompaction and the per-run attempt cap. // - // For wire-level overflow ([OverflowKindWire]) the request body itself - // is over the provider's cap; the latest turn alone is too large and - // would still have to be sent during compaction. For media overflow - // ([OverflowKindMedia]) we have no media-stripping path today, so a - // retry would resend the same oversized attachment. In both cases the - // recovery attempt always fails, so we skip it and surface the error - // directly — the user can act on it (smaller paste, smaller file) - // faster, without burning a second provider call. - if _, ok := errors.AsType[*modelerrors.ContextOverflowError](err); ok && r.sessionCompaction && *overflowCompactions < r.maxOverflowCompactions { + // 2. Session hygiene (wire/media overflow): rewrite the + // offending user message so the same oversized payload + // cannot reload on the next call and re-poison the session. + // Always runs when the kind warrants it, independent of + // the compaction config — the hygiene step does not retry + // and is correct even when compaction is disabled. + if _, ok := errors.AsType[*modelerrors.ContextOverflowError](err); ok { kind := modelerrors.OverflowKindOf(err) - if kind == modelerrors.OverflowKindTokens { + + // Token overflow: compaction is the right recovery — older + // turns can be summarised to free up context. Wire/media do + // not benefit from compaction (the latest turn alone is + // over the cap; resending it during compaction would just + // fail again), so we fall through to the hygiene step + // below for those. + if kind == modelerrors.OverflowKindTokens && r.sessionCompaction && *overflowCompactions < r.maxOverflowCompactions { *overflowCompactions++ slog.WarnContext(ctx, "Context window overflow detected, attempting auto-compaction", "agent", a.Name(), @@ -192,11 +194,15 @@ func (r *LocalRuntime) handleStreamError( r.compactWithReason(ctx, sess, "", compactionReasonOverflow, events) return streamErrorRetry } - slog.InfoContext(ctx, "Skipping auto-compaction for non-token overflow", - "agent", a.Name(), - "session_id", sess.ID, - "overflow_kind", string(kind), - ) + + // Hygiene scrub for wire/media overflow. Runs independently + // of r.sessionCompaction: this rewrites a single message in + // place, it does not retry, and the same-process + // session-poisoning bug it fixes occurs regardless of + // whether the user opted into auto-compaction. + if kind == modelerrors.OverflowKindWire || kind == modelerrors.OverflowKindMedia { + r.recoverFromOversizedTurn(ctx, sess, kind, events) + } } streamSpan.RecordError(err) diff --git a/pkg/runtime/overflow_recovery.go b/pkg/runtime/overflow_recovery.go new file mode 100644 index 000000000..97475083a --- /dev/null +++ b/pkg/runtime/overflow_recovery.go @@ -0,0 +1,289 @@ +package runtime + +import ( + "context" + "fmt" + "log/slog" + "path/filepath" + "strings" + + "github.com/docker/docker-agent/pkg/chat" + "github.com/docker/docker-agent/pkg/modelerrors" + "github.com/docker/docker-agent/pkg/session" +) + +// maxScrubbedTextBytes is the threshold at and below which a user +// message's plain-text content is preserved verbatim during scrubbing. +// Above the threshold the content is replaced with a placeholder that +// records the original size — keeping it verbatim would re-poison the +// session on the next turn by carrying the offending payload back into +// the provider request. +// +// The value is intentionally well below any major provider's wire-size +// cap. Smaller payloads pass through unchanged. +const maxScrubbedTextBytes = 1 << 20 // 1 MiB + +// scrubReport summarises what scrubMessage rewrote, for observability. +// All counters are zero on a no-op scrub. +type scrubReport struct { + // textReplaced is true when [chat.Message.Content] was over + // [maxScrubbedTextBytes] and has been replaced. + textReplaced bool + // originalBytes is the size of the original plain text, set + // only when textReplaced is true. + originalBytes int64 + // partsReplaced counts how many MultiContent parts were + // rewritten: media parts (image/file/document) are always + // counted; oversized text parts are also counted when they + // exceeded [maxScrubbedTextBytes]. + partsReplaced int +} + +func (r scrubReport) didScrub() bool { + return r.textReplaced || r.partsReplaced > 0 +} + +// scrubMessage returns a copy of msg in which media parts (image_url, +// file, document) are replaced with text placeholders and oversized +// plain-text content is replaced with a size-noting placeholder. The +// returned scrubReport describes what changed; when it reports +// [scrubReport.didScrub] false the message is byte-identical to msg. +// +// scrubMessage is pure: it does not consult the session, the model, +// or any context. Callers decide *when* to apply it (post-failure +// recovery, manual cleanup, etc.); this function only describes +// *how*. +func scrubMessage(msg chat.Message) (chat.Message, scrubReport) { + var report scrubReport + + out := msg + + // Plain text: replace only when oversized so we don't lose the + // user's intent for normal-sized messages. + if len(out.Content) > maxScrubbedTextBytes { + report.textReplaced = true + report.originalBytes = int64(len(out.Content)) + out.Content = oversizedTextPlaceholder(int64(len(out.Content))) + } + + // Multi-content parts: rewrite each media part in place. + if len(out.MultiContent) > 0 { + newParts := make([]chat.MessagePart, len(out.MultiContent)) + for i, part := range out.MultiContent { + rewritten, replaced := scrubMessagePart(part) + if replaced { + report.partsReplaced++ + } + newParts[i] = rewritten + } + out.MultiContent = newParts + } + + return out, report +} + +// scrubMessagePart replaces a single attachment part with a text +// placeholder. Returns the rewritten part and whether anything +// changed; small text parts and unrecognised types pass through +// unchanged. +// +// Text parts over [maxScrubbedTextBytes] are themselves rewritten to +// a size-noting placeholder — a single text part inside MultiContent +// can be just as poisoning as oversized [chat.Message.Content]. +// +// The placeholder describes what was attached (kind, name, size) +// without preserving any of the content, so the rewritten message +// can never re-trip the provider's media-size limits. +func scrubMessagePart(part chat.MessagePart) (chat.MessagePart, bool) { + switch part.Type { + case chat.MessagePartTypeText: + if int64(len(part.Text)) > maxScrubbedTextBytes { + return chat.MessagePart{ + Type: chat.MessagePartTypeText, + Text: oversizedTextPlaceholder(int64(len(part.Text))), + }, true + } + return part, false + + case chat.MessagePartTypeImageURL: + return chat.MessagePart{ + Type: chat.MessagePartTypeText, + Text: imagePlaceholder(part), + }, true + + case chat.MessagePartTypeFile: + return chat.MessagePart{ + Type: chat.MessagePartTypeText, + Text: filePlaceholder(part), + }, true + + case chat.MessagePartTypeDocument: + return chat.MessagePart{ + Type: chat.MessagePartTypeText, + Text: documentPlaceholder(part), + }, true + } + // Unknown part type — leave it alone rather than risk dropping + // data we don't recognise. + return part, false +} + +func oversizedTextPlaceholder(originalBytes int64) string { + return fmt.Sprintf( + "[previous message was %s of text — too large for the AI provider; "+ + "content was removed from the session so the conversation can continue]", + humanByteSize(originalBytes), + ) +} + +func imagePlaceholder(part chat.MessagePart) string { + if part.ImageURL == nil { + return "[image attachment removed: too large for the AI provider]" + } + if name := imageDisplayName(part.ImageURL.URL); name != "" { + return fmt.Sprintf("[image %q removed: too large for the AI provider]", name) + } + return "[image attachment removed: too large for the AI provider]" +} + +func filePlaceholder(part chat.MessagePart) string { + if part.File == nil { + return "[file attachment removed: too large for the AI provider]" + } + name := part.File.Path + if name != "" { + name = filepath.Base(name) + } + if name == "" { + name = part.File.FileID + } + if name == "" { + return "[file attachment removed: too large for the AI provider]" + } + return fmt.Sprintf("[file %q removed: too large for the AI provider]", name) +} + +func documentPlaceholder(part chat.MessagePart) string { + if part.Document == nil { + return "[document attachment removed: too large for the AI provider]" + } + doc := part.Document + if doc.Name != "" && doc.Size > 0 { + return fmt.Sprintf("[document %q (%s) removed: too large for the AI provider]", + doc.Name, humanByteSize(doc.Size)) + } + if doc.Name != "" { + return fmt.Sprintf("[document %q removed: too large for the AI provider]", doc.Name) + } + if doc.MimeType != "" { + return fmt.Sprintf("[%s attachment removed: too large for the AI provider]", doc.MimeType) + } + return "[document attachment removed: too large for the AI provider]" +} + +// imageDisplayName extracts a short display name from an image URL. +// Returns "" for data: URIs (where the URL itself is the payload and +// not user-meaningful) and falls back to the URL path's basename +// otherwise. +func imageDisplayName(url string) string { + if url == "" || strings.HasPrefix(url, "data:") { + return "" + } + // Strip query / fragment. + if i := strings.IndexAny(url, "?#"); i >= 0 { + url = url[:i] + } + base := filepath.Base(url) + if base == "/" || base == "." { + return "" + } + return base +} + +// humanByteSize renders n bytes as a short decimal string with binary +// units (KiB, MiB, GiB). Used for placeholder text only; precision is +// limited to one decimal place since this is informational. +func humanByteSize(n int64) string { + const ( + kib = 1 << 10 + mib = 1 << 20 + gib = 1 << 30 + ) + switch { + case n >= gib: + return fmt.Sprintf("%.1f GiB", float64(n)/float64(gib)) + case n >= mib: + return fmt.Sprintf("%.1f MiB", float64(n)/float64(mib)) + case n >= kib: + return fmt.Sprintf("%.1f KiB", float64(n)/float64(kib)) + } + return fmt.Sprintf("%d B", n) +} + +// recoverFromOversizedTurn rewrites the latest user message in sess so +// that the offending content (oversized text, media attachments) is +// neutralised. This is the runtime's in-memory hygiene step after a +// wire- or media-overflow rejection: without it, the same oversized +// turn re-sends on every subsequent call within this process and the +// conversation cannot continue. +// +// Scope: +// - In-memory only. The session store row is NOT updated; a +// docker-agent restart mid-session will reload the original +// oversized payload from disk. Mirroring the rewrite to the +// store requires propagating Message.ID from Store.AddMessage +// back into the in-memory session, which is an independent +// persistence-layer fix tracked as a separate change. +// - Only called for [modelerrors.OverflowKindWire] and +// [modelerrors.OverflowKindMedia]. Token overflow is handled by +// auto-compaction (a different mechanism). +// - Mutates only the most recent user message. Earlier turns are +// left alone — the heuristic is that the latest turn is the one +// that just tripped the provider; older turns must have been +// accepted at some point. +func (r *LocalRuntime) recoverFromOversizedTurn( + ctx context.Context, + sess *session.Session, + kind modelerrors.OverflowKind, + events EventSink, +) { + var report scrubReport + rewrote := sess.RewriteLatestUserMessage(func(msg chat.Message) (chat.Message, bool) { + rewritten, r := scrubMessage(msg) + if !r.didScrub() { + return msg, false + } + report = r + return rewritten, true + }) + if !rewrote { + // Nothing oversized to scrub (e.g. the offending content was + // already small, or the session has no user message yet). + return + } + + slog.InfoContext(ctx, "Scrubbed oversized user message after overflow", + "session_id", sess.ID, + "overflow_kind", string(kind), + "text_replaced", report.textReplaced, + "original_text_bytes", report.originalBytes, + "parts_replaced", report.partsReplaced, + ) + emitScrubNotice(events, report) +} + +// emitScrubNotice surfaces an informational warning so the UI can show +// the user that their message was rewritten in place. Without this the +// recovery is silent and the user sees only "Your message is too +// large" — they wouldn't know that the offending content has been +// removed from the conversation history. +func emitScrubNotice(events EventSink, report scrubReport) { + if events == nil || !report.didScrub() { + return + } + events.Emit(Warning( + "Your previous message was too large and has been rewritten in the "+ + "conversation history. Send a smaller message to continue.", + "", + )) +} diff --git a/pkg/runtime/overflow_recovery_test.go b/pkg/runtime/overflow_recovery_test.go new file mode 100644 index 000000000..3bad3ec46 --- /dev/null +++ b/pkg/runtime/overflow_recovery_test.go @@ -0,0 +1,390 @@ +package runtime + +import ( + "errors" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel/trace/noop" + + "github.com/docker/docker-agent/pkg/agent" + "github.com/docker/docker-agent/pkg/chat" + "github.com/docker/docker-agent/pkg/modelerrors" + "github.com/docker/docker-agent/pkg/session" + "github.com/docker/docker-agent/pkg/team" +) + +// --- scrubMessage / scrubMessagePart unit tests --- + +func TestScrubMessage_TextBelowThreshold_PassesThrough(t *testing.T) { + t.Parallel() + + msg := chat.Message{ + Role: chat.MessageRoleUser, + Content: strings.Repeat("a", 1024), + } + out, report := scrubMessage(msg) + assert.False(t, report.didScrub(), "small text must not be scrubbed") + assert.Equal(t, msg, out, "small text must pass through byte-identical") +} + +func TestScrubMessage_OversizedText_Replaced(t *testing.T) { + t.Parallel() + + original := strings.Repeat("z", maxScrubbedTextBytes+1) + msg := chat.Message{Role: chat.MessageRoleUser, Content: original} + + out, report := scrubMessage(msg) + assert.True(t, report.textReplaced) + assert.Equal(t, int64(len(original)), report.originalBytes) + assert.NotEqual(t, original, out.Content, "oversized text must be rewritten") + assert.Contains(t, out.Content, "too large", "placeholder must signal the cause") +} + +func TestScrubMessage_ImageURLPart_BecomesTextPlaceholder(t *testing.T) { + t.Parallel() + + msg := chat.Message{ + Role: chat.MessageRoleUser, + MultiContent: []chat.MessagePart{ + {Type: chat.MessagePartTypeText, Text: "look at this"}, + {Type: chat.MessagePartTypeImageURL, ImageURL: &chat.MessageImageURL{URL: "https://example.com/foo.png"}}, + }, + } + + out, report := scrubMessage(msg) + require.True(t, report.didScrub()) + assert.Equal(t, 1, report.partsReplaced) + require.Len(t, out.MultiContent, 2) + assert.Equal(t, chat.MessagePartTypeText, out.MultiContent[0].Type, "text part untouched") + assert.Equal(t, "look at this", out.MultiContent[0].Text) + assert.Equal(t, chat.MessagePartTypeText, out.MultiContent[1].Type, "media replaced by text") + assert.Contains(t, out.MultiContent[1].Text, "foo.png", "placeholder must include the name when available") +} + +func TestScrubMessage_DataURLImage_NameOmitted(t *testing.T) { + t.Parallel() + + msg := chat.Message{ + Role: chat.MessageRoleUser, + MultiContent: []chat.MessagePart{ + {Type: chat.MessagePartTypeImageURL, ImageURL: &chat.MessageImageURL{URL: "data:image/png;base64,AAAA"}}, + }, + } + out, report := scrubMessage(msg) + assert.True(t, report.didScrub()) + assert.Equal(t, chat.MessagePartTypeText, out.MultiContent[0].Type) + // data: URI is not user-meaningful — the placeholder should not + // leak the base64 payload and should still describe what was + // removed. + assert.NotContains(t, out.MultiContent[0].Text, "AAAA") + assert.Contains(t, out.MultiContent[0].Text, "image attachment") +} + +func TestScrubMessage_FilePart_BecomesPlaceholder(t *testing.T) { + t.Parallel() + + msg := chat.Message{ + Role: chat.MessageRoleUser, + MultiContent: []chat.MessagePart{ + {Type: chat.MessagePartTypeFile, File: &chat.MessageFile{Path: "/tmp/big.log"}}, + }, + } + out, report := scrubMessage(msg) + assert.True(t, report.didScrub()) + assert.Contains(t, out.MultiContent[0].Text, "big.log") +} + +func TestScrubMessage_DocumentPart_IncludesSize(t *testing.T) { + t.Parallel() + + msg := chat.Message{ + Role: chat.MessageRoleUser, + MultiContent: []chat.MessagePart{ + {Type: chat.MessagePartTypeDocument, Document: &chat.Document{ + Name: "report.pdf", MimeType: "application/pdf", Size: 3 * 1024 * 1024, + }}, + }, + } + out, report := scrubMessage(msg) + assert.True(t, report.didScrub()) + assert.Contains(t, out.MultiContent[0].Text, "report.pdf") + assert.Contains(t, out.MultiContent[0].Text, "MiB", "size should be human-readable") +} + +func TestScrubMessage_SmallTextPart_PassesThrough(t *testing.T) { + t.Parallel() + + msg := chat.Message{ + Role: chat.MessageRoleUser, + MultiContent: []chat.MessagePart{ + {Type: chat.MessagePartTypeText, Text: "just text"}, + }, + } + out, report := scrubMessage(msg) + assert.False(t, report.didScrub(), "small text in multi-content must not be scrubbed") + assert.Equal(t, msg, out) +} + +// TestScrubMessage_OversizedTextPart_Replaced verifies that an +// oversized text blob inside MultiContent is rewritten just like a +// top-level Content payload. A pure-text overflow can arrive as +// either, and the scrub must catch both shapes. +func TestScrubMessage_OversizedTextPart_Replaced(t *testing.T) { + t.Parallel() + + original := strings.Repeat("q", maxScrubbedTextBytes+1) + msg := chat.Message{ + Role: chat.MessageRoleUser, + MultiContent: []chat.MessagePart{ + {Type: chat.MessagePartTypeText, Text: "preserved preamble"}, + {Type: chat.MessagePartTypeText, Text: original}, + }, + } + out, report := scrubMessage(msg) + assert.True(t, report.didScrub()) + assert.Equal(t, 1, report.partsReplaced) + require.Len(t, out.MultiContent, 2) + assert.Equal(t, "preserved preamble", out.MultiContent[0].Text, + "small text parts must pass through untouched") + assert.NotEqual(t, original, out.MultiContent[1].Text, + "oversized text part must be rewritten") + assert.Contains(t, out.MultiContent[1].Text, "too large") +} + +func TestScrubMessage_MultipleMediaParts_AllReplaced(t *testing.T) { + t.Parallel() + + msg := chat.Message{ + Role: chat.MessageRoleUser, + MultiContent: []chat.MessagePart{ + {Type: chat.MessagePartTypeImageURL, ImageURL: &chat.MessageImageURL{URL: "https://a/1.png"}}, + {Type: chat.MessagePartTypeFile, File: &chat.MessageFile{Path: "/tmp/2.log"}}, + {Type: chat.MessagePartTypeDocument, Document: &chat.Document{Name: "3.pdf"}}, + }, + } + out, report := scrubMessage(msg) + assert.Equal(t, 3, report.partsReplaced) + for _, part := range out.MultiContent { + assert.Equal(t, chat.MessagePartTypeText, part.Type) + } +} + +func TestScrubMessage_TextAndMediaTogether(t *testing.T) { + t.Parallel() + + oversized := strings.Repeat("x", maxScrubbedTextBytes+512) + msg := chat.Message{ + Role: chat.MessageRoleUser, + Content: oversized, + MultiContent: []chat.MessagePart{ + {Type: chat.MessagePartTypeImageURL, ImageURL: &chat.MessageImageURL{URL: "https://a/x.png"}}, + }, + } + out, report := scrubMessage(msg) + assert.True(t, report.textReplaced) + assert.Equal(t, 1, report.partsReplaced) + assert.NotEqual(t, oversized, out.Content) + assert.Equal(t, chat.MessagePartTypeText, out.MultiContent[0].Type) +} + +// --- recoverFromOversizedTurn integration tests --- + +func TestRecoverFromOversizedTurn_NoUserMessage_NoOp(t *testing.T) { + t.Parallel() + + rt, _ := newTestRuntime(t) + sess := session.New() + events := make(chan Event, 4) + + rt.recoverFromOversizedTurn(t.Context(), sess, modelerrors.OverflowKindWire, NewChannelSink(events)) + assert.Empty(t, drainEvents(events), "empty session should produce no events") +} + +func TestRecoverFromOversizedTurn_SmallMessage_NoOp(t *testing.T) { + t.Parallel() + + rt, _ := newTestRuntime(t) + sess := session.New() + sess.AddMessage(session.UserMessage("hello")) + events := make(chan Event, 4) + + rt.recoverFromOversizedTurn(t.Context(), sess, modelerrors.OverflowKindWire, NewChannelSink(events)) + assert.Empty(t, drainEvents(events)) + assert.Equal(t, "hello", sess.GetLastUserMessageContent(), "small message stays verbatim") +} + +func TestRecoverFromOversizedTurn_OversizedText_Rewrites(t *testing.T) { + t.Parallel() + + rt, _ := newTestRuntime(t) + sess := session.New() + original := strings.Repeat("Y", maxScrubbedTextBytes+1) + sess.AddMessage(session.UserMessage(original)) + events := make(chan Event, 4) + + rt.recoverFromOversizedTurn(t.Context(), sess, modelerrors.OverflowKindWire, NewChannelSink(events)) + rewritten := sess.GetLastUserMessageContent() + assert.NotEqual(t, original, rewritten, "oversized text should have been rewritten") + assert.Contains(t, rewritten, "too large") + + var sawWarning bool + for _, ev := range drainEvents(events) { + if _, ok := ev.(*WarningEvent); ok { + sawWarning = true + } + } + assert.True(t, sawWarning, "scrub should emit a Warning so the UI can inform the user") +} + +func TestRecoverFromOversizedTurn_OnlyLatestUserMessage(t *testing.T) { + t.Parallel() + + rt, _ := newTestRuntime(t) + sess := session.New() + + old := strings.Repeat("o", maxScrubbedTextBytes+1) + sess.AddMessage(session.UserMessage(old)) + // Subsequent assistant + user messages — the scrub must only + // touch the most recent user turn. + sess.AddMessage(&session.Message{Message: chat.Message{Role: chat.MessageRoleAssistant, Content: "ok"}}) + sess.AddMessage(session.UserMessage("short")) + + events := make(chan Event, 4) + rt.recoverFromOversizedTurn(t.Context(), sess, modelerrors.OverflowKindWire, NewChannelSink(events)) + + // The latest user message ("short") is small — nothing to scrub. + assert.Equal(t, "short", sess.GetLastUserMessageContent()) + // The OLDER oversized message must NOT have been touched. + all := sess.GetAllMessages() + require.GreaterOrEqual(t, len(all), 1) + assert.Equal(t, old, all[0].Message.Content, + "older user messages must not be scrubbed — only the latest is suspect") + assert.Empty(t, drainEvents(events)) +} + +// --- handleStreamError integration --- + +func TestHandleStreamError_WireOverflowScrubsSession(t *testing.T) { + t.Parallel() + + prov := &mockProvider{id: "test/mock-model"} + root := agent.New("root", "test", agent.WithModel(prov)) + tm := team.New(team.WithAgents(root)) + rt, err := NewLocalRuntime(tm, WithModelStore(mockModelStore{})) + require.NoError(t, err) + + sess := session.New() + original := strings.Repeat("L", maxScrubbedTextBytes+10) + sess.AddMessage(session.UserMessage(original)) + + events := make(chan Event, 16) + _, sp := noop.NewTracerProvider().Tracer("t").Start(t.Context(), "x") + + overflow := &modelerrors.ContextOverflowError{ + Underlying: errors.New("HTTP 413: Payload Too Large"), + Kind: modelerrors.OverflowKindWire, + } + overflowCount := 0 + + outcome := rt.handleStreamError(t.Context(), sess, root, overflow, 1000, &overflowCount, sp, NewChannelSink(events)) + + assert.Equal(t, streamErrorFatal, outcome) + assert.NotEqual(t, original, sess.GetLastUserMessageContent(), + "wire overflow must scrub the offending user turn so future calls in this process don't re-fail") + + // The ErrorEvent for the rejection MUST still be emitted — + // scrubbing is in addition to the error, not instead of it. + var sawErrorEvent bool + for _, ev := range drainEvents(events) { + if e, ok := ev.(*ErrorEvent); ok { + sawErrorEvent = true + assert.Equal(t, ErrorCodeRequestTooLarge, e.Code, + "wire overflow still surfaces the request-too-large code") + } + } + assert.True(t, sawErrorEvent) +} + +// TestHandleStreamError_WireOverflowScrubsEvenWithCompactionDisabled +// pins that the hygiene scrub for wire/media overflow is independent +// of the session-compaction config. Compaction is irrelevant here — +// the scrub rewrites a single message and does not retry, and the +// in-process bug it fixes happens regardless of whether the user +// opted into auto-compaction. +func TestHandleStreamError_WireOverflowScrubsEvenWithCompactionDisabled(t *testing.T) { + t.Parallel() + + prov := &mockProvider{id: "test/mock-model"} + root := agent.New("root", "test", agent.WithModel(prov)) + tm := team.New(team.WithAgents(root)) + rt, err := NewLocalRuntime(tm, + WithSessionCompaction(false), + WithModelStore(mockModelStore{}), + ) + require.NoError(t, err) + + sess := session.New() + original := strings.Repeat("W", maxScrubbedTextBytes+10) + sess.AddMessage(session.UserMessage(original)) + + events := make(chan Event, 16) + _, sp := noop.NewTracerProvider().Tracer("t").Start(t.Context(), "x") + + overflow := &modelerrors.ContextOverflowError{ + Underlying: errors.New("HTTP 413: Payload Too Large"), + Kind: modelerrors.OverflowKindWire, + } + overflowCount := 0 + + outcome := rt.handleStreamError(t.Context(), sess, root, overflow, 1000, &overflowCount, sp, NewChannelSink(events)) + assert.Equal(t, streamErrorFatal, outcome) + assert.NotEqual(t, original, sess.GetLastUserMessageContent(), + "scrub must run even when session compaction is disabled") +} + +// --- Session.RewriteLatestUserMessage contract --- + +func TestSessionRewriteLatestUserMessage_FindsMostRecentUser(t *testing.T) { + t.Parallel() + + sess := session.New() + sess.AddMessage(session.UserMessage("first")) + sess.AddMessage(&session.Message{Message: chat.Message{Role: chat.MessageRoleAssistant, Content: "reply"}}) + sess.AddMessage(session.UserMessage("second")) + + var seen string + ok := sess.RewriteLatestUserMessage(func(m chat.Message) (chat.Message, bool) { + seen = m.Content + m.Content = "scrubbed" + return m, true + }) + assert.True(t, ok) + assert.Equal(t, "second", seen, "rewrite should target the latest user message") + assert.Equal(t, "scrubbed", sess.GetLastUserMessageContent()) +} + +func TestSessionRewriteLatestUserMessage_OptOut_DoesNotMutate(t *testing.T) { + t.Parallel() + + sess := session.New() + sess.AddMessage(session.UserMessage("keep")) + + ok := sess.RewriteLatestUserMessage(func(m chat.Message) (chat.Message, bool) { + return m, false + }) + assert.False(t, ok) + assert.Equal(t, "keep", sess.GetLastUserMessageContent()) +} + +func TestSessionRewriteLatestUserMessage_NoUserMessages(t *testing.T) { + t.Parallel() + + sess := session.New() + ok := sess.RewriteLatestUserMessage(func(m chat.Message) (chat.Message, bool) { + return m, true + }) + assert.False(t, ok) +} diff --git a/pkg/session/session.go b/pkg/session/session.go index e2f426827..49f09948e 100644 --- a/pkg/session/session.go +++ b/pkg/session/session.go @@ -431,6 +431,55 @@ func (s *Session) ApplyCompaction(inputTokens, outputTokens int64, item Item) { s.mu.Unlock() } +// RewriteLatestUserMessage atomically rewrites the most recent user +// message in s by passing its chat.Message to rewrite and replacing +// it with the returned value. The slot is found and updated under +// s.mu so concurrent readers (snapshotItems, persistence) cannot +// observe a torn state. +// +// rewrite is called at most once. If it returns false the message is +// left unchanged. The boolean return reports whether anything was +// rewritten — false when there is no user message in s or when +// rewrite opted out. +// +// This is the runtime's hook for in-place message hygiene after a +// failure that would otherwise poison the session — see the wire/ +// media overflow recovery in pkg/runtime. +// +// Scope: only items at the top level of s.Messages are considered. +// Sub-session items (Item.SubSession) are skipped; the function does +// not recurse into them. Callers that route user turns through a +// sub-session must rewrite the target message on the sub-session +// directly. The contract is intentional — sub-sessions own their +// own conversation transcript and the parent should not reach into +// them without going through their store. +// +// The rewrite is in-memory only. To mirror it to the session store +// (so it survives a docker-agent restart), the caller must follow up +// with a separate [Store.UpdateMessage] — which today requires a +// persistence ID that the runtime does not yet round-trip through +// [Store.AddMessage]. Closing that gap is a separate piece of work. +func (s *Session) RewriteLatestUserMessage(rewrite func(chat.Message) (chat.Message, bool)) bool { + s.mu.Lock() + defer s.mu.Unlock() + for i := range slices.Backward(s.Messages) { + item := &s.Messages[i] + if !item.IsMessage() { + continue + } + if item.Message.Message.Role != chat.MessageRoleUser { + continue + } + newMsg, ok := rewrite(item.Message.Message) + if !ok { + return false + } + item.Message.Message = newMsg + return true + } + return false +} + // AddSubSession adds a sub-session to the session func (s *Session) AddSubSession(subSession *Session) { s.mu.Lock()