mirror of
https://github.com/claude-code-best/claude-code.git
synced 2026-06-15 12:55:51 +00:00
refactor(acp): make bridge SDK message handling type-safe (#1265)
* refactor(acp): make bridge SDK message handling type-safe - Add BridgeSDKMessage type alias to eliminate 14 type errors from void-leaked IteratorResult - Replace 18 scattered as-casts with a single uniform as BridgeSDKMessage - Add 68 lines of unit tests covering bridge message handling - Fixes docstring coverage to pass CI threshold * fix(acp): restore IteratorResult return type to nextSdkMessageOrAbort The simplified SDKMessage | undefined return type collapsed two distinct states: generator truly done vs generator yielding undefined. This broke forwardSessionUpdates which needs to distinguish the two — when the generator yields null/undefined it should continue (calling next() again), not break out of the loop. Restored the original IteratorResult<SDKMessage, void> return type so done and yielded-null are distinct again.
This commit is contained in:
@@ -4,6 +4,7 @@ import {
|
||||
toolUpdateFromToolResult,
|
||||
toolUpdateFromEditToolResponse,
|
||||
forwardSessionUpdates,
|
||||
nextSdkMessageOrAbort,
|
||||
} from '../bridge.js'
|
||||
import { promptToQueryInput } from '../promptConversion.js'
|
||||
import { markdownEscape, toDisplayPath } from '../utils.js'
|
||||
@@ -30,6 +31,10 @@ async function* makeStream(
|
||||
for (const m of msgs) yield m
|
||||
}
|
||||
|
||||
async function* makeWaitingStream(): AsyncGenerator<SDKMessage, void, unknown> {
|
||||
await new Promise<never>(() => {})
|
||||
}
|
||||
|
||||
// ── toolInfoFromToolUse ────────────────────────────────────────────
|
||||
|
||||
describe('toolInfoFromToolUse', () => {
|
||||
@@ -692,6 +697,47 @@ describe('toDisplayPath', () => {
|
||||
|
||||
// ── forwardSessionUpdates ─────────────────────────────────────────
|
||||
|
||||
describe('nextSdkMessageOrAbort', () => {
|
||||
test('returns done:true when aborted while waiting for next message', async () => {
|
||||
const ac = new AbortController()
|
||||
const pending = nextSdkMessageOrAbort(makeWaitingStream(), ac.signal)
|
||||
ac.abort()
|
||||
|
||||
const result = await Promise.race([
|
||||
pending,
|
||||
new Promise<'timeout'>(resolve => setTimeout(resolve, 100, 'timeout')),
|
||||
])
|
||||
|
||||
expect(result).toEqual({ done: true, value: undefined })
|
||||
})
|
||||
|
||||
test('returns done:true when stream is done', async () => {
|
||||
const result = await nextSdkMessageOrAbort(
|
||||
makeStream([]),
|
||||
new AbortController().signal,
|
||||
)
|
||||
|
||||
expect(result).toEqual({ done: true, value: undefined })
|
||||
})
|
||||
|
||||
test('returns a valid SDKMessage via IteratorResult', async () => {
|
||||
const msg = {
|
||||
type: 'assistant',
|
||||
message: {
|
||||
role: 'assistant',
|
||||
content: [{ type: 'text', text: 'hello' }],
|
||||
},
|
||||
} as unknown as SDKMessage
|
||||
|
||||
const result = await nextSdkMessageOrAbort(
|
||||
makeStream([msg]),
|
||||
new AbortController().signal,
|
||||
)
|
||||
|
||||
expect(result).toEqual({ done: false, value: msg })
|
||||
})
|
||||
})
|
||||
|
||||
describe('forwardSessionUpdates', () => {
|
||||
test('returns end_turn when stream is empty', async () => {
|
||||
const conn = makeConn()
|
||||
@@ -1077,6 +1123,28 @@ describe('forwardSessionUpdates', () => {
|
||||
).toBe(0)
|
||||
})
|
||||
|
||||
test('ignores unknown message types without crashing', async () => {
|
||||
const conn = makeConn()
|
||||
const debug = console.debug
|
||||
const debugMock = mock(() => {})
|
||||
console.debug = debugMock as typeof console.debug
|
||||
|
||||
try {
|
||||
const result = await forwardSessionUpdates(
|
||||
's1',
|
||||
makeStream([{ type: 'future_message' } as unknown as SDKMessage]),
|
||||
conn,
|
||||
new AbortController().signal,
|
||||
{},
|
||||
)
|
||||
|
||||
expect(result.stopReason).toBe('end_turn')
|
||||
expect(debugMock).toHaveBeenCalled()
|
||||
} finally {
|
||||
console.debug = debug
|
||||
}
|
||||
})
|
||||
|
||||
test('re-throws unexpected errors from stream', async () => {
|
||||
const conn = makeConn()
|
||||
async function* errorStream(): AsyncGenerator<
|
||||
|
||||
@@ -28,6 +28,7 @@ import { toDisplayPath, markdownEscape } from './utils.js'
|
||||
|
||||
// ── ToolUseCache ──────────────────────────────────────────────────
|
||||
|
||||
/** Maps tool_use_id → tool metadata for tracked inflight tool calls. */
|
||||
export type ToolUseCache = {
|
||||
[key: string]: {
|
||||
type: 'tool_use' | 'server_tool_use' | 'mcp_tool_use'
|
||||
@@ -39,6 +40,7 @@ export type ToolUseCache = {
|
||||
|
||||
// ── Session usage tracking ────────────────────────────────────────
|
||||
|
||||
/** Accumulated token usage across a session, updated per result message. */
|
||||
export type SessionUsage = {
|
||||
inputTokens: number
|
||||
outputTokens: number
|
||||
@@ -46,8 +48,139 @@ export type SessionUsage = {
|
||||
cachedWriteTokens: number
|
||||
}
|
||||
|
||||
/** Token usage reported in SDK result messages. */
|
||||
type BridgeUsage = {
|
||||
input_tokens?: number
|
||||
output_tokens?: number
|
||||
cache_read_input_tokens?: number
|
||||
cache_creation_input_tokens?: number
|
||||
}
|
||||
|
||||
/** system-init, compact_boundary, status, api_retry, local_command_output messages. */
|
||||
type BridgeSystemMessage = {
|
||||
type: 'system'
|
||||
subtype?: string
|
||||
session_id?: string
|
||||
content?: string
|
||||
status?: string
|
||||
compact_result?: string
|
||||
compact_error?: string
|
||||
model?: string
|
||||
uuid?: string
|
||||
[key: string]: unknown
|
||||
}
|
||||
|
||||
/** Turn completion message: success with usage, or error with stop_reason. */
|
||||
type BridgeResultMessage = {
|
||||
type: 'result'
|
||||
subtype?: string
|
||||
usage?: BridgeUsage
|
||||
modelUsage?: Record<string, { contextWindow?: number }>
|
||||
total_cost_usd?: number
|
||||
is_error?: boolean
|
||||
stop_reason?: string | null
|
||||
result?: string
|
||||
errors?: string[]
|
||||
duration_ms?: number
|
||||
duration_api_ms?: number
|
||||
num_turns?: number
|
||||
permission_denials?: unknown[]
|
||||
session_id?: string
|
||||
[key: string]: unknown
|
||||
}
|
||||
|
||||
/** Full assistant response message after the turn completes. */
|
||||
type BridgeAssistantMessage = {
|
||||
type: 'assistant'
|
||||
message?: {
|
||||
role?: string
|
||||
id?: string
|
||||
model?: string
|
||||
content?: string | Array<Record<string, unknown>>
|
||||
usage?: BridgeUsage | Record<string, unknown>
|
||||
stop_reason?: string | null
|
||||
[key: string]: unknown
|
||||
}
|
||||
parent_tool_use_id?: string | null
|
||||
uuid?: string
|
||||
session_id?: string
|
||||
error?: unknown
|
||||
[key: string]: unknown
|
||||
}
|
||||
|
||||
/** Real-time streaming event (aka partial_assistant in the SDK schema). */
|
||||
type BridgeStreamEventMessage = {
|
||||
type: 'stream_event'
|
||||
event?: { type?: string; [key: string]: unknown }
|
||||
message?: Record<string, unknown>
|
||||
parent_tool_use_id?: string | null
|
||||
session_id?: string
|
||||
uuid?: string
|
||||
[key: string]: unknown
|
||||
}
|
||||
|
||||
/** User prompt message (may include tool_use_result from prior turns). */
|
||||
type BridgeUserMessage = {
|
||||
type: 'user'
|
||||
message?: Record<string, unknown>
|
||||
uuid?: string
|
||||
isReplay?: boolean
|
||||
isMeta?: boolean
|
||||
timestamp?: string
|
||||
[key: string]: unknown
|
||||
}
|
||||
|
||||
/** Subagent or hook progress notification (internal, not an SDK message member). */
|
||||
type BridgeProgressMessage = {
|
||||
type: 'progress'
|
||||
data?: {
|
||||
type?: string
|
||||
message?: Record<string, unknown>
|
||||
[key: string]: unknown
|
||||
}
|
||||
[key: string]: unknown
|
||||
}
|
||||
|
||||
/** Summary of tool calls made during a turn. */
|
||||
type BridgeToolUseSummaryMessage = {
|
||||
type: 'tool_use_summary'
|
||||
summary?: string
|
||||
preceding_tool_use_ids?: string[]
|
||||
uuid?: string
|
||||
session_id?: string
|
||||
[key: string]: unknown
|
||||
}
|
||||
|
||||
/** File attachment metadata (internal, not an SDK message member). */
|
||||
type BridgeAttachmentMessage = {
|
||||
type: 'attachment'
|
||||
[key: string]: unknown
|
||||
}
|
||||
|
||||
/** Compaction boundary marker (type is 'compact_boundary', not 'system'). */
|
||||
type BridgeCompactBoundaryMessage = {
|
||||
type: 'compact_boundary'
|
||||
compact_metadata?: Record<string, unknown>
|
||||
[key: string]: unknown
|
||||
}
|
||||
|
||||
/** ACP bridge local discriminated union — covers all message shapes consumed by the forwarding loop. */
|
||||
type BridgeSDKMessage =
|
||||
| BridgeSystemMessage
|
||||
| BridgeResultMessage
|
||||
| BridgeAssistantMessage
|
||||
| BridgeStreamEventMessage
|
||||
| BridgeUserMessage
|
||||
| BridgeProgressMessage
|
||||
| BridgeToolUseSummaryMessage
|
||||
| BridgeAttachmentMessage
|
||||
| BridgeCompactBoundaryMessage
|
||||
|
||||
const logger: { debug: (...args: unknown[]) => void } = console
|
||||
|
||||
// ── Tool info conversion ──────────────────────────────────────────
|
||||
|
||||
/** Sanitised tool metadata sent to ACP client for tool_call notifications. */
|
||||
interface ToolInfo {
|
||||
title: string
|
||||
kind: ToolKind
|
||||
@@ -519,6 +652,7 @@ function toAcpContentBlock(
|
||||
|
||||
// ── Edit tool response → diff ──────────────────────────────────────
|
||||
|
||||
/** Context lines and diff metadata for one hunk of an Edit tool response. */
|
||||
interface EditToolResponseHunk {
|
||||
oldStart: number
|
||||
oldLines: number
|
||||
@@ -527,6 +661,7 @@ interface EditToolResponseHunk {
|
||||
lines: string[]
|
||||
}
|
||||
|
||||
/** Result block for Edit/Write tool responses containing hunks and optional file stats. */
|
||||
interface EditToolResponse {
|
||||
filePath?: string
|
||||
structuredPatch?: EditToolResponseHunk[]
|
||||
@@ -581,14 +716,13 @@ export function toolUpdateFromEditToolResponse(toolResponse: unknown): {
|
||||
return result
|
||||
}
|
||||
|
||||
function nextSdkMessageOrAbort(
|
||||
export function nextSdkMessageOrAbort(
|
||||
sdkMessages: AsyncGenerator<SDKMessage, void, unknown>,
|
||||
abortSignal: AbortSignal,
|
||||
): Promise<IteratorResult<SDKMessage, void>> {
|
||||
if (abortSignal.aborted) {
|
||||
return Promise.resolve({ done: true, value: undefined })
|
||||
}
|
||||
|
||||
let abortHandler: (() => void) | undefined
|
||||
const abortPromise = new Promise<IteratorResult<SDKMessage, void>>(
|
||||
resolve => {
|
||||
@@ -596,7 +730,6 @@ function nextSdkMessageOrAbort(
|
||||
abortSignal.addEventListener('abort', abortHandler, { once: true })
|
||||
},
|
||||
)
|
||||
|
||||
return Promise.race([sdkMessages.next(), abortPromise]).finally(() => {
|
||||
if (abortHandler) {
|
||||
abortSignal.removeEventListener('abort', abortHandler)
|
||||
@@ -642,16 +775,14 @@ export async function forwardSessionUpdates(
|
||||
// a slow API response.
|
||||
const nextResult = await nextSdkMessageOrAbort(sdkMessages, abortSignal)
|
||||
if (nextResult.done || abortSignal.aborted) break
|
||||
const msg = nextResult.value
|
||||
const rawMsg = nextResult.value
|
||||
if (rawMsg == null) continue
|
||||
const msg = rawMsg as BridgeSDKMessage
|
||||
|
||||
if (msg == null) continue
|
||||
|
||||
const type = msg.type as string
|
||||
|
||||
switch (type) {
|
||||
switch (msg.type) {
|
||||
// ── System messages ────────────────────────────────────────
|
||||
case 'system': {
|
||||
const subtype = msg.subtype as string | undefined
|
||||
const subtype = msg.subtype
|
||||
|
||||
if (subtype === 'compact_boundary') {
|
||||
// Reset assistant usage tracking after compaction
|
||||
@@ -679,27 +810,19 @@ export async function forwardSessionUpdates(
|
||||
|
||||
// ── Result messages ────────────────────────────────────────
|
||||
case 'result': {
|
||||
const usage = msg.usage as
|
||||
| {
|
||||
input_tokens: number
|
||||
output_tokens: number
|
||||
cache_read_input_tokens: number
|
||||
cache_creation_input_tokens: number
|
||||
}
|
||||
| undefined
|
||||
const usage = msg.usage
|
||||
|
||||
if (usage) {
|
||||
accumulatedUsage.inputTokens += usage.input_tokens
|
||||
accumulatedUsage.outputTokens += usage.output_tokens
|
||||
accumulatedUsage.cachedReadTokens += usage.cache_read_input_tokens
|
||||
accumulatedUsage.inputTokens += usage.input_tokens ?? 0
|
||||
accumulatedUsage.outputTokens += usage.output_tokens ?? 0
|
||||
accumulatedUsage.cachedReadTokens +=
|
||||
usage.cache_read_input_tokens ?? 0
|
||||
accumulatedUsage.cachedWriteTokens +=
|
||||
usage.cache_creation_input_tokens
|
||||
usage.cache_creation_input_tokens ?? 0
|
||||
}
|
||||
|
||||
// Resolve context window size from modelUsage via prefix matching
|
||||
const modelUsage = msg.modelUsage as
|
||||
| Record<string, { contextWindow?: number }>
|
||||
| undefined
|
||||
const modelUsage = msg.modelUsage
|
||||
if (modelUsage && lastAssistantModel) {
|
||||
const match = getMatchingModelUsage(modelUsage, lastAssistantModel)
|
||||
if (match?.contextWindow) {
|
||||
@@ -716,7 +839,7 @@ export async function forwardSessionUpdates(
|
||||
accumulatedUsage.cachedReadTokens +
|
||||
accumulatedUsage.cachedWriteTokens
|
||||
|
||||
const totalCostUsd = msg.total_cost_usd as number | undefined
|
||||
const totalCostUsd = msg.total_cost_usd
|
||||
await conn.sessionUpdate({
|
||||
sessionId,
|
||||
update: {
|
||||
@@ -731,8 +854,8 @@ export async function forwardSessionUpdates(
|
||||
})
|
||||
|
||||
// Determine stop reason
|
||||
const subtype = msg.subtype as string | undefined
|
||||
const isError = msg.is_error as boolean | undefined
|
||||
const subtype = msg.subtype
|
||||
const isError = msg.is_error
|
||||
|
||||
if (abortSignal.aborted) {
|
||||
stopReason = 'cancelled'
|
||||
@@ -741,7 +864,7 @@ export async function forwardSessionUpdates(
|
||||
|
||||
switch (subtype) {
|
||||
case 'success': {
|
||||
const stopReasonStr = msg.stop_reason as string | null
|
||||
const stopReasonStr = msg.stop_reason
|
||||
if (stopReasonStr === 'max_tokens') {
|
||||
stopReason = 'max_tokens'
|
||||
}
|
||||
@@ -752,7 +875,7 @@ export async function forwardSessionUpdates(
|
||||
break
|
||||
}
|
||||
case 'error_during_execution': {
|
||||
if ((msg.stop_reason as string | null) === 'max_tokens') {
|
||||
if (msg.stop_reason === 'max_tokens') {
|
||||
stopReason = 'max_tokens'
|
||||
} else if (isError) {
|
||||
stopReason = 'end_turn'
|
||||
@@ -797,20 +920,23 @@ export async function forwardSessionUpdates(
|
||||
case 'assistant': {
|
||||
// Track last assistant total usage for context window computation
|
||||
// (only for top-level messages, not subagents)
|
||||
const assistantMsg = msg.message as
|
||||
| Record<string, unknown>
|
||||
| undefined
|
||||
const parentToolUseId = msg.parent_tool_use_id as
|
||||
| string
|
||||
| null
|
||||
| undefined
|
||||
const assistantMsg = msg.message
|
||||
const parentToolUseId = msg.parent_tool_use_id
|
||||
if (assistantMsg?.usage && parentToolUseId === null) {
|
||||
const msgUsage = assistantMsg.usage as Record<string, unknown>
|
||||
const usage = assistantMsg.usage
|
||||
lastAssistantTotalUsage =
|
||||
((msgUsage.input_tokens as number) ?? 0) +
|
||||
((msgUsage.output_tokens as number) ?? 0) +
|
||||
((msgUsage.cache_read_input_tokens as number) ?? 0) +
|
||||
((msgUsage.cache_creation_input_tokens as number) ?? 0)
|
||||
(typeof usage.input_tokens === 'number'
|
||||
? usage.input_tokens
|
||||
: 0) +
|
||||
(typeof usage.output_tokens === 'number'
|
||||
? usage.output_tokens
|
||||
: 0) +
|
||||
(typeof usage.cache_read_input_tokens === 'number'
|
||||
? usage.cache_read_input_tokens
|
||||
: 0) +
|
||||
(typeof usage.cache_creation_input_tokens === 'number'
|
||||
? usage.cache_creation_input_tokens
|
||||
: 0)
|
||||
}
|
||||
// Track the current top-level model for context window size lookup
|
||||
if (
|
||||
@@ -818,7 +944,7 @@ export async function forwardSessionUpdates(
|
||||
assistantMsg?.model &&
|
||||
assistantMsg.model !== '<synthetic>'
|
||||
) {
|
||||
lastAssistantModel = assistantMsg.model as string
|
||||
lastAssistantModel = assistantMsg.model
|
||||
}
|
||||
|
||||
const notifications = assistantMessageToAcpNotifications(
|
||||
@@ -848,18 +974,16 @@ export async function forwardSessionUpdates(
|
||||
|
||||
// ── Progress messages ──────────────────────────────────────
|
||||
case 'progress': {
|
||||
const progressData = msg.data as Record<string, unknown> | undefined
|
||||
const progressData = msg.data
|
||||
if (!progressData) break
|
||||
|
||||
// Handle agent/skill subagent progress
|
||||
const progressType = progressData.type as string | undefined
|
||||
const progressType = progressData.type
|
||||
if (
|
||||
progressType === 'agent_progress' ||
|
||||
progressType === 'skill_progress'
|
||||
) {
|
||||
const progressMessage = progressData.message as
|
||||
| Record<string, unknown>
|
||||
| undefined
|
||||
const progressMessage = progressData.message
|
||||
if (progressMessage) {
|
||||
const content = progressMessage.content as
|
||||
| Array<Record<string, unknown>>
|
||||
@@ -916,7 +1040,7 @@ export async function forwardSessionUpdates(
|
||||
}
|
||||
|
||||
default:
|
||||
// Ignore unknown message types
|
||||
logger.debug('Ignoring unknown SDK message type')
|
||||
break
|
||||
}
|
||||
}
|
||||
@@ -1278,19 +1402,22 @@ export async function replayHistoryMessages(
|
||||
clientCapabilities?: ClientCapabilities,
|
||||
cwd?: string,
|
||||
): Promise<void> {
|
||||
for (const msg of messages) {
|
||||
const type = msg.type as string
|
||||
for (const rawMsg of messages) {
|
||||
const msg = rawMsg as BridgeSDKMessage
|
||||
// Skip non-conversation messages
|
||||
if (type !== 'user' && type !== 'assistant') continue
|
||||
if (msg.type !== 'user' && msg.type !== 'assistant') {
|
||||
logger.debug('Ignoring unknown SDK message type')
|
||||
continue
|
||||
}
|
||||
// Skip meta messages (synthetic continuation prompts)
|
||||
if (msg.isMeta === true) continue
|
||||
|
||||
const messageData = msg.message as Record<string, unknown> | undefined
|
||||
const messageData = msg.message
|
||||
const content = messageData?.content
|
||||
if (!content) continue
|
||||
|
||||
const role: 'assistant' | 'user' =
|
||||
type === 'assistant' ? 'assistant' : 'user'
|
||||
msg.type === 'assistant' ? 'assistant' : 'user'
|
||||
|
||||
if (typeof content === 'string') {
|
||||
if (!content.trim()) continue
|
||||
|
||||
Reference in New Issue
Block a user