diff --git a/packages/builtin-tools/src/tools/AgentTool/__tests__/filterIncompleteToolCalls.test.ts b/packages/builtin-tools/src/tools/AgentTool/__tests__/filterIncompleteToolCalls.test.ts new file mode 100644 index 000000000..5429c6ce8 --- /dev/null +++ b/packages/builtin-tools/src/tools/AgentTool/__tests__/filterIncompleteToolCalls.test.ts @@ -0,0 +1,180 @@ +import { describe, expect, test } from 'bun:test' +import type { Message } from 'src/types/message.js' +import { filterIncompleteToolCalls } from '../filterIncompleteToolCalls.js' + +describe('filterIncompleteToolCalls', () => { + test('drops assistant tool uses that do not have matching results', () => { + const messages = [ + { + type: 'assistant', + uuid: 'a1', + message: { + role: 'assistant', + content: [{ type: 'tool_use', id: 'missing', name: 'Read' }], + }, + }, + { + type: 'user', + uuid: 'u1', + message: { role: 'user', content: 'continue' }, + }, + ] as unknown as Message[] + + expect( + filterIncompleteToolCalls(messages).map(message => String(message.uuid)), + ).toEqual(['u1']) + }) + + test('preserves assistant text when dropping orphan tool uses', () => { + const messages = [ + { + type: 'assistant', + uuid: 'a1', + message: { + role: 'assistant', + content: [ + { type: 'text', text: 'I will read the file.' }, + { type: 'tool_use', id: 'missing', name: 'Read' }, + ], + }, + }, + ] as unknown as Message[] + + const filtered = filterIncompleteToolCalls(messages) + expect(filtered).toHaveLength(1) + const first = filtered[0]! + const content = first.message!.content + expect( + Array.isArray(content) ? content.map(block => block.type) : [], + ).toEqual(['text']) + }) + + test('keeps completed parallel tool calls when dropping an orphan', () => { + const messages = [ + { + type: 'assistant', + uuid: 'a1', + message: { + role: 'assistant', + content: [ + { type: 'tool_use', id: 'done', name: 'Read' }, + { type: 'tool_use', id: 'missing', name: 'Grep' }, + ], + }, + }, + { + type: 'user', + uuid: 'u1', + message: { + role: 'user', + content: [{ type: 'tool_result', tool_use_id: 'done', content: 'ok' }], + }, + }, + ] as unknown as Message[] + + const filtered = filterIncompleteToolCalls(messages) + expect(filtered.map(message => String(message.uuid))).toEqual(['a1', 'u1']) + const first = filtered[0]! + const content = first.message!.content + expect( + Array.isArray(content) + ? content.map(block => + block.type === 'tool_use' ? block.id : block.type, + ) + : [], + ).toEqual(['done']) + }) + + test('keeps assistant tool uses that have matching results', () => { + const messages = [ + { + type: 'assistant', + uuid: 'a1', + message: { + role: 'assistant', + content: [{ type: 'tool_use', id: 'done', name: 'Read' }], + }, + }, + { + type: 'user', + uuid: 'u1', + message: { + role: 'user', + content: [{ type: 'tool_result', tool_use_id: 'done', content: 'ok' }], + }, + }, + ] as unknown as Message[] + + expect( + filterIncompleteToolCalls(messages).map(message => String(message.uuid)), + ).toEqual(['a1', 'u1']) + }) + + test('drops orphan tool results when their tool use was removed', () => { + const messages = [ + { + type: 'user', + uuid: 'u1', + message: { + role: 'user', + content: [ + { type: 'tool_result', tool_use_id: 'missing', content: 'late' }, + ], + }, + }, + ] as unknown as Message[] + + expect(filterIncompleteToolCalls(messages)).toEqual([]) + }) + + test('keeps user text while dropping orphan tool results', () => { + const messages = [ + { + type: 'assistant', + uuid: 'a1', + message: { role: 'assistant', content: 'done' }, + }, + { + type: 'user', + uuid: 'u1', + message: { + role: 'user', + content: [ + { type: 'text', text: 'keep this' }, + { type: 'tool_result', tool_use_id: 'missing', content: 'late' }, + ], + }, + }, + ] as unknown as Message[] + + const filtered = filterIncompleteToolCalls(messages) + expect(filtered.map(message => String(message.uuid))).toEqual(['a1', 'u1']) + const content = filtered[1]!.message!.content + expect(Array.isArray(content) ? content : []).toEqual([ + { type: 'text', text: 'keep this' }, + ]) + }) + + test('drops malformed tool blocks without ids', () => { + const messages = [ + { + type: 'assistant', + uuid: 'a1', + message: { + role: 'assistant', + content: [{ type: 'tool_use', name: 'Read' }], + }, + }, + { + type: 'user', + uuid: 'u1', + message: { + role: 'user', + content: [{ type: 'tool_result', content: 'late' }], + }, + }, + ] as unknown as Message[] + + expect(filterIncompleteToolCalls(messages)).toEqual([]) + }) +}) diff --git a/packages/builtin-tools/src/tools/AgentTool/filterIncompleteToolCalls.ts b/packages/builtin-tools/src/tools/AgentTool/filterIncompleteToolCalls.ts new file mode 100644 index 000000000..7e30754ee --- /dev/null +++ b/packages/builtin-tools/src/tools/AgentTool/filterIncompleteToolCalls.ts @@ -0,0 +1,110 @@ +import type { + AssistantMessage, + Message, + UserMessage, +} from 'src/types/message.js' + +/** + * Removes invalid or orphaned tool_use/tool_result blocks while preserving + * completed tool-call pairs. This is intentionally block-level, not + * message-level, so completed parallel tool calls stay paired with results. + */ +export function filterIncompleteToolCalls(messages: Message[]): Message[] { + const toolUseIdsWithResults = new Set() + + for (const message of messages) { + if (message?.type === 'user') { + const userMessage = message as UserMessage + const content = userMessage.message.content + if (Array.isArray(content)) { + for (const block of content) { + if (block.type === 'tool_result' && block.tool_use_id) { + toolUseIdsWithResults.add(block.tool_use_id) + } + } + } + } + } + + const retainedToolUseIds = new Set() + const withoutOrphanToolUses: Message[] = [] + + for (const message of messages) { + if (message?.type === 'assistant') { + const assistantMessage = message as AssistantMessage + const content = assistantMessage.message.content + if (Array.isArray(content)) { + let changed = false + const filteredContent = content.filter(block => { + if (block.type !== 'tool_use') return true + if (!block.id) { + changed = true + return false + } + if (toolUseIdsWithResults.has(block.id)) { + retainedToolUseIds.add(block.id) + return true + } + changed = true + return false + }) + + if (!changed) { + withoutOrphanToolUses.push(message) + continue + } + if (filteredContent.length > 0) { + withoutOrphanToolUses.push({ + ...assistantMessage, + message: { + ...assistantMessage.message, + content: filteredContent, + }, + }) + } + continue + } + } + withoutOrphanToolUses.push(message) + } + + const filteredMessages: Message[] = [] + for (const message of withoutOrphanToolUses) { + if (message?.type !== 'user') { + filteredMessages.push(message) + continue + } + const userMessage = message as UserMessage + const content = userMessage.message.content + if (!Array.isArray(content)) { + filteredMessages.push(message) + continue + } + let changed = false + const filteredContent = content.filter(block => { + if (block.type !== 'tool_result') return true + if (!block.tool_use_id) { + changed = true + return false + } + if (retainedToolUseIds.has(block.tool_use_id)) return true + changed = true + return false + }) + if (!changed) { + filteredMessages.push(message) + continue + } + if (filteredContent.length > 0) { + filteredMessages.push({ + ...userMessage, + message: { + ...userMessage.message, + content: filteredContent, + }, + }) + } + } + + return filteredMessages +} diff --git a/packages/builtin-tools/src/tools/AgentTool/runAgent.ts b/packages/builtin-tools/src/tools/AgentTool/runAgent.ts index baeed9022..de55b53f8 100644 --- a/packages/builtin-tools/src/tools/AgentTool/runAgent.ts +++ b/packages/builtin-tools/src/tools/AgentTool/runAgent.ts @@ -86,8 +86,11 @@ import { import type { ContentReplacementState } from 'src/utils/toolResultStorage.js' import { createAgentId } from 'src/utils/uuid.js' import { resolveAgentTools } from './agentToolUtils.js' +import { filterIncompleteToolCalls } from './filterIncompleteToolCalls.js' import { type AgentDefinition, isBuiltInAgent } from './loadAgentsDir.js' +export { filterIncompleteToolCalls } from './filterIncompleteToolCalls.js' + /** * Initialize agent-specific MCP servers * Agents can define their own MCP servers in their frontmatter that are additive @@ -886,50 +889,6 @@ export async function* runAgent({ } } -/** - * Filters out assistant messages with incomplete tool calls (tool uses without results). - * This prevents API errors when sending messages with orphaned tool calls. - */ -export function filterIncompleteToolCalls(messages: Message[]): Message[] { - // Build a set of tool use IDs that have results - const toolUseIdsWithResults = new Set() - - for (const message of messages) { - if (message?.type === 'user') { - const userMessage = message as UserMessage - const content = userMessage.message.content - if (Array.isArray(content)) { - for (const block of content) { - if (block.type === 'tool_result' && block.tool_use_id) { - toolUseIdsWithResults.add(block.tool_use_id) - } - } - } - } - } - - // Filter out assistant messages that contain tool calls without results - return messages.filter(message => { - if (message?.type === 'assistant') { - const assistantMessage = message as AssistantMessage - const content = assistantMessage.message.content - if (Array.isArray(content)) { - // Check if this assistant message has any tool uses without results - const hasIncompleteToolCall = content.some( - block => - block.type === 'tool_use' && - block.id && - !toolUseIdsWithResults.has(block.id), - ) - // Exclude messages with incomplete tool calls - return !hasIncompleteToolCall - } - } - // Keep all non-assistant messages and assistant messages without tool calls - return true - }) -} - async function getAgentSystemPrompt( agentDefinition: AgentDefinition, toolUseContext: Pick, diff --git a/packages/builtin-tools/src/tools/SendMessageTool/SendMessageTool.ts b/packages/builtin-tools/src/tools/SendMessageTool/SendMessageTool.ts index e4868bc53..68f531034 100644 --- a/packages/builtin-tools/src/tools/SendMessageTool/SendMessageTool.ts +++ b/packages/builtin-tools/src/tools/SendMessageTool/SendMessageTool.ts @@ -131,15 +131,16 @@ export type SendMessageToolOutput = | ResponseOutput const UDS_INLINE_TOKEN_MARKER = '#token=' -const UDS_INLINE_TOKEN_REJECTED_KEY = '__udsInlineTokenRejected' function stripInlineUdsToken(target: string): string { - const markerIndex = target.lastIndexOf(UDS_INLINE_TOKEN_MARKER) + const markerIndex = target.indexOf(UDS_INLINE_TOKEN_MARKER) return markerIndex === -1 ? target : target.slice(0, markerIndex) } function hasInlineUdsToken(to: string): boolean { const addr = parseAddress(to) + // Empty-token markers are still inline-token attempts. Observable input + // redaction preserves "#token=" so cloned inputs remain rejected. return ( addr.scheme === 'uds' && addr.target.includes(UDS_INLINE_TOKEN_MARKER) ) @@ -151,20 +152,17 @@ function recipientForDisplay(to: string): string { return `uds:${stripInlineUdsToken(addr.target)}` } -function markAndRedactInlineUdsToken( - input: { to: string } & Record, -): void { - if (!hasInlineUdsToken(input.to)) return - input.to = recipientForDisplay(input.to) - input[UDS_INLINE_TOKEN_REJECTED_KEY] = true +function redactInlineUdsTokenForRejection(to: string): string { + const addr = parseAddress(to) + if (addr.scheme !== 'uds') return to + const markerIndex = addr.target.indexOf(UDS_INLINE_TOKEN_MARKER) + if (markerIndex === -1) return to + return `uds:${addr.target.slice(0, markerIndex)}${UDS_INLINE_TOKEN_MARKER}` } -function wasInlineUdsTokenRejected(input: unknown): boolean { - return ( - typeof input === 'object' && - input !== null && - (input as Record)[UDS_INLINE_TOKEN_REJECTED_KEY] === true - ) +function redactObservableInlineUdsToken(input: { to: string }): void { + if (!hasInlineUdsToken(input.to)) return + input.to = redactInlineUdsTokenForRejection(input.to) } function findTeammateColor( @@ -580,9 +578,7 @@ export const SendMessageTool: Tool = backfillObservableInput(input) { if (typeof input.to !== 'string') return - markAndRedactInlineUdsToken( - input as { to: string } & Record, - ) + redactObservableInlineUdsToken(input as { to: string }) if ('type' in input) return if (input.to === '*') { @@ -620,7 +616,10 @@ export const SendMessageTool: Tool = case 'shutdown_response': return `shutdown_response ${input.message.approve ? 'approve' : 'reject'} ${input.message.request_id}` case 'plan_approval_response': - return `plan_approval ${input.message.approve ? 'approve' : 'reject'} to ${recipient}` + const planApprovalDecision = input.message.approve + ? 'approve' + : 'reject' + return `plan_approval ${planApprovalDecision} to ${recipient}` } }, @@ -674,7 +673,7 @@ export const SendMessageTool: Tool = } if ( addr.scheme === 'uds' && - (hasInlineUdsToken(input.to) || wasInlineUdsTokenRejected(input)) + hasInlineUdsToken(input.to) ) { return { result: false, @@ -808,10 +807,7 @@ export const SendMessageTool: Tool = async call(input, context, canUseTool, assistantMessage) { if (typeof input.message === 'string') { const addr = parseAddress(input.to) - if ( - addr.scheme === 'uds' && - (hasInlineUdsToken(input.to) || wasInlineUdsTokenRejected(input)) - ) { + if (addr.scheme === 'uds' && hasInlineUdsToken(input.to)) { return { data: { success: false, @@ -841,10 +837,10 @@ export const SendMessageTool: Tool = const { postInterClaudeMessage } = require('src/bridge/peerSessions.js') as typeof import('src/bridge/peerSessions.js') /* eslint-enable @typescript-eslint/no-require-imports */ - const result = (await postInterClaudeMessage( + const result = await postInterClaudeMessage( addr.target, input.message, - )) as { ok: boolean; error?: string } + ) as { ok: boolean; error?: string } const preview = input.summary || truncate(input.message, 50) return { data: { @@ -856,16 +852,6 @@ export const SendMessageTool: Tool = } } if (addr.scheme === 'uds') { - const recipient = recipientForDisplay(input.to) - if (hasInlineUdsToken(input.to) || wasInlineUdsTokenRejected(input)) { - return { - data: { - success: false, - message: - 'uds addresses must not include inline auth tokens; use the ListPeers address', - }, - } - } /* eslint-disable @typescript-eslint/no-require-imports */ const { sendToUdsSocket } = require('src/utils/udsClient.js') as typeof import('src/utils/udsClient.js') @@ -876,14 +862,14 @@ export const SendMessageTool: Tool = return { data: { success: true, - message: `”${preview}” → ${recipient}`, + message: `”${preview}” → ${input.to}`, }, } } catch (e) { return { data: { success: false, - message: `Failed to send to ${recipient}: ${errorMessage(e)}`, + message: `Failed to send to ${input.to}: ${errorMessage(e)}`, }, } } diff --git a/packages/builtin-tools/src/tools/SendMessageTool/__tests__/udsRecipientSanitization.test.ts b/packages/builtin-tools/src/tools/SendMessageTool/__tests__/udsRecipientSanitization.test.ts index a0ab2af0d..e0ce1a823 100644 --- a/packages/builtin-tools/src/tools/SendMessageTool/__tests__/udsRecipientSanitization.test.ts +++ b/packages/builtin-tools/src/tools/SendMessageTool/__tests__/udsRecipientSanitization.test.ts @@ -1,8 +1,8 @@ import { describe, expect, test } from 'bun:test' +import { SendMessageTool } from '../SendMessageTool.js' describe('SendMessageTool UDS recipient handling', () => { test('redacts inline UDS tokens before classifier and observable paths', async () => { - const { SendMessageTool } = await import('../SendMessageTool.js') const tokenAddress = 'uds:/tmp/peer.sock#token=secret-token' const observableInput = { @@ -12,6 +12,7 @@ describe('SendMessageTool UDS recipient handling', () => { SendMessageTool.backfillObservableInput!(observableInput) expect(observableInput.recipient).toBe('uds:/tmp/peer.sock') + expect(observableInput.to).toBe('uds:/tmp/peer.sock#token=') expect(JSON.stringify(observableInput)).not.toContain('secret-token') expect( SendMessageTool.toAutoClassifierInput({ @@ -22,7 +23,6 @@ describe('SendMessageTool UDS recipient handling', () => { }) test('keeps redacted UDS token rejection through observable backfill', async () => { - const { SendMessageTool } = await import('../SendMessageTool.js') const observableInput = { to: 'uds:/tmp/peer.sock#token=secret-token', message: { @@ -35,7 +35,7 @@ describe('SendMessageTool UDS recipient handling', () => { SendMessageTool.backfillObservableInput!(observableInput) - expect(observableInput.to).toBe('uds:/tmp/peer.sock') + expect(observableInput.to).toBe('uds:/tmp/peer.sock#token=') expect(observableInput.recipient).toBe('uds:/tmp/peer.sock') expect(observableInput.type).toBe('plan_approval_response') expect(observableInput.request_id).toBe('req-1') @@ -55,8 +55,37 @@ describe('SendMessageTool UDS recipient handling', () => { expect(result.message).toContain('inline auth tokens') }) + test('keeps inline-token rejection when observable input is cloned', async () => { + const observableInput = { + to: 'uds:/tmp/peer.sock#token=secret-token', + message: 'hello', + } as Record + + SendMessageTool.backfillObservableInput!(observableInput) + const clonedInput = { + to: observableInput.to, + message: observableInput.message, + summary: 'hello peer', + } + + const validation = await SendMessageTool.validateInput!( + clonedInput as never, + {} as never, + ) + const result = await SendMessageTool.call( + clonedInput as never, + {} as never, + undefined as never, + undefined as never, + ) + + expect(validation.result).toBe(false) + expect(result.data.success).toBe(false) + expect(JSON.stringify(clonedInput)).not.toContain('secret-token') + expect(JSON.stringify(result)).not.toContain('secret-token') + }) + test('redacts UDS tokens in structured classifier text', async () => { - const { SendMessageTool } = await import('../SendMessageTool.js') const to = 'uds:/tmp/peer.sock#token=secret-token' expect( @@ -75,10 +104,50 @@ describe('SendMessageTool UDS recipient handling', () => { }, }), ).toBe('plan_approval approve to uds:/tmp/peer.sock') + expect( + SendMessageTool.toAutoClassifierInput({ + to, + message: { + type: 'plan_approval_response', + request_id: 'req-2', + approve: false, + }, + }), + ).toBe('plan_approval reject to uds:/tmp/peer.sock') + expect( + SendMessageTool.toAutoClassifierInput({ + to, + message: { + type: 'shutdown_response', + request_id: 'shutdown-1', + approve: false, + }, + }), + ).toBe('shutdown_response reject shutdown-1') + }) + + test('redacts from the first inline UDS token marker', async () => { + const tokenAddress = 'uds:/tmp/peer.sock#token=first#token=second' + + const observableInput = { + to: tokenAddress, + message: 'hello', + } as Record + SendMessageTool.backfillObservableInput!(observableInput) + + expect(observableInput.to).toBe('uds:/tmp/peer.sock#token=') + expect(observableInput.recipient).toBe('uds:/tmp/peer.sock') + expect(JSON.stringify(observableInput)).not.toContain('first') + expect(JSON.stringify(observableInput)).not.toContain('second') + expect( + SendMessageTool.toAutoClassifierInput({ + to: tokenAddress, + message: 'hello', + }), + ).toBe('to uds:/tmp/peer.sock: hello') }) test('rejects inline UDS tokens during validation', async () => { - const { SendMessageTool } = await import('../SendMessageTool.js') const result = await SendMessageTool.validateInput!( { to: 'uds:/tmp/peer.sock#token=secret-token', @@ -96,7 +165,6 @@ describe('SendMessageTool UDS recipient handling', () => { }) test('rejects inline UDS tokens during execution without leaking them', async () => { - const { SendMessageTool } = await import('../SendMessageTool.js') const result = await SendMessageTool.call( { to: 'uds:/tmp/peer.sock#token=secret-token', diff --git a/src/commands/poor/__tests__/poorMode.test.ts b/src/commands/poor/__tests__/poorMode.test.ts index c2a80f3cf..539c804e1 100644 --- a/src/commands/poor/__tests__/poorMode.test.ts +++ b/src/commands/poor/__tests__/poorMode.test.ts @@ -5,7 +5,8 @@ * After the fix, it reads from / writes to settings.json via * getInitialSettings() and updateSettingsForSource(). */ -import { describe, expect, test, beforeEach, mock } from 'bun:test' +import { afterAll, describe, expect, test, beforeEach, mock } from 'bun:test' +import * as settingsModule from '../../../utils/settings/settings.js' // ── Mocks must be declared before the module under test is imported ────────── @@ -13,24 +14,48 @@ let mockSettings: Record = {} let lastUpdate: { source: string; patch: Record } | null = null mock.module('src/utils/settings/settings.js', () => ({ + loadManagedFileSettings: () => ({ settings: null, errors: [] }), + getManagedFileSettingsPresence: () => ({ + hasBase: false, + hasDropIns: false, + }), + parseSettingsFile: () => ({ settings: null, errors: [] }), + getSettingsRootPathForSource: () => '', + getSettingsFilePathForSource: () => undefined, + getRelativeSettingsFilePathForSource: () => '', getInitialSettings: () => mockSettings, + getSettingsForSource: () => mockSettings, + getPolicySettingsOrigin: () => null, + getSettingsWithErrors: () => ({ settings: mockSettings, errors: [] }), + getSettingsWithSources: () => ({ effective: mockSettings, sources: [] }), + getSettings_DEPRECATED: () => mockSettings, + settingsMergeCustomizer: () => undefined, + getManagedSettingsKeysForLogging: () => [], + // Keep unrelated exports aligned with the real settings module so this + // full-surface mock cannot change later test files if Bun keeps it alive. + hasAutoModeOptIn: () => true, + hasSkipDangerousModePermissionPrompt: () => false, + getAutoModeConfig: () => undefined, + getUseAutoModeDuringPlan: () => true, + rawSettingsContainsKey: (key: string) => key in mockSettings, updateSettingsForSource: (source: string, patch: Record) => { lastUpdate = { source, patch } mockSettings = { ...mockSettings, ...patch } }, })) -// Import AFTER mocks are registered -const { isPoorModeActive, setPoorMode } = await import('../poorMode.js') +afterAll(() => { + mock.restore() + mock.module('src/utils/settings/settings.js', () => settingsModule) +}) -// ── Helpers ────────────────────────────────────────────────────────────────── - -/** Reset module-level singleton between tests by re-importing a fresh copy. */ -async function freshModule() { - // Bun caches modules; we manipulate the exported functions directly since - // the singleton `poorModeActive` is reset to null only on first import. - // Instead we test the observable behaviour through set/get pairs. -} +// Import AFTER mocks are registered. The query suffix gives this file its own +// module instance so cross-file poorMode.js mocks cannot replace the subject +// under test during Bun's shared coverage run. +const poorModeModulePath = '../poorMode.js?poorModeTest' +const { isPoorModeActive, setPoorMode } = (await import( + poorModeModulePath +)) as typeof import('../poorMode.js') // ── Tests ──────────────────────────────────────────────────────────────────── diff --git a/src/services/AgentSummary/__tests__/agentSummary.test.ts b/src/services/AgentSummary/__tests__/agentSummary.test.ts index 0ab070080..cffc614a3 100644 --- a/src/services/AgentSummary/__tests__/agentSummary.test.ts +++ b/src/services/AgentSummary/__tests__/agentSummary.test.ts @@ -1,16 +1,14 @@ -import { - afterAll, - afterEach, - beforeEach, - describe, - expect, - mock, - test, -} from 'bun:test' -import { debugMock } from '../../../../tests/mocks/debug' -import { logMock } from '../../../../tests/mocks/log' +import { beforeEach, describe, expect, test } from 'bun:test' import { asAgentId } from '../../../types/ids.js' -import type { CacheSafeParams } from '../../../utils/forkedAgent.js' +import type { Message } from '../../../types/message.js' +import type { + CacheSafeParams, + ForkedAgentResult, +} from '../../../utils/forkedAgent.js' +import { + type AgentSummaryDependencies, + startAgentSummarization, +} from '../agentSummary.js' const transcriptMessages = [ { type: 'user', message: { content: 'start' }, uuid: 'u1' }, @@ -20,114 +18,195 @@ const transcriptMessages = [ uuid: 'a1', }, { type: 'user', message: { content: 'continue' }, uuid: 'u2' }, -] +] as unknown as Message[] -let poorModeActive = false -let forkCalls = 0 -let updateCalls: Array<{ taskId: string; summary: string }> = [] -let transcript = { messages: transcriptMessages } -const sessionStorageSnapshot = { - ...(require('../../../utils/sessionStorage.ts') as Record), +type ForkCall = { + cacheSafeParams: CacheSafeParams } -mock.module('src/commands/poor/poorMode.js', () => ({ - isPoorModeActive: () => poorModeActive, -})) - -mock.module('src/tasks/LocalAgentTask/LocalAgentTask.js', () => ({ - updateAgentSummary: (taskId: string, summary: string) => { - updateCalls.push({ taskId, summary }) - }, -})) - -mock.module( - '@claude-code-best/builtin-tools/tools/AgentTool/runAgent.js', - () => ({ - filterIncompleteToolCalls: (messages: T) => messages, - }), -) - -mock.module('src/utils/debug.js', debugMock) -mock.module('src/utils/log.js', logMock) - -mock.module('src/utils/forkedAgent.js', () => ({ - runForkedAgent: async () => { - forkCalls += 1 - return { - messages: [ - { - type: 'assistant', - message: { - content: [{ type: 'text', text: 'Reading udsClient.ts' }], - }, - }, - ], - } - }, -})) - -mock.module('src/utils/sessionStorage.js', () => ({ - ...sessionStorageSnapshot, - getAgentTranscript: async () => transcript, -})) - -afterAll(() => { - mock.module('src/utils/sessionStorage.js', () => - require('../../../utils/sessionStorage.ts'), - ) -}) - describe('startAgentSummarization', () => { - const realSetTimeout = globalThis.setTimeout - const realClearTimeout = globalThis.clearTimeout - let scheduled: - | ((...args: Parameters void)>) => void) - | undefined + let scheduled: (() => void | Promise) | undefined + let handle: { stop: () => void } | undefined + let forkCalls: ForkCall[] + let updateCalls: Array<{ taskId: string; summary: string }> + let transcriptMessagesForTest: Message[] + let debugLogs: string[] + let loggedErrors: Error[] + let clearedHandles: unknown[] - beforeEach(() => { - poorModeActive = false - forkCalls = 0 - updateCalls = [] - transcript = { messages: transcriptMessages } - scheduled = undefined - globalThis.setTimeout = ((callback: TimerHandler) => { - scheduled = callback as (...args: unknown[]) => void - return 1 as unknown as ReturnType - }) as unknown as typeof setTimeout - globalThis.clearTimeout = (() => undefined) as typeof clearTimeout - }) - - afterEach(() => { - globalThis.setTimeout = realSetTimeout - globalThis.clearTimeout = realClearTimeout - }) - - test('summarizes bounded transcript once and skips unchanged fingerprints', async () => { - const { startAgentSummarization } = await import('../agentSummary.js') - - const handle = startAgentSummarization( + function startTestSummarization( + dependencies: AgentSummaryDependencies = {}, + ): { stop: () => void } { + return startAgentSummarization( 'task-1', asAgentId('a0000000000000000'), { - forkContextMessages: [{ type: 'user', message: { content: 'old' } }], + forkContextMessages: [ + { type: 'user', message: { content: 'stale' }, uuid: 'old' }, + ], model: 'claude-test', } as unknown as CacheSafeParams, () => undefined, + { + clearTimeout: ((timeoutId: unknown) => { + clearedHandles.push(timeoutId) + }) as typeof clearTimeout, + getAgentTranscript: async () => ({ + messages: transcriptMessagesForTest, + contentReplacements: [], + }), + isPoorModeActive: () => false, + logError: error => { + loggedErrors.push( + error instanceof Error ? error : new Error(String(error)), + ) + }, + logForDebugging: message => { + debugLogs.push(message) + }, + runForkedAgent: async (args: ForkCall) => { + forkCalls.push(args) + return { + messages: [ + { + type: 'assistant', + message: { + content: [{ type: 'text', text: 'Reading udsClient.ts' }], + }, + }, + ], + } as unknown as ForkedAgentResult + }, + setTimeout: ((callback: TimerHandler) => { + if (typeof callback !== 'function') { + throw new Error('Expected timer callback') + } + scheduled = callback as () => void | Promise + return 1 as unknown as ReturnType + }) as unknown as typeof setTimeout, + updateAgentSummary: (taskId: string, summary: string) => { + updateCalls.push({ taskId, summary }) + }, + ...dependencies, + }, ) + } + + beforeEach(() => { + forkCalls = [] + updateCalls = [] + scheduled = undefined + handle = undefined + transcriptMessagesForTest = transcriptMessages + debugLogs = [] + loggedErrors = [] + clearedHandles = [] + }) + + test('summarizes bounded transcript once and skips unchanged fingerprints', async () => { + handle = startTestSummarization() expect(typeof scheduled).toBe('function') await scheduled!() - expect(forkCalls).toBe(1) + expect(forkCalls).toHaveLength(1) expect(updateCalls).toEqual([ { taskId: 'task-1', summary: 'Reading udsClient.ts' }, ]) + const forkContext = forkCalls[0].cacheSafeParams.forkContextMessages ?? [] + expect(forkContext.map(message => String(message.uuid))).toEqual([ + 'u1', + 'a1', + 'u2', + ]) + expect(forkContext.some(message => String(message.uuid) === 'old')).toBe( + false, + ) + await scheduled!() - expect(forkCalls).toBe(1) + expect(forkCalls).toHaveLength(1) expect(updateCalls).toHaveLength(1) + }) + + test('skips summarization when filtering leaves too little bounded context', async () => { + transcriptMessagesForTest = [ + { type: 'user', message: { content: 'start' }, uuid: 'u1' }, + { + type: 'assistant', + uuid: 'a1', + message: { + content: [{ type: 'tool_use', id: 'missing', name: 'Read' }], + }, + }, + { type: 'user', message: { content: 'continue' }, uuid: 'u2' }, + ] as unknown as Message[] + + handle = startTestSummarization() + + expect(typeof scheduled).toBe('function') + await scheduled!() + + expect(forkCalls).toEqual([]) + expect(updateCalls).toEqual([]) + expect(debugLogs).toContain( + '[AgentSummary] Skipping summary for task-1: no bounded context available', + ) + }) + + test('skips summarization before building context when transcript is too short', async () => { + transcriptMessagesForTest = transcriptMessages.slice(0, 2) + handle = startTestSummarization() + + expect(typeof scheduled).toBe('function') + await scheduled!() + + expect(forkCalls).toEqual([]) + expect(updateCalls).toEqual([]) + expect(debugLogs).toContain( + '[AgentSummary] Skipping summary for task-1: not enough messages (2)', + ) + }) + + test('skips and reschedules while poor mode is active', async () => { + handle = startTestSummarization({ + isPoorModeActive: () => true, + }) + + expect(typeof scheduled).toBe('function') + await scheduled!() + + expect(forkCalls).toEqual([]) + expect(updateCalls).toEqual([]) + expect(debugLogs).toContain( + '[AgentSummary] Skipping summary — poor mode active', + ) + }) + + test('logs summary errors and keeps the next timer owned by the summarizer', async () => { + const error = new Error('fork failed') + handle = startTestSummarization({ + runForkedAgent: async () => { + throw error + }, + }) + + expect(typeof scheduled).toBe('function') + await scheduled!() + + expect(loggedErrors).toEqual([error]) + expect(updateCalls).toEqual([]) + }) + + test('stop clears the pending summary timer', () => { + handle = startTestSummarization() handle.stop() + + expect(debugLogs).toContain( + '[AgentSummary] Stopping summarization for task-1', + ) + expect(clearedHandles).toEqual([1]) }) }) diff --git a/src/services/AgentSummary/__tests__/summaryContext.test.ts b/src/services/AgentSummary/__tests__/summaryContext.test.ts index 1c701d14b..5aafcf3c1 100644 --- a/src/services/AgentSummary/__tests__/summaryContext.test.ts +++ b/src/services/AgentSummary/__tests__/summaryContext.test.ts @@ -1,6 +1,8 @@ import { describe, expect, test } from 'bun:test' import type { Message } from '../../../types/message.js' import { + buildSummaryContext, + estimateMessageChars, getSummaryContextFingerprint, MAX_SUMMARY_CONTEXT_CHARS, selectSummaryContextMessages, @@ -75,6 +77,21 @@ describe('selectSummaryContextMessages', () => { expect(selected).toEqual([]) }) + test('stops at an older oversized message after keeping the recent suffix', () => { + const messages = [ + makeMessage('user', 'u1', 'x'.repeat(5_000)), + makeMessage('user', 'u2', 'small prompt'), + makeMessage('assistant', 'a2', 'small answer'), + ] + + const selected = selectSummaryContextMessages(messages, { + maxMessages: 10, + maxChars: 1_000, + }) + + expect(selected.map(message => String(message.uuid))).toEqual(['u2', 'a2']) + }) + test('drops leading orphan tool results after bounding', () => { const messages = [ makeMessage('assistant', 'a0', 'older assistant'), @@ -102,6 +119,35 @@ describe('selectSummaryContextMessages', () => { }) describe('getSummaryContextFingerprint', () => { + test('estimates circular messages as unbounded', () => { + const circular = makeMessage('assistant', 'a1', 'cycle') as Message & { + self?: unknown + } + circular.self = circular + + expect(estimateMessageChars(circular)).toBe(Number.POSITIVE_INFINITY) + }) + + test('ignores non-json primitive fields in size estimates', () => { + const message = makeMessage('assistant', 'a1', 'metadata') as Message & { + skipUndefined?: undefined + skipFunction?: () => void + skipSymbol?: symbol + } + message.skipUndefined = undefined + message.skipFunction = () => undefined + message.skipSymbol = Symbol('ignored') + + expect(estimateMessageChars(message)).toBeGreaterThan(0) + }) + + test('treats unsupported top-level primitives as zero-size estimates', () => { + expect( + estimateMessageChars((() => undefined) as unknown as Message), + ).toBe(0) + expect(estimateMessageChars(1n as unknown as Message)).toBe(0) + }) + test('returns null for an empty transcript', () => { expect(getSummaryContextFingerprint([])).toBeNull() }) @@ -146,4 +192,77 @@ describe('getSummaryContextFingerprint', () => { expect(first).not.toBe(second) }) + + test('fingerprints circular message references without recursing forever', () => { + const circular = makeMessage('assistant', 'a1', 'cycle') as Message & { + self?: unknown + } + circular.self = circular + + expect(getSummaryContextFingerprint([circular])).toContain(':a1:') + }) +}) + +describe('buildSummaryContext', () => { + test('returns bounded messages and fingerprint for summarizable context', () => { + const messages = [ + { type: 'user', uuid: 'u1', message: { content: 'start' } }, + { + type: 'assistant', + uuid: 'a1', + message: { content: [{ type: 'text', text: 'working' }] }, + }, + { type: 'user', uuid: 'u2', message: { content: 'continue' } }, + ] as unknown as Message[] + + const result = buildSummaryContext(messages, null) + + expect(result.skipReason).toBeUndefined() + expect(result.messages.map(message => String(message.uuid))).toEqual([ + 'u1', + 'a1', + 'u2', + ]) + expect(result.fingerprint).toContain('3:u2:') + }) + + test('reports unchanged contexts by fingerprint', () => { + const messages = [ + { type: 'user', uuid: 'u1', message: { content: 'start' } }, + { + type: 'assistant', + uuid: 'a1', + message: { content: [{ type: 'text', text: 'working' }] }, + }, + { type: 'user', uuid: 'u2', message: { content: 'continue' } }, + ] as unknown as Message[] + const first = buildSummaryContext(messages, null) + + const second = buildSummaryContext(messages, first.fingerprint) + + expect(second.skipReason).toBe('unchanged') + expect(second.fingerprint).toBe(first.fingerprint) + }) + + test('filters incomplete tool calls before deciding context is too small', () => { + const messages = [ + { type: 'user', uuid: 'u1', message: { content: 'start' } }, + { + type: 'assistant', + uuid: 'a1', + message: { + content: [{ type: 'tool_use', id: 'missing', name: 'Read' }], + }, + }, + { type: 'user', uuid: 'u2', message: { content: 'continue' } }, + ] as unknown as Message[] + + const result = buildSummaryContext(messages, null) + + expect(result.skipReason).toBe('too_small') + expect(result.messages.map(message => String(message.uuid))).toEqual([ + 'u1', + 'u2', + ]) + }) }) diff --git a/src/services/AgentSummary/__tests__/summaryPrompt.test.ts b/src/services/AgentSummary/__tests__/summaryPrompt.test.ts new file mode 100644 index 000000000..9e8f03cac --- /dev/null +++ b/src/services/AgentSummary/__tests__/summaryPrompt.test.ts @@ -0,0 +1,34 @@ +import { describe, expect, test } from 'bun:test' +import { + buildSummaryPrompt, + createSummaryPromptMessage, +} from '../summaryPrompt.js' + +describe('buildSummaryPrompt', () => { + test('builds the first summary prompt without previous-summary pressure', () => { + const prompt = buildSummaryPrompt(null) + + expect(prompt).toContain('Describe your most recent action') + expect(prompt).toContain('Good: "Reading runAgent.ts"') + expect(prompt).not.toContain('Previous:') + }) + + test('asks for a new summary when a previous one exists', () => { + const prompt = buildSummaryPrompt('Reading udsMessaging.ts') + + expect(prompt).toContain('Previous: "Reading udsMessaging.ts"') + expect(prompt).toContain('say something NEW') + }) +}) + +describe('createSummaryPromptMessage', () => { + test('creates the minimal user message shape used by forked summaries', () => { + const message = createSummaryPromptMessage('Summarize progress') + + expect(message.type).toBe('user') + expect(message.message.role).toBe('user') + expect(message.message.content).toBe('Summarize progress') + expect(message.uuid).toBeString() + expect(message.timestamp).toBeString() + }) +}) diff --git a/src/services/AgentSummary/agentSummary.ts b/src/services/AgentSummary/agentSummary.ts index 2232e839d..d212a5c72 100644 --- a/src/services/AgentSummary/agentSummary.ts +++ b/src/services/AgentSummary/agentSummary.ts @@ -13,7 +13,6 @@ import type { TaskContext } from '../../Task.js' import { isPoorModeActive } from '../../commands/poor/poorMode.js' import { updateAgentSummary } from '../../tasks/LocalAgentTask/LocalAgentTask.js' -import { filterIncompleteToolCalls } from '@claude-code-best/builtin-tools/tools/AgentTool/runAgent.js' import type { AgentId } from '../../types/ids.js' import { logForDebugging } from '../../utils/debug.js' import { @@ -21,38 +20,32 @@ import { runForkedAgent, } from '../../utils/forkedAgent.js' import { logError } from '../../utils/log.js' -import { createUserMessage } from '../../utils/messages.js' import { getAgentTranscript } from '../../utils/sessionStorage.js' +import { buildSummaryContext } from './summaryContext.js' import { - getSummaryContextFingerprint, - selectSummaryContextMessages, -} from './summaryContext.js' + buildSummaryPrompt, + createSummaryPromptMessage, +} from './summaryPrompt.js' const SUMMARY_INTERVAL_MS = 30_000 -function buildSummaryPrompt(previousSummary: string | null): string { - const prevLine = previousSummary - ? `\nPrevious: "${previousSummary}" — say something NEW.\n` - : '' - - return `Describe your most recent action in 3-5 words using present tense (-ing). Name the file or function, not the branch. Do not use tools. -${prevLine} -Good: "Reading runAgent.ts" -Good: "Fixing null check in validate.ts" -Good: "Running auth module tests" -Good: "Adding retry logic to fetchUser" - -Bad (past tense): "Analyzed the branch diff" -Bad (too vague): "Investigating the issue" -Bad (too long): "Reviewing full branch diff and AgentTool.tsx integration" -Bad (branch name): "Analyzed adam/background-summary branch diff"` -} +export type AgentSummaryDependencies = Partial<{ + clearTimeout: typeof clearTimeout + getAgentTranscript: typeof getAgentTranscript + isPoorModeActive: typeof isPoorModeActive + logError: typeof logError + logForDebugging: typeof logForDebugging + runForkedAgent: typeof runForkedAgent + setTimeout: typeof setTimeout + updateAgentSummary: typeof updateAgentSummary +}> export function startAgentSummarization( taskId: string, agentId: AgentId, cacheSafeParams: CacheSafeParams, setAppState: TaskContext['setAppState'], + dependencies: AgentSummaryDependencies = {}, ): { stop: () => void } { // Drop forkContextMessages from the closure — runSummary rebuilds it each // tick from getAgentTranscript(). Without this, the original fork messages @@ -63,46 +56,53 @@ export function startAgentSummarization( let stopped = false let previousSummary: string | null = null let lastHandledTranscriptFingerprint: string | null = null + const clearTimeoutImpl = dependencies.clearTimeout ?? clearTimeout + const getAgentTranscriptImpl = + dependencies.getAgentTranscript ?? getAgentTranscript + const isPoorModeActiveImpl = + dependencies.isPoorModeActive ?? isPoorModeActive + const logErrorImpl = dependencies.logError ?? logError + const logForDebuggingImpl = + dependencies.logForDebugging ?? logForDebugging + const runForkedAgentImpl = dependencies.runForkedAgent ?? runForkedAgent + const setTimeoutImpl = dependencies.setTimeout ?? setTimeout + const updateAgentSummaryImpl = + dependencies.updateAgentSummary ?? updateAgentSummary async function runSummary(): Promise { if (stopped) return - if (isPoorModeActive()) { - logForDebugging('[AgentSummary] Skipping summary — poor mode active') + if (isPoorModeActiveImpl()) { + logForDebuggingImpl('[AgentSummary] Skipping summary — poor mode active') scheduleNext() return } - logForDebugging(`[AgentSummary] Timer fired for agent ${agentId}`) + logForDebuggingImpl(`[AgentSummary] Timer fired for agent ${agentId}`) try { // Read current messages from transcript - const transcript = await getAgentTranscript(agentId) + const transcript = await getAgentTranscriptImpl(agentId) if (!transcript || transcript.messages.length < 3) { // Not enough context yet — finally block will schedule next attempt - logForDebugging( + logForDebuggingImpl( `[AgentSummary] Skipping summary for ${taskId}: not enough messages (${transcript?.messages.length ?? 0})`, ) return } - // Filter to clean message state - const cleanMessages = filterIncompleteToolCalls(transcript.messages) - const summaryContext = filterIncompleteToolCalls( - selectSummaryContextMessages(cleanMessages), + const summaryContext = buildSummaryContext( + transcript.messages, + lastHandledTranscriptFingerprint, ) - const transcriptFingerprint = getSummaryContextFingerprint(summaryContext) - if ( - transcriptFingerprint && - transcriptFingerprint === lastHandledTranscriptFingerprint - ) { - logForDebugging( + if (summaryContext.skipReason === 'unchanged') { + logForDebuggingImpl( `[AgentSummary] Skipping summary for ${taskId}: transcript unchanged`, ) return } - if (summaryContext.length < 3) { - logForDebugging( + if (summaryContext.skipReason === 'too_small') { + logForDebuggingImpl( `[AgentSummary] Skipping summary for ${taskId}: no bounded context available`, ) return @@ -111,11 +111,11 @@ export function startAgentSummarization( // Build fork params with current messages const forkParams: CacheSafeParams = { ...baseParams, - forkContextMessages: summaryContext, + forkContextMessages: summaryContext.messages, } - logForDebugging( - `[AgentSummary] Forking for summary, ${summaryContext.length} messages in context`, + logForDebuggingImpl( + `[AgentSummary] Forking for summary, ${summaryContext.messages.length} messages in context`, ) // Create abort controller for this summary @@ -137,9 +137,9 @@ export function startAgentSummarization( // ContentReplacementState is cloned by default in createSubagentContext // from forkParams.toolUseContext (the subagent's LIVE state captured at // onCacheSafeParams time). No explicit override needed. - const result = await runForkedAgent({ + const result = await runForkedAgentImpl({ promptMessages: [ - createUserMessage({ content: buildSummaryPrompt(previousSummary) }), + createSummaryPromptMessage(buildSummaryPrompt(previousSummary)), ], cacheSafeParams: forkParams, canUseTool, @@ -167,18 +167,18 @@ export function startAgentSummarization( const textBlock = contentArr.find(b => b.type === 'text') if (textBlock?.type === 'text' && textBlock.text.trim()) { const summaryText = textBlock.text.trim() - logForDebugging( + logForDebuggingImpl( `[AgentSummary] Summary result for ${taskId}: ${summaryText}`, ) - lastHandledTranscriptFingerprint = transcriptFingerprint + lastHandledTranscriptFingerprint = summaryContext.fingerprint previousSummary = summaryText - updateAgentSummary(taskId, summaryText, setAppState) + updateAgentSummaryImpl(taskId, summaryText, setAppState) break } } } catch (e) { if (!stopped && e instanceof Error) { - logError(e) + logErrorImpl(e) } } finally { summaryAbortController = null @@ -191,14 +191,14 @@ export function startAgentSummarization( function scheduleNext(): void { if (stopped) return - timeoutId = setTimeout(runSummary, SUMMARY_INTERVAL_MS) + timeoutId = setTimeoutImpl(runSummary, SUMMARY_INTERVAL_MS) } function stop(): void { - logForDebugging(`[AgentSummary] Stopping summarization for ${taskId}`) + logForDebuggingImpl(`[AgentSummary] Stopping summarization for ${taskId}`) stopped = true if (timeoutId) { - clearTimeout(timeoutId) + clearTimeoutImpl(timeoutId) timeoutId = null } if (summaryAbortController) { diff --git a/src/services/AgentSummary/summaryContext.ts b/src/services/AgentSummary/summaryContext.ts index 894a21e36..d4c00e1d4 100644 --- a/src/services/AgentSummary/summaryContext.ts +++ b/src/services/AgentSummary/summaryContext.ts @@ -1,4 +1,5 @@ -import { createHash } from 'crypto' +import { createHash } from 'node:crypto' +import { filterIncompleteToolCalls } from '@claude-code-best/builtin-tools/tools/AgentTool/filterIncompleteToolCalls.js' import type { Message } from '../../types/message.js' export const MAX_SUMMARY_CONTEXT_MESSAGES = 120 @@ -178,3 +179,41 @@ export function selectSummaryContextMessages( return selected } + +export type SummaryContextBuildResult = { + messages: Message[] + fingerprint: string | null + skipReason?: 'too_small' | 'unchanged' +} + +export function buildSummaryContext( + messages: Message[], + previousFingerprint: string | null, +): SummaryContextBuildResult { + const cleanMessages = filterIncompleteToolCalls(messages) + const boundedMessages = filterIncompleteToolCalls( + selectSummaryContextMessages(cleanMessages), + ) + const fingerprint = getSummaryContextFingerprint(boundedMessages) + + if (fingerprint && fingerprint === previousFingerprint) { + return { + messages: boundedMessages, + fingerprint, + skipReason: 'unchanged', + } + } + + if (boundedMessages.length < 3) { + return { + messages: boundedMessages, + fingerprint, + skipReason: 'too_small', + } + } + + return { + messages: boundedMessages, + fingerprint, + } +} diff --git a/src/services/AgentSummary/summaryPrompt.ts b/src/services/AgentSummary/summaryPrompt.ts new file mode 100644 index 000000000..ce3138f2a --- /dev/null +++ b/src/services/AgentSummary/summaryPrompt.ts @@ -0,0 +1,32 @@ +import { randomUUID, type UUID } from 'node:crypto' +import type { UserMessage } from '../../types/message.js' + +export function buildSummaryPrompt(previousSummary: string | null): string { + const prevLine = previousSummary + ? `\nPrevious: "${previousSummary}" — say something NEW.\n` + : '' + + return `Describe your most recent action in 3-5 words using present tense (-ing). Name the file or function, not the branch. Do not use tools. +${prevLine} +Good: "Reading runAgent.ts" +Good: "Fixing null check in validate.ts" +Good: "Running auth module tests" +Good: "Adding retry logic to fetchUser" + +Bad (past tense): "Analyzed the branch diff" +Bad (too vague): "Investigating the issue" +Bad (too long): "Reviewing full branch diff and AgentTool.tsx integration" +Bad (branch name): "Analyzed adam/background-summary branch diff"` +} + +export function createSummaryPromptMessage(content: string): UserMessage { + return { + type: 'user', + message: { + role: 'user', + content, + }, + uuid: randomUUID() as UUID, + timestamp: new Date().toISOString(), + } +} diff --git a/src/utils/__tests__/ndjsonFramer.test.ts b/src/utils/__tests__/ndjsonFramer.test.ts index 35174162a..344c1e58c 100644 --- a/src/utils/__tests__/ndjsonFramer.test.ts +++ b/src/utils/__tests__/ndjsonFramer.test.ts @@ -88,4 +88,66 @@ describe('attachNdjsonFramer', () => { expect(errors[0]?.message).toContain('NDJSON frame exceeded') expect(socket.destroyed).toBe(true) }) + + test('lets callers own oversized-frame shutdown when configured', () => { + const socket = createTestSocket() + const errors: Error[] = [] + + attachNdjsonFramer( + socket, + () => undefined, + text => JSON.parse(text) as unknown, + { + maxFrameBytes: 8, + onFrameError: error => errors.push(error), + destroyOnFrameError: false, + }, + ) + + socket.emitData(Buffer.from('{"long":true}\n')) + + expect(errors[0]?.message).toContain('NDJSON frame exceeded') + expect(socket.destroyed).toBe(false) + }) + + test('reports malformed non-empty frames without changing default compatibility', () => { + const socket = createTestSocket() + const messages: unknown[] = [] + const errors: Error[] = [] + + attachNdjsonFramer( + socket, + msg => messages.push(msg), + text => JSON.parse(text) as unknown, + { + onInvalidFrame: error => errors.push(error), + }, + ) + + socket.emitData(Buffer.from('{not-json\n')) + + expect(messages).toEqual([]) + expect(errors).toHaveLength(1) + expect(socket.destroyed).toBe(false) + }) + + test('destroys malformed frames when configured by the caller', () => { + const socket = createTestSocket() + const errors: Error[] = [] + + attachNdjsonFramer( + socket, + () => undefined, + text => JSON.parse(text) as unknown, + { + destroyOnInvalidFrame: true, + onInvalidFrame: error => errors.push(error), + }, + ) + + socket.emitData(Buffer.from('{not-json\n')) + + expect(errors).toHaveLength(1) + expect(socket.destroyed).toBe(true) + }) }) diff --git a/src/utils/__tests__/teammateMailbox.test.ts b/src/utils/__tests__/teammateMailbox.test.ts index 7f479ed36..7ca595bf1 100644 --- a/src/utils/__tests__/teammateMailbox.test.ts +++ b/src/utils/__tests__/teammateMailbox.test.ts @@ -3,7 +3,7 @@ import { mkdir, readFile, rm, writeFile } from 'node:fs/promises' import { mkdtempSync } from 'node:fs' import { tmpdir } from 'node:os' import { dirname, join } from 'node:path' -import type { Message } from '../../types/message.js' +import type { Message } from 'src/types/message.js' import { compactMailboxMessages, getLastPeerDmSummary, @@ -13,13 +13,14 @@ import { markMessagesAsRead, markMessagesAsReadByPredicate, MAX_MAILBOX_MESSAGE_TEXT_BYTES, + MAX_MAILBOX_FILE_BYTES, MAX_MAILBOX_MESSAGES, MAX_READ_MAILBOX_MESSAGES, MAX_UNREAD_PROTOCOL_MAILBOX_MESSAGES, readMailbox, type TeammateMessage, writeToMailbox, -} from '../teammateMailbox.js' +} from 'src/utils/teammateMailbox.js' let tempHome = '' let previousConfigDir: string | undefined @@ -55,21 +56,6 @@ async function readRawMailbox( return JSON.parse(content) as TeammateMessage[] } -beforeEach(() => { - previousConfigDir = process.env.CLAUDE_CONFIG_DIR - tempHome = mkdtempSync(join(tmpdir(), 'teammate-mailbox-')) - process.env.CLAUDE_CONFIG_DIR = tempHome -}) - -afterEach(async () => { - if (previousConfigDir === undefined) { - delete process.env.CLAUDE_CONFIG_DIR - } else { - process.env.CLAUDE_CONFIG_DIR = previousConfigDir - } - await rm(tempHome, { recursive: true, force: true }) -}) - describe('compactMailboxMessages', () => { test('prioritizes unread messages and keeps only recent read history', () => { const compacted = compactMailboxMessages( @@ -175,9 +161,46 @@ describe('compactMailboxMessages', () => { expect(compacted.length).toBeLessThan(20) expect(compacted.at(-1)?.text).toContain('msg-19') }) + + test('returns an empty mailbox when even one message exceeds retained budget', () => { + const compacted = compactMailboxMessages([message('too-large', false)], { + maxMessages: 10, + maxReadMessages: 0, + maxRetainedBytes: 1, + }) + + expect(compacted).toEqual([]) + }) + + test('returns an empty mailbox when all retention lanes are disabled', () => { + const compacted = compactMailboxMessages([message('unread', false)], { + maxMessages: 0, + maxReadMessages: 0, + maxUnreadProtocolMessages: 0, + maxRetainedBytes: 1_000, + }) + + expect(compacted).toEqual([]) + }) }) describe('teammate mailbox retention', () => { + beforeEach(() => { + previousConfigDir = process.env.CLAUDE_CONFIG_DIR + tempHome = mkdtempSync(join(tmpdir(), 'teammate-mailbox-')) + process.env.CLAUDE_CONFIG_DIR = tempHome + }) + + afterEach(async () => { + if (previousConfigDir === undefined) { + delete process.env.CLAUDE_CONFIG_DIR + } else { + process.env.CLAUDE_CONFIG_DIR = previousConfigDir + } + await rm(tempHome, { recursive: true, force: true }) + tempHome = '' + }) + test('writeToMailbox compacts oversized unread inbox files', async () => { const existing = Array.from( { length: MAX_MAILBOX_MESSAGES + 20 }, @@ -319,6 +342,23 @@ describe('teammate mailbox retention', () => { expect(await readFile(inboxPath, 'utf-8')).toBe('{not-json') }) + test('writeToMailbox rejects when the inbox path is already a directory', async () => { + const inboxPath = getInboxPath('worker', 'alpha') + await mkdir(inboxPath, { recursive: true }) + + await expect( + writeToMailbox( + 'worker', + { + from: 'team-lead', + text: 'new', + timestamp: new Date(5).toISOString(), + }, + 'alpha', + ), + ).rejects.toThrow() + }) + test('readMailbox fails closed on corrupt mailbox content', async () => { const inboxPath = getInboxPath('worker', 'alpha') await mkdir(dirname(inboxPath), { recursive: true }) @@ -326,6 +366,76 @@ describe('teammate mailbox retention', () => { await expect(readMailbox('worker', 'alpha')).rejects.toThrow() }) + + test('readMailbox rejects non-array mailbox files', async () => { + const inboxPath = getInboxPath('worker', 'alpha') + await mkdir(dirname(inboxPath), { recursive: true }) + await writeFile(inboxPath, JSON.stringify({ text: 'not an array' }), 'utf-8') + + await expect(readMailbox('worker', 'alpha')).rejects.toThrow( + 'expected message array', + ) + }) + + test('readMailbox rejects malformed stored message shapes', async () => { + const inboxPath = getInboxPath('worker', 'alpha') + await mkdir(dirname(inboxPath), { recursive: true }) + await writeFile( + inboxPath, + JSON.stringify([{ from: 'lead', text: 'missing timestamp' }]), + 'utf-8', + ) + + await expect(readMailbox('worker', 'alpha')).rejects.toThrow( + 'Invalid mailbox message shape', + ) + }) + + test('readMailbox rejects non-object stored messages', async () => { + const inboxPath = getInboxPath('worker', 'alpha') + await mkdir(dirname(inboxPath), { recursive: true }) + await writeFile(inboxPath, JSON.stringify(['not an object']), 'utf-8') + + await expect(readMailbox('worker', 'alpha')).rejects.toThrow( + 'expected object', + ) + }) + + test('readMailbox rejects oversized mailbox files before parsing', async () => { + const inboxPath = getInboxPath('worker', 'alpha') + await mkdir(dirname(inboxPath), { recursive: true }) + await writeFile(inboxPath, `[${' '.repeat(MAX_MAILBOX_FILE_BYTES)}]`, 'utf-8') + + await expect(readMailbox('worker', 'alpha')).rejects.toThrow( + 'Mailbox file exceeds', + ) + }) + + test('markMessageAsReadByIdentity returns false for missing mailbox files', async () => { + await expect( + markMessageAsReadByIdentity('worker', 'alpha', message('absent', false)), + ).resolves.toBe(false) + }) + + test('markMessageAsReadByIdentity returns false when the expected message moved out', async () => { + await seedMailbox('worker', 'alpha', [message('other', false)]) + + await expect( + markMessageAsReadByIdentity('worker', 'alpha', message('missing', false)), + ).resolves.toBe(false) + + expect((await readRawMailbox('worker', 'alpha'))[0]?.read).toBe(false) + }) + + test('markMessageAsReadByIdentity returns false on corrupt mailbox content', async () => { + const inboxPath = getInboxPath('worker', 'alpha') + await mkdir(dirname(inboxPath), { recursive: true }) + await writeFile(inboxPath, '{not-json', 'utf-8') + + await expect( + markMessageAsReadByIdentity('worker', 'alpha', message('missing', false)), + ).resolves.toBe(false) + }) }) describe('getLastPeerDmSummary', () => { diff --git a/src/utils/__tests__/udsMessaging.test.ts b/src/utils/__tests__/udsMessaging.test.ts index ef943cb76..ed7454e0b 100644 --- a/src/utils/__tests__/udsMessaging.test.ts +++ b/src/utils/__tests__/udsMessaging.test.ts @@ -3,24 +3,31 @@ import { chmod, mkdir, mkdtemp, + readdir, rm, stat, symlink, unlink, + writeFile, } from 'node:fs/promises' +import { createHash } from 'node:crypto' import { createConnection, createServer } from 'node:net' import { dirname, join } from 'node:path' import { tmpdir } from 'node:os' import { drainInbox, + getDefaultUdsSocketPath, MAX_UDS_INBOX_ENTRIES, MAX_UDS_INBOX_BYTES, MAX_UDS_FRAME_BYTES, + MAX_UDS_CLIENTS, + formatUdsAddress, parseUdsTarget, sendUdsMessage, setOnEnqueue, startUdsMessaging, stopUdsMessaging, + UDS_AUTH_TIMEOUT_MS, } from '../udsMessaging.js' let previousConfigDir: string | undefined @@ -192,7 +199,7 @@ describe('UDS inbox retention', () => { try { const { isPeerAlive } = await import('../udsClient.js') - expect(await isPeerAlive(path)).toBe(false) + expect(await isPeerAlive(path, 3_000, 'test-token')).toBe(false) } finally { await closeServer(receiver) if (process.platform !== 'win32') { @@ -210,6 +217,29 @@ describe('UDS inbox retention', () => { ) }) + test('udsClient send reports connection failures without leaking token state', async () => { + const path = socketPath('uds-client-connect-error') + const capabilityDir = join(tempConfigDir, 'messaging-capabilities') + const capabilityName = `${createHash('sha256').update(path).digest('hex')}.json` + await mkdir(capabilityDir, { recursive: true, mode: 0o700 }) + await writeFile( + join(capabilityDir, capabilityName), + JSON.stringify({ socketPath: path, authToken: 'test-token' }), + 'utf-8', + ) + const { sendToUdsSocket } = await import('../udsClient.js') + + await expect(sendToUdsSocket(path, 'hello')).rejects.toThrow( + 'Failed to connect to peer', + ) + }) + + test('sendUdsMessage fails closed before connecting without an auth token', async () => { + await expect( + sendUdsMessage(socketPath('no-auth-token'), { type: 'text', data: 'x' }), + ).rejects.toThrow('without auth token') + }) + test('drained entries never expose the UDS auth token', async () => { const path = socketPath('strip-token') await startUdsMessaging(path, { isExplicit: true }) @@ -232,6 +262,7 @@ describe('UDS inbox retention', () => { await startUdsMessaging(path, { isExplicit: true }) const response = await new Promise((resolve, reject) => { + let responseText = '' const conn = createConnection(path, () => { conn.write(`${JSON.stringify({ type: 'text', data: 'bad' })}\n`) }) @@ -242,10 +273,10 @@ describe('UDS inbox retention', () => { conn.on('data', chunk => { const text = chunk.toString('utf-8') if (text.includes('\n')) { - conn.end() - resolve(text) + responseText = text } }) + conn.on('close', () => resolve(responseText)) conn.on('error', reject) }) @@ -253,6 +284,56 @@ describe('UDS inbox retention', () => { expect(drainInbox()).toEqual([]) }) + test('disconnects malformed JSON clients without enqueueing inbox work', async () => { + const path = socketPath('malformed-client') + await startUdsMessaging(path, { isExplicit: true }) + + const response = await new Promise((resolve, reject) => { + let responseText = '' + const conn = createConnection(path, () => { + conn.write('{not-json\n') + }) + conn.setTimeout(5_000, () => { + conn.destroy() + reject(new Error('Timed out waiting for malformed frame close')) + }) + conn.on('data', chunk => { + responseText += chunk.toString('utf-8') + }) + conn.on('close', () => resolve(responseText)) + conn.on('error', reject) + }) + + const parsed = JSON.parse(response) + expect(parsed.type).toBe('error') + expect(parsed.data).toBe('invalid frame') + expect(drainInbox()).toEqual([]) + }) + + test('disconnects idle unauthenticated clients', async () => { + const path = socketPath('idle-client') + await startUdsMessaging(path, { isExplicit: true }) + + const response = await new Promise((resolve, reject) => { + let responseText = '' + const conn = createConnection(path) + conn.setTimeout(UDS_AUTH_TIMEOUT_MS + 2_000, () => { + conn.destroy() + reject(new Error('Timed out waiting for auth timeout close')) + }) + conn.on('data', chunk => { + responseText += chunk.toString('utf-8') + }) + conn.on('close', () => resolve(responseText)) + conn.on('error', reject) + }) + + const parsed = JSON.parse(response) + expect(parsed.type).toBe('error') + expect(parsed.data).toBe('authentication timeout') + expect(drainInbox()).toEqual([]) + }) + test('destroys oversized frames before enqueueing inbox work', async () => { const path = socketPath('oversized') await startUdsMessaging(path, { isExplicit: true }) @@ -272,6 +353,14 @@ describe('UDS inbox retention', () => { expect(drainInbox()).toEqual([]) }) + test('default socket path is regenerated after stop', async () => { + const firstPath = getDefaultUdsSocketPath() + await startUdsMessaging(firstPath) + await stopUdsMessaging() + + expect(getDefaultUdsSocketPath()).not.toBe(firstPath) + }) + test('rejects oversized receiver responses before retaining them', async () => { const path = socketPath('oversized-response') if (process.platform !== 'win32') { @@ -303,9 +392,71 @@ describe('UDS inbox retention', () => { } }) + test('rejects closed receiver responses without waiting for timeout', async () => { + const path = socketPath('closed-response') + if (process.platform !== 'win32') { + await mkdir(dirname(path), { recursive: true }) + } + const receiver = createServer(socket => { + socket.end() + }) + await new Promise((resolve, reject) => { + receiver.on('error', reject) + receiver.listen(path, () => resolve()) + }) + + try { + await expect( + sendUdsMessage( + path, + { type: 'text', data: 'hello' }, + { authToken: 'test-token' }, + ), + ).rejects.toThrow('before response') + } finally { + await closeServer(receiver) + if (process.platform !== 'win32') { + await unlink(path).catch(() => undefined) + } + } + }) + + test('rejects malformed receiver responses without waiting for timeout', async () => { + const path = socketPath('malformed-response') + if (process.platform !== 'win32') { + await mkdir(dirname(path), { recursive: true }) + } + const receiver = createServer(socket => { + socket.on('data', () => { + socket.write('{not-json\n') + }) + }) + await new Promise((resolve, reject) => { + receiver.on('error', reject) + receiver.listen(path, () => resolve()) + }) + + try { + await expect( + sendUdsMessage( + path, + { type: 'text', data: 'hello' }, + { authToken: 'test-token' }, + ), + ).rejects.toThrow('Invalid UDS response frame') + } finally { + await closeServer(receiver) + if (process.platform !== 'win32') { + await unlink(path).catch(() => undefined) + } + } + }) + test('rejects inline auth token UDS targets instead of parsing them', async () => { const path = socketPath('inline-token') + expect(formatUdsAddress(path)).toBe(`uds:${path}`) + const targetWithToken = `${path}#token=secret` expect(() => parseUdsTarget(targetWithToken)).toThrow('inline auth token') try { @@ -320,6 +471,23 @@ describe('UDS inbox retention', () => { ) }) + test('fails closed and cleans temp files when capability target is occupied', async () => { + const path = socketPath('capability-target-dir') + const capabilityDir = join(tempConfigDir, 'messaging-capabilities') + const capabilityName = `${createHash('sha256').update(path).digest('hex')}.json` + await mkdir(join(capabilityDir, capabilityName), { + recursive: true, + mode: 0o700, + }) + + await expect( + startUdsMessaging(path, { isExplicit: true }), + ).rejects.toThrow() + + expect(process.env.CLAUDE_CODE_MESSAGING_SOCKET).toBeUndefined() + expect(await readdir(capabilityDir)).toEqual([capabilityName]) + }) + if (process.platform !== 'win32') { test('creates the listening socket with owner-only permissions', async () => { const path = socketPath('socket-mode') @@ -341,9 +509,11 @@ describe('UDS inbox retention', () => { await chmod(capabilityDir, 0o755) try { + const path = socketPath('broad-capdir') await expect( - startUdsMessaging(socketPath('broad-capdir'), { isExplicit: true }), + startUdsMessaging(path, { isExplicit: true }), ).rejects.toThrow('permissions are too broad') + await expect(stat(path)).rejects.toThrow() } finally { if (previousConfigDir === undefined) { delete process.env.CLAUDE_CONFIG_DIR @@ -397,5 +567,65 @@ describe('UDS inbox retention', () => { await rm(parent, { recursive: true, force: true }) } }) + + test('fails closed when an explicit socket parent is a file', async () => { + const parentFile = join( + tmpdir(), + `uds-socket-parent-file-${process.pid}-${Date.now()}-${Math.random().toString(16).slice(2)}`, + ) + await writeFile(parentFile, 'not a directory', 'utf-8') + + try { + await expect( + startUdsMessaging(join(parentFile, 'messaging.sock'), { + isExplicit: true, + }), + ).rejects.toThrow('socket parent is not a directory') + } finally { + await rm(parentFile, { force: true }) + } + }) + + test('stop tolerates an already removed socket path', async () => { + const path = socketPath('already-removed') + await startUdsMessaging(path, { isExplicit: true }) + await unlink(path) + + await stopUdsMessaging() + + expect(process.env.CLAUDE_CODE_MESSAGING_SOCKET).toBeUndefined() + }) + + test('rejects clients over the configured connection cap', async () => { + const path = socketPath('client-cap') + await startUdsMessaging(path, { isExplicit: true }) + const sockets: ReturnType[] = [] + + try { + for (let i = 0; i < MAX_UDS_CLIENTS; i++) { + const socket = await new Promise>( + (resolve, reject) => { + const conn = createConnection(path, () => resolve(conn)) + conn.on('error', reject) + }, + ) + sockets.push(socket) + } + + await new Promise((resolve, reject) => { + const extra = createConnection(path) + extra.on('close', () => resolve()) + extra.on('error', reject) + extra.setTimeout(5_000, () => { + extra.destroy() + reject(new Error('Timed out waiting for client cap close')) + }) + }) + } finally { + for (const socket of sockets) { + socket.destroy() + } + } + }) } }) diff --git a/src/utils/__tests__/udsResponseReader.test.ts b/src/utils/__tests__/udsResponseReader.test.ts new file mode 100644 index 000000000..3ec35422e --- /dev/null +++ b/src/utils/__tests__/udsResponseReader.test.ts @@ -0,0 +1,218 @@ +import { describe, expect, test } from 'bun:test' +import { EventEmitter } from 'node:events' +import type { Socket } from 'node:net' +import { attachUdsResponseReader } from '../udsResponseReader.js' + +class FakeSocket extends EventEmitter { + destroyed = false + ended = false + + destroy(): this { + this.destroyed = true + this.emit('close', true) + return this + } + + end(): this { + this.ended = true + this.emit('close', false) + return this + } + + emitData(chunk: Buffer): void { + this.emit('data', chunk) + } +} + +function asSocket(socket: FakeSocket): Socket { + return socket as unknown as Socket +} + +describe('attachUdsResponseReader', () => { + test('tracks byte limits across split multibyte response chunks', () => { + const socket = new FakeSocket() + let settled = false + let settledError: Error | undefined + + attachUdsResponseReader(asSocket(socket), { + maxFrameBytes: 128, + onSettled: error => { + settled = true + settledError = error + }, + }) + + const multibyte = String.fromCodePoint(0x20ac) + const frame = Buffer.from( + JSON.stringify({ type: 'response', data: `ok ${multibyte}` }) + '\n', + 'utf8', + ) + const multibyteStart = frame.indexOf(Buffer.from(multibyte, 'utf8')[0]) + + socket.emitData(frame.subarray(0, multibyteStart + 1)) + expect(settled).toBe(false) + + socket.emitData(frame.subarray(multibyteStart + 1)) + expect(settled).toBe(true) + expect(settledError).toBeUndefined() + expect(socket.ended).toBe(true) + }) + + test('rejects malformed response frames immediately', () => { + const socket = new FakeSocket() + let settledError: Error | undefined + + attachUdsResponseReader(asSocket(socket), { + maxFrameBytes: 128, + onSettled: error => { + settledError = error + }, + }) + + socket.emitData(Buffer.from('{bad-json}\n')) + + expect(settledError?.message).toBe('Invalid UDS response frame') + expect(socket.destroyed).toBe(true) + }) + + test('skips blank frames before a valid response', () => { + const socket = new FakeSocket() + let settled = false + let settledError: Error | undefined + + attachUdsResponseReader(asSocket(socket), { + maxFrameBytes: 128, + onSettled: error => { + settled = true + settledError = error + }, + }) + + socket.emitData(Buffer.from('\n \n')) + expect(settled).toBe(false) + + socket.emitData(Buffer.from(`${JSON.stringify({ type: 'response' })}\n`)) + expect(settled).toBe(true) + expect(settledError).toBeUndefined() + expect(socket.ended).toBe(true) + }) + + test('continues scanning when blank and valid frames share one chunk', () => { + const socket = new FakeSocket() + let settled = false + let settledError: Error | undefined + + attachUdsResponseReader(asSocket(socket), { + maxFrameBytes: 128, + onSettled: error => { + settled = true + settledError = error + }, + }) + + socket.emitData( + Buffer.from(`\n${JSON.stringify({ type: 'response' })}\n`), + ) + + expect(settled).toBe(true) + expect(settledError).toBeUndefined() + expect(socket.ended).toBe(true) + }) + + test('rejects receiver error frames', () => { + const socket = new FakeSocket() + let settledError: Error | undefined + + attachUdsResponseReader(asSocket(socket), { + maxFrameBytes: 128, + onSettled: error => { + settledError = error + }, + }) + + socket.emitData( + Buffer.from(`${JSON.stringify({ type: 'error', data: 'denied' })}\n`), + ) + + expect(settledError?.message).toBe('denied') + expect(socket.destroyed).toBe(true) + }) + + test('ignores unrelated receiver frames until a terminal response arrives', () => { + const socket = new FakeSocket() + let settled = false + let settledError: Error | undefined + + attachUdsResponseReader(asSocket(socket), { + maxFrameBytes: 128, + onSettled: error => { + settled = true + settledError = error + }, + }) + + socket.emitData( + Buffer.from( + `${JSON.stringify({ type: 'notification', data: 'queued' })}\n`, + ), + ) + expect(settled).toBe(false) + + socket.emitData(Buffer.from(`${JSON.stringify({ type: 'response' })}\n`)) + expect(settled).toBe(true) + expect(settledError).toBeUndefined() + }) + + test('uses custom socket error formatting', () => { + const socket = new FakeSocket() + let settledError: Error | undefined + + attachUdsResponseReader(asSocket(socket), { + maxFrameBytes: 128, + onSettled: error => { + settledError = error + }, + formatSocketError: error => + new Error(`wrapped:${(error as Error).message}`), + }) + + socket.emit('error', new Error('connect failed')) + + expect(settledError?.message).toBe('wrapped:connect failed') + expect(socket.destroyed).toBe(true) + }) + + test('rejects socket end before response', () => { + const socket = new FakeSocket() + let settledError: Error | undefined + + attachUdsResponseReader(asSocket(socket), { + maxFrameBytes: 128, + onSettled: error => { + settledError = error + }, + }) + + socket.emit('end') + + expect(settledError?.message).toBe('UDS socket ended before response') + expect(socket.destroyed).toBe(true) + }) + + test('rejects clean socket close before response', () => { + const socket = new FakeSocket() + let settledError: Error | undefined + + attachUdsResponseReader(asSocket(socket), { + maxFrameBytes: 128, + onSettled: error => { + settledError = error + }, + }) + + socket.emit('close', false) + + expect(settledError?.message).toBe('UDS socket closed before response') + expect(socket.destroyed).toBe(true) + }) +}) diff --git a/src/utils/messages/systemInit.ts b/src/utils/messages/systemInit.ts index 4585c7817..fcb9e74d1 100644 --- a/src/utils/messages/systemInit.ts +++ b/src/utils/messages/systemInit.ts @@ -87,10 +87,8 @@ export function buildSystemInitMessage(inputs: SystemInitInputs): SDKMessage { // Hidden from public SDK types — ant-only UDS messaging socket path if (feature('UDS_INBOX')) { /* eslint-disable @typescript-eslint/no-require-imports */ - const udsMessaging = - require('../udsMessaging.js') as typeof import('../udsMessaging.js') ;(initMessage as Record).messaging_socket_path = - udsMessaging.getUdsMessagingSocketPath() + require('../udsMessaging.js').getUdsMessagingSocketPath() /* eslint-enable @typescript-eslint/no-require-imports */ } initMessage.fast_mode_state = getFastModeState(inputs.model, inputs.fastMode) diff --git a/src/utils/ndjsonFramer.ts b/src/utils/ndjsonFramer.ts index 7832e9303..69717fc11 100644 --- a/src/utils/ndjsonFramer.ts +++ b/src/utils/ndjsonFramer.ts @@ -10,11 +10,15 @@ import type { Socket } from 'net' export type NdjsonFramerOptions = { maxFrameBytes?: number onFrameError?: (error: Error) => void + destroyOnFrameError?: boolean + onInvalidFrame?: (error: Error) => void + destroyOnInvalidFrame?: boolean } /** * Attach an NDJSON framer to a socket. Calls `onMessage` for each - * complete JSON line received. Malformed lines are silently skipped. + * complete JSON line received. Malformed lines are skipped by default; + * callers may opt into error callbacks or socket destruction. * * @param parse - Optional custom JSON parser (defaults to JSON.parse). * Useful when the caller uses a wrapped parser like jsonParse @@ -35,15 +39,26 @@ export function attachNdjsonFramer( `NDJSON frame exceeded ${maxFrameBytes} bytes (${bytes})`, ) options.onFrameError?.(error) - socket.destroy(error) + if (options.destroyOnFrameError ?? true) { + socket.destroy(error) + } + } + + const rejectInvalidFrame = (error: unknown): void => { + const frameError = + error instanceof Error ? error : new Error('Invalid NDJSON frame') + options.onInvalidFrame?.(frameError) + if (options.destroyOnInvalidFrame ?? false) { + socket.destroy(frameError) + } } const emitLine = (line: string): void => { if (!line.trim()) return try { onMessage(parse(line)) - } catch { - // Malformed JSON — skip + } catch (error) { + rejectInvalidFrame(error) } } diff --git a/src/utils/swarm/inProcessRunner.ts b/src/utils/swarm/inProcessRunner.ts index eaab58ef7..06fde705a 100644 --- a/src/utils/swarm/inProcessRunner.ts +++ b/src/utils/swarm/inProcessRunner.ts @@ -1246,13 +1246,8 @@ export async function runInProcessTeammate( // Track in-progress tool use IDs for animation in transcript view let inProgressToolUseIDs = task.inProgressToolUseIDs if (message.type === 'assistant') { - for (const block of Array.isArray(message.message!.content) - ? message.message!.content - : []) { - if ( - typeof block !== 'string' && - block.type === 'tool_use' - ) { + for (const block of (Array.isArray(message.message!.content) ? message.message!.content : [])) { + if (typeof block !== 'string' && block.type === 'tool_use') { inProgressToolUseIDs = new Set([ ...(inProgressToolUseIDs ?? []), block.id, @@ -1323,10 +1318,7 @@ export async function runInProcessTeammate( setAppState, ) if (currentAutonomyRunId) { - await markAutonomyRunFailed( - currentAutonomyRunId, - ERROR_MESSAGE_USER_ABORT, - ) + await markAutonomyRunFailed(currentAutonomyRunId, ERROR_MESSAGE_USER_ABORT) currentAutonomyRunId = undefined } } else if (currentAutonomyRunId) { diff --git a/src/utils/udsMessaging.ts b/src/utils/udsMessaging.ts index 94b73dcd6..b30cba137 100644 --- a/src/utils/udsMessaging.ts +++ b/src/utils/udsMessaging.ts @@ -82,6 +82,8 @@ export const MAX_UDS_INBOX_ENTRIES = 1_000 export const MAX_UDS_FRAME_BYTES = 64 * 1024 export const MAX_UDS_INBOX_BYTES = 2 * 1024 * 1024 export const MAX_UDS_CLIENTS = 128 +export const UDS_AUTH_TIMEOUT_MS = 2_000 +export const UDS_IDLE_TIMEOUT_MS = 30_000 // --------------------------------------------------------------------------- // Public API — socket path helpers @@ -339,6 +341,43 @@ function writeSocketMessage(socket: Socket, message: UdsMessage): void { socket.write(jsonStringify(message) + '\n') } +function writeSocketMessageAndDestroy(socket: Socket, message: UdsMessage): void { + if (socket.destroyed) return + socket.write(jsonStringify(message) + '\n', () => { + if (!socket.destroyed) socket.destroy() + }) +} + +function writeSocketErrorAndDestroy(socket: Socket, data: string): void { + writeSocketMessageAndDestroy(socket, { + type: 'error', + data, + ts: new Date().toISOString(), + }) +} + +function unrefTimer(timer: ReturnType): void { + const maybeUnref = (timer as { unref?: () => void }).unref + if (typeof maybeUnref === 'function') { + maybeUnref.call(timer) + } +} + +async function closeServer(serverToClose: Server): Promise { + await new Promise(resolve => { + serverToClose.close(() => resolve()) + }) +} + +async function removeSocketPath(path: string): Promise { + if (process.platform === 'win32') return + try { + await unlink(path) + } catch { + // Already gone. + } +} + function stripAuthToken(message: UdsMessage): UdsMessage { const { authToken: _authToken, ...metaWithoutAuth } = message.meta ?? {} return { @@ -391,10 +430,9 @@ export async function startUdsMessaging( } const token = ensureAuthToken() + let startedServer: Server | null = null + let exportedSocketEnv = false try { - await writeCapabilityFile(path, token) - socketPath = path - await new Promise((resolve, reject) => { const srv = createServer(socket => { if (clients.size >= MAX_UDS_CLIENTS) { @@ -408,6 +446,24 @@ export async function startUdsMessaging( logForDebugging( `[udsMessaging] client connected (total: ${clients.size})`, ) + let authenticated = false + let closing = false + const closeWithError = (data: string): void => { + if (closing || socket.destroyed) return + closing = true + socket.pause() + writeSocketErrorAndDestroy(socket, data) + } + const authTimer = setTimeout(() => { + if (authenticated || socket.destroyed) return + logForDebugging('[udsMessaging] closing unauthenticated idle client') + closeWithError('authentication timeout') + }, UDS_AUTH_TIMEOUT_MS) + unrefTimer(authTimer) + socket.setTimeout(UDS_IDLE_TIMEOUT_MS, () => { + logForDebugging('[udsMessaging] closing idle client') + closeWithError('idle timeout') + }) attachNdjsonFramer( socket, @@ -416,17 +472,13 @@ export async function startUdsMessaging( logForDebugging( `[udsMessaging] rejected unauthenticated message type=${msg.type}`, ) - if (!socket.destroyed) { - socket.write( - jsonStringify({ - type: 'error', - data: 'unauthorized', - ts: new Date().toISOString(), - } satisfies UdsMessage) + '\n', - ) - } + closeWithError('unauthorized') return } + if (!authenticated) { + authenticated = true + clearTimeout(authTimer) + } // Handle ping with automatic pong if (msg.type === 'ping') { @@ -447,11 +499,7 @@ export async function startUdsMessaging( status: 'pending', } if (!enqueueInboxEntry(entry)) { - writeSocketMessage(socket, { - type: 'error', - data: 'inbox full', - ts: new Date().toISOString(), - }) + closeWithError('inbox full') return } logForDebugging( @@ -470,21 +518,40 @@ export async function startUdsMessaging( maxFrameBytes: MAX_UDS_FRAME_BYTES, onFrameError: error => { logForDebugging(`[udsMessaging] ${error.message}`) + closeWithError(error.message) }, + onInvalidFrame: error => { + logForDebugging( + `[udsMessaging] invalid client frame: ${errorMessage(error)}`, + ) + closeWithError('invalid frame') + }, + destroyOnFrameError: false, }, ) socket.on('close', () => { + clearTimeout(authTimer) clients.delete(socket) }) socket.on('error', err => { + clearTimeout(authTimer) clients.delete(socket) logForDebugging(`[udsMessaging] client error: ${errorMessage(err)}`) }) }) - srv.on('error', reject) + const rejectBeforeListen = (error: Error): void => { + reject(error) + } + const logRuntimeError = (error: Error): void => { + logForDebugging( + `[udsMessaging] server error on ${path}${opts?.isExplicit ? ' (explicit)' : ''}: ${errorMessage(error)}`, + ) + } + + srv.once('error', rejectBeforeListen) srv.listen(path, () => { void (async () => { @@ -492,19 +559,41 @@ export async function startUdsMessaging( if (process.platform !== 'win32') { await chmod(path, 0o600) } + srv.off('error', rejectBeforeListen) + srv.on('error', logRuntimeError) server = srv - // Export so child processes can discover the socket - process.env.CLAUDE_CODE_MESSAGING_SOCKET = path - logForDebugging( - `[udsMessaging] server listening on ${path}${opts?.isExplicit ? ' (explicit)' : ''}`, - ) + startedServer = srv resolve() } catch (error) { - srv.close(() => reject(error)) + srv.off('error', rejectBeforeListen) + const closeError = + error instanceof Error ? error : new Error(errorMessage(error)) + let rejected = false + const rejectOnce = (): void => { + if (rejected) return + rejected = true + reject(closeError) + } + const fallback = setTimeout(rejectOnce, 1_000) + unrefTimer(fallback) + srv.close(() => { + clearTimeout(fallback) + rejectOnce() + }) } })() }) }) + + await writeCapabilityFile(path, token) + socketPath = path + // Export so child processes can discover the socket only after the + // capability file exists and the listener is ready. + process.env.CLAUDE_CODE_MESSAGING_SOCKET = path + exportedSocketEnv = true + logForDebugging( + `[udsMessaging] server listening on ${path}${opts?.isExplicit ? ' (explicit)' : ''}`, + ) } catch (error) { if (capabilityFilePath) { try { @@ -514,7 +603,18 @@ export async function startUdsMessaging( } capabilityFilePath = null } + if (startedServer) { + await closeServer(startedServer) + } + if (server === startedServer) { + server = null + } + await removeSocketPath(path) + if (exportedSocketEnv) { + delete process.env.CLAUDE_CODE_MESSAGING_SOCKET + } socketPath = null + defaultSocketPath = null authToken = null throw error } @@ -529,6 +629,7 @@ export async function startUdsMessaging( * Stop the UDS messaging server and clean up the socket file. */ export async function stopUdsMessaging(): Promise { + defaultSocketPath = null if (!server) return // Close all connected clients @@ -547,13 +648,7 @@ export async function stopUdsMessaging(): Promise { // Remove socket file (skip on Windows — pipe paths aren't files) if (socketPath) { - if (process.platform !== 'win32') { - try { - await unlink(socketPath) - } catch { - // Already gone - } - } + await removeSocketPath(socketPath) delete process.env.CLAUDE_CODE_MESSAGING_SOCKET logForDebugging( `[udsMessaging] server stopped, socket removed: ${socketPath}`, diff --git a/src/utils/udsResponseReader.ts b/src/utils/udsResponseReader.ts index bb8d21f40..d86328aab 100644 --- a/src/utils/udsResponseReader.ts +++ b/src/utils/udsResponseReader.ts @@ -1,4 +1,5 @@ import type { Socket } from 'net' +import { StringDecoder } from 'node:string_decoder' import { errorMessage } from './errors.js' import { jsonParse } from './slowOperations.js' import type { UdsMessage } from './udsMessaging.js' @@ -16,11 +17,11 @@ export function getChunkBytes(chunk: string | Buffer): number { : chunk.byteLength } -function parseResponseLine(line: string): UdsMessage | null { +function parseResponseLine(line: string): UdsMessage { try { return jsonParse(line) as UdsMessage } catch { - return null + throw new Error('Invalid UDS response frame') } } @@ -29,35 +30,58 @@ export function attachUdsResponseReader( options: UdsResponseReaderOptions, ): void { let buffer = '' + let bufferBytes = 0 let settled = false + const decoder = new StringDecoder('utf8') - const finish = (error?: Error): void => { + function cleanupListeners(): void { + socket.off('data', onData) + socket.off('error', onError) + socket.off('end', onEnd) + socket.off('close', onClose) + } + + function finish(error?: Error): void { if (settled) return settled = true + buffer = '' + bufferBytes = 0 + cleanupListeners() if (error) { - socket.destroy(error) + socket.destroy() } else { socket.end() } options.onSettled(error) } - socket.on('data', chunk => { - if ( - Buffer.byteLength(buffer, 'utf8') + getChunkBytes(chunk) > - options.maxFrameBytes - ) { + function onData(chunk: Buffer): void { + const decoded = decoder.write(chunk) + const decodedBytes = Buffer.byteLength(decoded, 'utf8') + if (bufferBytes + decodedBytes > options.maxFrameBytes) { finish(new Error('UDS response frame exceeded size limit')) return } - buffer += chunk.toString() - const lines = buffer.split('\n') - buffer = lines.pop() ?? '' - for (const line of lines) { - if (!line.trim()) continue - const response = parseResponseLine(line) - if (!response) continue + buffer += decoded + bufferBytes += decodedBytes + let newlineIndex = buffer.indexOf('\n') + while (newlineIndex !== -1) { + const line = buffer.slice(0, newlineIndex) + const consumed = buffer.slice(0, newlineIndex + 1) + buffer = buffer.slice(newlineIndex + 1) + bufferBytes -= Buffer.byteLength(consumed, 'utf8') + if (!line.trim()) { + newlineIndex = buffer.indexOf('\n') + continue + } + let response: UdsMessage + try { + response = parseResponseLine(line) + } catch (error) { + finish(error instanceof Error ? error : new Error(errorMessage(error))) + return + } if ( response.type === 'response' || (options.acceptPong === true && response.type === 'pong') @@ -69,13 +93,28 @@ export function attachUdsResponseReader( finish(new Error(response.data ?? 'UDS receiver rejected message')) return } + newlineIndex = buffer.indexOf('\n') } - }) + } - socket.on('error', error => { + function onError(error: Error): void { finish( options.formatSocketError?.(error) ?? (error instanceof Error ? error : new Error(errorMessage(error))), ) - }) + } + + function onEnd(): void { + finish(new Error('UDS socket ended before response')) + } + + function onClose(hadError: boolean): void { + if (hadError) return + finish(new Error('UDS socket closed before response')) + } + + socket.on('data', onData) + socket.on('error', onError) + socket.on('end', onEnd) + socket.on('close', onClose) }