From bee711f431b38e74b0a2cb014e52791270ffb38c Mon Sep 17 00:00:00 2001 From: James F <47167674+GhostDragon124@users.noreply.github.com> Date: Tue, 9 Jun 2026 21:49:05 +0800 Subject: [PATCH] refactor(acp): make bridge SDK message handling type-safe (#1265) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * refactor(acp): make bridge SDK message handling type-safe - Add BridgeSDKMessage type alias to eliminate 14 type errors from void-leaked IteratorResult - Replace 18 scattered as-casts with a single uniform as BridgeSDKMessage - Add 68 lines of unit tests covering bridge message handling - Fixes docstring coverage to pass CI threshold * fix(acp): restore IteratorResult return type to nextSdkMessageOrAbort The simplified SDKMessage | undefined return type collapsed two distinct states: generator truly done vs generator yielding undefined. This broke forwardSessionUpdates which needs to distinguish the two — when the generator yields null/undefined it should continue (calling next() again), not break out of the loop. Restored the original IteratorResult return type so done and yielded-null are distinct again. --- src/services/acp/__tests__/bridge.test.ts | 68 +++++++ src/services/acp/bridge.ts | 235 +++++++++++++++++----- 2 files changed, 249 insertions(+), 54 deletions(-) diff --git a/src/services/acp/__tests__/bridge.test.ts b/src/services/acp/__tests__/bridge.test.ts index 99900ff3e..b5a9cece2 100644 --- a/src/services/acp/__tests__/bridge.test.ts +++ b/src/services/acp/__tests__/bridge.test.ts @@ -4,6 +4,7 @@ import { toolUpdateFromToolResult, toolUpdateFromEditToolResponse, forwardSessionUpdates, + nextSdkMessageOrAbort, } from '../bridge.js' import { promptToQueryInput } from '../promptConversion.js' import { markdownEscape, toDisplayPath } from '../utils.js' @@ -30,6 +31,10 @@ async function* makeStream( for (const m of msgs) yield m } +async function* makeWaitingStream(): AsyncGenerator { + await new Promise(() => {}) +} + // ── toolInfoFromToolUse ──────────────────────────────────────────── describe('toolInfoFromToolUse', () => { @@ -692,6 +697,47 @@ describe('toDisplayPath', () => { // ── forwardSessionUpdates ───────────────────────────────────────── +describe('nextSdkMessageOrAbort', () => { + test('returns done:true when aborted while waiting for next message', async () => { + const ac = new AbortController() + const pending = nextSdkMessageOrAbort(makeWaitingStream(), ac.signal) + ac.abort() + + const result = await Promise.race([ + pending, + new Promise<'timeout'>(resolve => setTimeout(resolve, 100, 'timeout')), + ]) + + expect(result).toEqual({ done: true, value: undefined }) + }) + + test('returns done:true when stream is done', async () => { + const result = await nextSdkMessageOrAbort( + makeStream([]), + new AbortController().signal, + ) + + expect(result).toEqual({ done: true, value: undefined }) + }) + + test('returns a valid SDKMessage via IteratorResult', async () => { + const msg = { + type: 'assistant', + message: { + role: 'assistant', + content: [{ type: 'text', text: 'hello' }], + }, + } as unknown as SDKMessage + + const result = await nextSdkMessageOrAbort( + makeStream([msg]), + new AbortController().signal, + ) + + expect(result).toEqual({ done: false, value: msg }) + }) +}) + describe('forwardSessionUpdates', () => { test('returns end_turn when stream is empty', async () => { const conn = makeConn() @@ -1077,6 +1123,28 @@ describe('forwardSessionUpdates', () => { ).toBe(0) }) + test('ignores unknown message types without crashing', async () => { + const conn = makeConn() + const debug = console.debug + const debugMock = mock(() => {}) + console.debug = debugMock as typeof console.debug + + try { + const result = await forwardSessionUpdates( + 's1', + makeStream([{ type: 'future_message' } as unknown as SDKMessage]), + conn, + new AbortController().signal, + {}, + ) + + expect(result.stopReason).toBe('end_turn') + expect(debugMock).toHaveBeenCalled() + } finally { + console.debug = debug + } + }) + test('re-throws unexpected errors from stream', async () => { const conn = makeConn() async function* errorStream(): AsyncGenerator< diff --git a/src/services/acp/bridge.ts b/src/services/acp/bridge.ts index 58c6d284b..2c51c0f51 100644 --- a/src/services/acp/bridge.ts +++ b/src/services/acp/bridge.ts @@ -28,6 +28,7 @@ import { toDisplayPath, markdownEscape } from './utils.js' // ── ToolUseCache ────────────────────────────────────────────────── +/** Maps tool_use_id → tool metadata for tracked inflight tool calls. */ export type ToolUseCache = { [key: string]: { type: 'tool_use' | 'server_tool_use' | 'mcp_tool_use' @@ -39,6 +40,7 @@ export type ToolUseCache = { // ── Session usage tracking ──────────────────────────────────────── +/** Accumulated token usage across a session, updated per result message. */ export type SessionUsage = { inputTokens: number outputTokens: number @@ -46,8 +48,139 @@ export type SessionUsage = { cachedWriteTokens: number } +/** Token usage reported in SDK result messages. */ +type BridgeUsage = { + input_tokens?: number + output_tokens?: number + cache_read_input_tokens?: number + cache_creation_input_tokens?: number +} + +/** system-init, compact_boundary, status, api_retry, local_command_output messages. */ +type BridgeSystemMessage = { + type: 'system' + subtype?: string + session_id?: string + content?: string + status?: string + compact_result?: string + compact_error?: string + model?: string + uuid?: string + [key: string]: unknown +} + +/** Turn completion message: success with usage, or error with stop_reason. */ +type BridgeResultMessage = { + type: 'result' + subtype?: string + usage?: BridgeUsage + modelUsage?: Record + total_cost_usd?: number + is_error?: boolean + stop_reason?: string | null + result?: string + errors?: string[] + duration_ms?: number + duration_api_ms?: number + num_turns?: number + permission_denials?: unknown[] + session_id?: string + [key: string]: unknown +} + +/** Full assistant response message after the turn completes. */ +type BridgeAssistantMessage = { + type: 'assistant' + message?: { + role?: string + id?: string + model?: string + content?: string | Array> + usage?: BridgeUsage | Record + stop_reason?: string | null + [key: string]: unknown + } + parent_tool_use_id?: string | null + uuid?: string + session_id?: string + error?: unknown + [key: string]: unknown +} + +/** Real-time streaming event (aka partial_assistant in the SDK schema). */ +type BridgeStreamEventMessage = { + type: 'stream_event' + event?: { type?: string; [key: string]: unknown } + message?: Record + parent_tool_use_id?: string | null + session_id?: string + uuid?: string + [key: string]: unknown +} + +/** User prompt message (may include tool_use_result from prior turns). */ +type BridgeUserMessage = { + type: 'user' + message?: Record + uuid?: string + isReplay?: boolean + isMeta?: boolean + timestamp?: string + [key: string]: unknown +} + +/** Subagent or hook progress notification (internal, not an SDK message member). */ +type BridgeProgressMessage = { + type: 'progress' + data?: { + type?: string + message?: Record + [key: string]: unknown + } + [key: string]: unknown +} + +/** Summary of tool calls made during a turn. */ +type BridgeToolUseSummaryMessage = { + type: 'tool_use_summary' + summary?: string + preceding_tool_use_ids?: string[] + uuid?: string + session_id?: string + [key: string]: unknown +} + +/** File attachment metadata (internal, not an SDK message member). */ +type BridgeAttachmentMessage = { + type: 'attachment' + [key: string]: unknown +} + +/** Compaction boundary marker (type is 'compact_boundary', not 'system'). */ +type BridgeCompactBoundaryMessage = { + type: 'compact_boundary' + compact_metadata?: Record + [key: string]: unknown +} + +/** ACP bridge local discriminated union — covers all message shapes consumed by the forwarding loop. */ +type BridgeSDKMessage = + | BridgeSystemMessage + | BridgeResultMessage + | BridgeAssistantMessage + | BridgeStreamEventMessage + | BridgeUserMessage + | BridgeProgressMessage + | BridgeToolUseSummaryMessage + | BridgeAttachmentMessage + | BridgeCompactBoundaryMessage + +const logger: { debug: (...args: unknown[]) => void } = console + // ── Tool info conversion ────────────────────────────────────────── +/** Sanitised tool metadata sent to ACP client for tool_call notifications. */ interface ToolInfo { title: string kind: ToolKind @@ -519,6 +652,7 @@ function toAcpContentBlock( // ── Edit tool response → diff ────────────────────────────────────── +/** Context lines and diff metadata for one hunk of an Edit tool response. */ interface EditToolResponseHunk { oldStart: number oldLines: number @@ -527,6 +661,7 @@ interface EditToolResponseHunk { lines: string[] } +/** Result block for Edit/Write tool responses containing hunks and optional file stats. */ interface EditToolResponse { filePath?: string structuredPatch?: EditToolResponseHunk[] @@ -581,14 +716,13 @@ export function toolUpdateFromEditToolResponse(toolResponse: unknown): { return result } -function nextSdkMessageOrAbort( +export function nextSdkMessageOrAbort( sdkMessages: AsyncGenerator, abortSignal: AbortSignal, ): Promise> { if (abortSignal.aborted) { return Promise.resolve({ done: true, value: undefined }) } - let abortHandler: (() => void) | undefined const abortPromise = new Promise>( resolve => { @@ -596,7 +730,6 @@ function nextSdkMessageOrAbort( abortSignal.addEventListener('abort', abortHandler, { once: true }) }, ) - return Promise.race([sdkMessages.next(), abortPromise]).finally(() => { if (abortHandler) { abortSignal.removeEventListener('abort', abortHandler) @@ -642,16 +775,14 @@ export async function forwardSessionUpdates( // a slow API response. const nextResult = await nextSdkMessageOrAbort(sdkMessages, abortSignal) if (nextResult.done || abortSignal.aborted) break - const msg = nextResult.value + const rawMsg = nextResult.value + if (rawMsg == null) continue + const msg = rawMsg as BridgeSDKMessage - if (msg == null) continue - - const type = msg.type as string - - switch (type) { + switch (msg.type) { // ── System messages ──────────────────────────────────────── case 'system': { - const subtype = msg.subtype as string | undefined + const subtype = msg.subtype if (subtype === 'compact_boundary') { // Reset assistant usage tracking after compaction @@ -679,27 +810,19 @@ export async function forwardSessionUpdates( // ── Result messages ──────────────────────────────────────── case 'result': { - const usage = msg.usage as - | { - input_tokens: number - output_tokens: number - cache_read_input_tokens: number - cache_creation_input_tokens: number - } - | undefined + const usage = msg.usage if (usage) { - accumulatedUsage.inputTokens += usage.input_tokens - accumulatedUsage.outputTokens += usage.output_tokens - accumulatedUsage.cachedReadTokens += usage.cache_read_input_tokens + accumulatedUsage.inputTokens += usage.input_tokens ?? 0 + accumulatedUsage.outputTokens += usage.output_tokens ?? 0 + accumulatedUsage.cachedReadTokens += + usage.cache_read_input_tokens ?? 0 accumulatedUsage.cachedWriteTokens += - usage.cache_creation_input_tokens + usage.cache_creation_input_tokens ?? 0 } // Resolve context window size from modelUsage via prefix matching - const modelUsage = msg.modelUsage as - | Record - | undefined + const modelUsage = msg.modelUsage if (modelUsage && lastAssistantModel) { const match = getMatchingModelUsage(modelUsage, lastAssistantModel) if (match?.contextWindow) { @@ -716,7 +839,7 @@ export async function forwardSessionUpdates( accumulatedUsage.cachedReadTokens + accumulatedUsage.cachedWriteTokens - const totalCostUsd = msg.total_cost_usd as number | undefined + const totalCostUsd = msg.total_cost_usd await conn.sessionUpdate({ sessionId, update: { @@ -731,8 +854,8 @@ export async function forwardSessionUpdates( }) // Determine stop reason - const subtype = msg.subtype as string | undefined - const isError = msg.is_error as boolean | undefined + const subtype = msg.subtype + const isError = msg.is_error if (abortSignal.aborted) { stopReason = 'cancelled' @@ -741,7 +864,7 @@ export async function forwardSessionUpdates( switch (subtype) { case 'success': { - const stopReasonStr = msg.stop_reason as string | null + const stopReasonStr = msg.stop_reason if (stopReasonStr === 'max_tokens') { stopReason = 'max_tokens' } @@ -752,7 +875,7 @@ export async function forwardSessionUpdates( break } case 'error_during_execution': { - if ((msg.stop_reason as string | null) === 'max_tokens') { + if (msg.stop_reason === 'max_tokens') { stopReason = 'max_tokens' } else if (isError) { stopReason = 'end_turn' @@ -797,20 +920,23 @@ export async function forwardSessionUpdates( case 'assistant': { // Track last assistant total usage for context window computation // (only for top-level messages, not subagents) - const assistantMsg = msg.message as - | Record - | undefined - const parentToolUseId = msg.parent_tool_use_id as - | string - | null - | undefined + const assistantMsg = msg.message + const parentToolUseId = msg.parent_tool_use_id if (assistantMsg?.usage && parentToolUseId === null) { - const msgUsage = assistantMsg.usage as Record + const usage = assistantMsg.usage lastAssistantTotalUsage = - ((msgUsage.input_tokens as number) ?? 0) + - ((msgUsage.output_tokens as number) ?? 0) + - ((msgUsage.cache_read_input_tokens as number) ?? 0) + - ((msgUsage.cache_creation_input_tokens as number) ?? 0) + (typeof usage.input_tokens === 'number' + ? usage.input_tokens + : 0) + + (typeof usage.output_tokens === 'number' + ? usage.output_tokens + : 0) + + (typeof usage.cache_read_input_tokens === 'number' + ? usage.cache_read_input_tokens + : 0) + + (typeof usage.cache_creation_input_tokens === 'number' + ? usage.cache_creation_input_tokens + : 0) } // Track the current top-level model for context window size lookup if ( @@ -818,7 +944,7 @@ export async function forwardSessionUpdates( assistantMsg?.model && assistantMsg.model !== '' ) { - lastAssistantModel = assistantMsg.model as string + lastAssistantModel = assistantMsg.model } const notifications = assistantMessageToAcpNotifications( @@ -848,18 +974,16 @@ export async function forwardSessionUpdates( // ── Progress messages ────────────────────────────────────── case 'progress': { - const progressData = msg.data as Record | undefined + const progressData = msg.data if (!progressData) break // Handle agent/skill subagent progress - const progressType = progressData.type as string | undefined + const progressType = progressData.type if ( progressType === 'agent_progress' || progressType === 'skill_progress' ) { - const progressMessage = progressData.message as - | Record - | undefined + const progressMessage = progressData.message if (progressMessage) { const content = progressMessage.content as | Array> @@ -916,7 +1040,7 @@ export async function forwardSessionUpdates( } default: - // Ignore unknown message types + logger.debug('Ignoring unknown SDK message type') break } } @@ -1278,19 +1402,22 @@ export async function replayHistoryMessages( clientCapabilities?: ClientCapabilities, cwd?: string, ): Promise { - for (const msg of messages) { - const type = msg.type as string + for (const rawMsg of messages) { + const msg = rawMsg as BridgeSDKMessage // Skip non-conversation messages - if (type !== 'user' && type !== 'assistant') continue + if (msg.type !== 'user' && msg.type !== 'assistant') { + logger.debug('Ignoring unknown SDK message type') + continue + } // Skip meta messages (synthetic continuation prompts) if (msg.isMeta === true) continue - const messageData = msg.message as Record | undefined + const messageData = msg.message const content = messageData?.content if (!content) continue const role: 'assistant' | 'user' = - type === 'assistant' ? 'assistant' : 'user' + msg.type === 'assistant' ? 'assistant' : 'user' if (typeof content === 'string') { if (!content.trim()) continue