refactor(acp): make bridge SDK message handling type-safe (#1265)

* refactor(acp): make bridge SDK message handling type-safe

- Add BridgeSDKMessage type alias to eliminate 14 type errors from void-leaked IteratorResult
- Replace 18 scattered as-casts with a single uniform as BridgeSDKMessage
- Add 68 lines of unit tests covering bridge message handling
- Fixes docstring coverage to pass CI threshold

* fix(acp): restore IteratorResult return type to nextSdkMessageOrAbort

The simplified SDKMessage | undefined return type collapsed two distinct
states: generator truly done vs generator yielding undefined. This broke
forwardSessionUpdates which needs to distinguish the two — when the
generator yields null/undefined it should continue (calling next() again),
not break out of the loop.

Restored the original IteratorResult<SDKMessage, void> return type so
done and yielded-null are distinct again.
This commit is contained in:
James F
2026-06-09 21:49:05 +08:00
committed by GitHub
parent 4d930eb4eb
commit bee711f431
2 changed files with 249 additions and 54 deletions

View File

@@ -4,6 +4,7 @@ import {
toolUpdateFromToolResult,
toolUpdateFromEditToolResponse,
forwardSessionUpdates,
nextSdkMessageOrAbort,
} from '../bridge.js'
import { promptToQueryInput } from '../promptConversion.js'
import { markdownEscape, toDisplayPath } from '../utils.js'
@@ -30,6 +31,10 @@ async function* makeStream(
for (const m of msgs) yield m
}
async function* makeWaitingStream(): AsyncGenerator<SDKMessage, void, unknown> {
await new Promise<never>(() => {})
}
// ── toolInfoFromToolUse ────────────────────────────────────────────
describe('toolInfoFromToolUse', () => {
@@ -692,6 +697,47 @@ describe('toDisplayPath', () => {
// ── forwardSessionUpdates ─────────────────────────────────────────
describe('nextSdkMessageOrAbort', () => {
test('returns done:true when aborted while waiting for next message', async () => {
const ac = new AbortController()
const pending = nextSdkMessageOrAbort(makeWaitingStream(), ac.signal)
ac.abort()
const result = await Promise.race([
pending,
new Promise<'timeout'>(resolve => setTimeout(resolve, 100, 'timeout')),
])
expect(result).toEqual({ done: true, value: undefined })
})
test('returns done:true when stream is done', async () => {
const result = await nextSdkMessageOrAbort(
makeStream([]),
new AbortController().signal,
)
expect(result).toEqual({ done: true, value: undefined })
})
test('returns a valid SDKMessage via IteratorResult', async () => {
const msg = {
type: 'assistant',
message: {
role: 'assistant',
content: [{ type: 'text', text: 'hello' }],
},
} as unknown as SDKMessage
const result = await nextSdkMessageOrAbort(
makeStream([msg]),
new AbortController().signal,
)
expect(result).toEqual({ done: false, value: msg })
})
})
describe('forwardSessionUpdates', () => {
test('returns end_turn when stream is empty', async () => {
const conn = makeConn()
@@ -1077,6 +1123,28 @@ describe('forwardSessionUpdates', () => {
).toBe(0)
})
test('ignores unknown message types without crashing', async () => {
const conn = makeConn()
const debug = console.debug
const debugMock = mock(() => {})
console.debug = debugMock as typeof console.debug
try {
const result = await forwardSessionUpdates(
's1',
makeStream([{ type: 'future_message' } as unknown as SDKMessage]),
conn,
new AbortController().signal,
{},
)
expect(result.stopReason).toBe('end_turn')
expect(debugMock).toHaveBeenCalled()
} finally {
console.debug = debug
}
})
test('re-throws unexpected errors from stream', async () => {
const conn = makeConn()
async function* errorStream(): AsyncGenerator<

View File

@@ -28,6 +28,7 @@ import { toDisplayPath, markdownEscape } from './utils.js'
// ── ToolUseCache ──────────────────────────────────────────────────
/** Maps tool_use_id → tool metadata for tracked inflight tool calls. */
export type ToolUseCache = {
[key: string]: {
type: 'tool_use' | 'server_tool_use' | 'mcp_tool_use'
@@ -39,6 +40,7 @@ export type ToolUseCache = {
// ── Session usage tracking ────────────────────────────────────────
/** Accumulated token usage across a session, updated per result message. */
export type SessionUsage = {
inputTokens: number
outputTokens: number
@@ -46,8 +48,139 @@ export type SessionUsage = {
cachedWriteTokens: number
}
/** Token usage reported in SDK result messages. */
type BridgeUsage = {
input_tokens?: number
output_tokens?: number
cache_read_input_tokens?: number
cache_creation_input_tokens?: number
}
/** system-init, compact_boundary, status, api_retry, local_command_output messages. */
type BridgeSystemMessage = {
type: 'system'
subtype?: string
session_id?: string
content?: string
status?: string
compact_result?: string
compact_error?: string
model?: string
uuid?: string
[key: string]: unknown
}
/** Turn completion message: success with usage, or error with stop_reason. */
type BridgeResultMessage = {
type: 'result'
subtype?: string
usage?: BridgeUsage
modelUsage?: Record<string, { contextWindow?: number }>
total_cost_usd?: number
is_error?: boolean
stop_reason?: string | null
result?: string
errors?: string[]
duration_ms?: number
duration_api_ms?: number
num_turns?: number
permission_denials?: unknown[]
session_id?: string
[key: string]: unknown
}
/** Full assistant response message after the turn completes. */
type BridgeAssistantMessage = {
type: 'assistant'
message?: {
role?: string
id?: string
model?: string
content?: string | Array<Record<string, unknown>>
usage?: BridgeUsage | Record<string, unknown>
stop_reason?: string | null
[key: string]: unknown
}
parent_tool_use_id?: string | null
uuid?: string
session_id?: string
error?: unknown
[key: string]: unknown
}
/** Real-time streaming event (aka partial_assistant in the SDK schema). */
type BridgeStreamEventMessage = {
type: 'stream_event'
event?: { type?: string; [key: string]: unknown }
message?: Record<string, unknown>
parent_tool_use_id?: string | null
session_id?: string
uuid?: string
[key: string]: unknown
}
/** User prompt message (may include tool_use_result from prior turns). */
type BridgeUserMessage = {
type: 'user'
message?: Record<string, unknown>
uuid?: string
isReplay?: boolean
isMeta?: boolean
timestamp?: string
[key: string]: unknown
}
/** Subagent or hook progress notification (internal, not an SDK message member). */
type BridgeProgressMessage = {
type: 'progress'
data?: {
type?: string
message?: Record<string, unknown>
[key: string]: unknown
}
[key: string]: unknown
}
/** Summary of tool calls made during a turn. */
type BridgeToolUseSummaryMessage = {
type: 'tool_use_summary'
summary?: string
preceding_tool_use_ids?: string[]
uuid?: string
session_id?: string
[key: string]: unknown
}
/** File attachment metadata (internal, not an SDK message member). */
type BridgeAttachmentMessage = {
type: 'attachment'
[key: string]: unknown
}
/** Compaction boundary marker (type is 'compact_boundary', not 'system'). */
type BridgeCompactBoundaryMessage = {
type: 'compact_boundary'
compact_metadata?: Record<string, unknown>
[key: string]: unknown
}
/** ACP bridge local discriminated union — covers all message shapes consumed by the forwarding loop. */
type BridgeSDKMessage =
| BridgeSystemMessage
| BridgeResultMessage
| BridgeAssistantMessage
| BridgeStreamEventMessage
| BridgeUserMessage
| BridgeProgressMessage
| BridgeToolUseSummaryMessage
| BridgeAttachmentMessage
| BridgeCompactBoundaryMessage
const logger: { debug: (...args: unknown[]) => void } = console
// ── Tool info conversion ──────────────────────────────────────────
/** Sanitised tool metadata sent to ACP client for tool_call notifications. */
interface ToolInfo {
title: string
kind: ToolKind
@@ -519,6 +652,7 @@ function toAcpContentBlock(
// ── Edit tool response → diff ──────────────────────────────────────
/** Context lines and diff metadata for one hunk of an Edit tool response. */
interface EditToolResponseHunk {
oldStart: number
oldLines: number
@@ -527,6 +661,7 @@ interface EditToolResponseHunk {
lines: string[]
}
/** Result block for Edit/Write tool responses containing hunks and optional file stats. */
interface EditToolResponse {
filePath?: string
structuredPatch?: EditToolResponseHunk[]
@@ -581,14 +716,13 @@ export function toolUpdateFromEditToolResponse(toolResponse: unknown): {
return result
}
function nextSdkMessageOrAbort(
export function nextSdkMessageOrAbort(
sdkMessages: AsyncGenerator<SDKMessage, void, unknown>,
abortSignal: AbortSignal,
): Promise<IteratorResult<SDKMessage, void>> {
if (abortSignal.aborted) {
return Promise.resolve({ done: true, value: undefined })
}
let abortHandler: (() => void) | undefined
const abortPromise = new Promise<IteratorResult<SDKMessage, void>>(
resolve => {
@@ -596,7 +730,6 @@ function nextSdkMessageOrAbort(
abortSignal.addEventListener('abort', abortHandler, { once: true })
},
)
return Promise.race([sdkMessages.next(), abortPromise]).finally(() => {
if (abortHandler) {
abortSignal.removeEventListener('abort', abortHandler)
@@ -642,16 +775,14 @@ export async function forwardSessionUpdates(
// a slow API response.
const nextResult = await nextSdkMessageOrAbort(sdkMessages, abortSignal)
if (nextResult.done || abortSignal.aborted) break
const msg = nextResult.value
const rawMsg = nextResult.value
if (rawMsg == null) continue
const msg = rawMsg as BridgeSDKMessage
if (msg == null) continue
const type = msg.type as string
switch (type) {
switch (msg.type) {
// ── System messages ────────────────────────────────────────
case 'system': {
const subtype = msg.subtype as string | undefined
const subtype = msg.subtype
if (subtype === 'compact_boundary') {
// Reset assistant usage tracking after compaction
@@ -679,27 +810,19 @@ export async function forwardSessionUpdates(
// ── Result messages ────────────────────────────────────────
case 'result': {
const usage = msg.usage as
| {
input_tokens: number
output_tokens: number
cache_read_input_tokens: number
cache_creation_input_tokens: number
}
| undefined
const usage = msg.usage
if (usage) {
accumulatedUsage.inputTokens += usage.input_tokens
accumulatedUsage.outputTokens += usage.output_tokens
accumulatedUsage.cachedReadTokens += usage.cache_read_input_tokens
accumulatedUsage.inputTokens += usage.input_tokens ?? 0
accumulatedUsage.outputTokens += usage.output_tokens ?? 0
accumulatedUsage.cachedReadTokens +=
usage.cache_read_input_tokens ?? 0
accumulatedUsage.cachedWriteTokens +=
usage.cache_creation_input_tokens
usage.cache_creation_input_tokens ?? 0
}
// Resolve context window size from modelUsage via prefix matching
const modelUsage = msg.modelUsage as
| Record<string, { contextWindow?: number }>
| undefined
const modelUsage = msg.modelUsage
if (modelUsage && lastAssistantModel) {
const match = getMatchingModelUsage(modelUsage, lastAssistantModel)
if (match?.contextWindow) {
@@ -716,7 +839,7 @@ export async function forwardSessionUpdates(
accumulatedUsage.cachedReadTokens +
accumulatedUsage.cachedWriteTokens
const totalCostUsd = msg.total_cost_usd as number | undefined
const totalCostUsd = msg.total_cost_usd
await conn.sessionUpdate({
sessionId,
update: {
@@ -731,8 +854,8 @@ export async function forwardSessionUpdates(
})
// Determine stop reason
const subtype = msg.subtype as string | undefined
const isError = msg.is_error as boolean | undefined
const subtype = msg.subtype
const isError = msg.is_error
if (abortSignal.aborted) {
stopReason = 'cancelled'
@@ -741,7 +864,7 @@ export async function forwardSessionUpdates(
switch (subtype) {
case 'success': {
const stopReasonStr = msg.stop_reason as string | null
const stopReasonStr = msg.stop_reason
if (stopReasonStr === 'max_tokens') {
stopReason = 'max_tokens'
}
@@ -752,7 +875,7 @@ export async function forwardSessionUpdates(
break
}
case 'error_during_execution': {
if ((msg.stop_reason as string | null) === 'max_tokens') {
if (msg.stop_reason === 'max_tokens') {
stopReason = 'max_tokens'
} else if (isError) {
stopReason = 'end_turn'
@@ -797,20 +920,23 @@ export async function forwardSessionUpdates(
case 'assistant': {
// Track last assistant total usage for context window computation
// (only for top-level messages, not subagents)
const assistantMsg = msg.message as
| Record<string, unknown>
| undefined
const parentToolUseId = msg.parent_tool_use_id as
| string
| null
| undefined
const assistantMsg = msg.message
const parentToolUseId = msg.parent_tool_use_id
if (assistantMsg?.usage && parentToolUseId === null) {
const msgUsage = assistantMsg.usage as Record<string, unknown>
const usage = assistantMsg.usage
lastAssistantTotalUsage =
((msgUsage.input_tokens as number) ?? 0) +
((msgUsage.output_tokens as number) ?? 0) +
((msgUsage.cache_read_input_tokens as number) ?? 0) +
((msgUsage.cache_creation_input_tokens as number) ?? 0)
(typeof usage.input_tokens === 'number'
? usage.input_tokens
: 0) +
(typeof usage.output_tokens === 'number'
? usage.output_tokens
: 0) +
(typeof usage.cache_read_input_tokens === 'number'
? usage.cache_read_input_tokens
: 0) +
(typeof usage.cache_creation_input_tokens === 'number'
? usage.cache_creation_input_tokens
: 0)
}
// Track the current top-level model for context window size lookup
if (
@@ -818,7 +944,7 @@ export async function forwardSessionUpdates(
assistantMsg?.model &&
assistantMsg.model !== '<synthetic>'
) {
lastAssistantModel = assistantMsg.model as string
lastAssistantModel = assistantMsg.model
}
const notifications = assistantMessageToAcpNotifications(
@@ -848,18 +974,16 @@ export async function forwardSessionUpdates(
// ── Progress messages ──────────────────────────────────────
case 'progress': {
const progressData = msg.data as Record<string, unknown> | undefined
const progressData = msg.data
if (!progressData) break
// Handle agent/skill subagent progress
const progressType = progressData.type as string | undefined
const progressType = progressData.type
if (
progressType === 'agent_progress' ||
progressType === 'skill_progress'
) {
const progressMessage = progressData.message as
| Record<string, unknown>
| undefined
const progressMessage = progressData.message
if (progressMessage) {
const content = progressMessage.content as
| Array<Record<string, unknown>>
@@ -916,7 +1040,7 @@ export async function forwardSessionUpdates(
}
default:
// Ignore unknown message types
logger.debug('Ignoring unknown SDK message type')
break
}
}
@@ -1278,19 +1402,22 @@ export async function replayHistoryMessages(
clientCapabilities?: ClientCapabilities,
cwd?: string,
): Promise<void> {
for (const msg of messages) {
const type = msg.type as string
for (const rawMsg of messages) {
const msg = rawMsg as BridgeSDKMessage
// Skip non-conversation messages
if (type !== 'user' && type !== 'assistant') continue
if (msg.type !== 'user' && msg.type !== 'assistant') {
logger.debug('Ignoring unknown SDK message type')
continue
}
// Skip meta messages (synthetic continuation prompts)
if (msg.isMeta === true) continue
const messageData = msg.message as Record<string, unknown> | undefined
const messageData = msg.message
const content = messageData?.content
if (!content) continue
const role: 'assistant' | 'user' =
type === 'assistant' ? 'assistant' : 'user'
msg.type === 'assistant' ? 'assistant' : 'user'
if (typeof content === 'string') {
if (!content.trim()) continue