diff --git a/packages/builtin-tools/src/tools/ListPeersTool/ListPeersTool.ts b/packages/builtin-tools/src/tools/ListPeersTool/ListPeersTool.ts index e520243a5..48219edc7 100644 --- a/packages/builtin-tools/src/tools/ListPeersTool/ListPeersTool.ts +++ b/packages/builtin-tools/src/tools/ListPeersTool/ListPeersTool.ts @@ -85,21 +85,35 @@ Use this tool to discover messaging targets before sending cross-session message // and optionally includes Remote Control bridge peers. const peers: PeerInfo[] = [] - // Discovery is handled by the UDS messaging subsystem initialized in setup.ts. - // Return discovered peers from the app state. - const appState = context.getAppState() - const messagingSocketPath = (appState as Record).messagingSocketPath as string | undefined + /* eslint-disable @typescript-eslint/no-require-imports */ + const udsMessaging = + require('src/utils/udsMessaging.js') as typeof import('src/utils/udsMessaging.js') + const udsClient = + require('src/utils/udsClient.js') as typeof import('src/utils/udsClient.js') + /* eslint-enable @typescript-eslint/no-require-imports */ + + const messagingSocketPath = udsMessaging.getUdsMessagingSocketPath() if (messagingSocketPath) { // Self entry for reference if (_input.include_self) { peers.push({ - address: `uds:${messagingSocketPath}`, + address: udsMessaging.formatUdsAddress(messagingSocketPath), name: 'self', pid: process.pid, }) } } + for (const peer of await udsClient.listPeers()) { + if (!peer.messagingSocketPath) continue + peers.push({ + address: udsMessaging.formatUdsAddress(peer.messagingSocketPath), + name: peer.name ?? peer.kind, + cwd: peer.cwd, + pid: peer.pid, + }) + } + return { data: { peers }, } diff --git a/packages/builtin-tools/src/tools/SendMessageTool/SendMessageTool.ts b/packages/builtin-tools/src/tools/SendMessageTool/SendMessageTool.ts index 4e9737051..3548544fc 100644 --- a/packages/builtin-tools/src/tools/SendMessageTool/SendMessageTool.ts +++ b/packages/builtin-tools/src/tools/SendMessageTool/SendMessageTool.ts @@ -130,6 +130,43 @@ export type SendMessageToolOutput = | RequestOutput | ResponseOutput +const UDS_INLINE_TOKEN_MARKER = '#token=' +const UDS_INLINE_TOKEN_REJECTED_KEY = '__udsInlineTokenRejected' + +function stripInlineUdsToken(target: string): string { + const markerIndex = target.lastIndexOf(UDS_INLINE_TOKEN_MARKER) + return markerIndex === -1 ? target : target.slice(0, markerIndex) +} + +function hasInlineUdsToken(to: string): boolean { + const addr = parseAddress(to) + return ( + addr.scheme === 'uds' && addr.target.includes(UDS_INLINE_TOKEN_MARKER) + ) +} + +function recipientForDisplay(to: string): string { + const addr = parseAddress(to) + if (addr.scheme !== 'uds') return to + return `uds:${stripInlineUdsToken(addr.target)}` +} + +function markAndRedactInlineUdsToken( + input: { to: string } & Record, +): void { + if (!hasInlineUdsToken(input.to)) return + input.to = recipientForDisplay(input.to) + input[UDS_INLINE_TOKEN_REJECTED_KEY] = true +} + +function wasInlineUdsTokenRejected(input: unknown): boolean { + return ( + typeof input === 'object' && + input !== null && + (input as Record)[UDS_INLINE_TOKEN_REJECTED_KEY] === true + ) +} + function findTeammateColor( appState: { teamContext?: { teammates: { [id: string]: { color?: string } } } @@ -541,15 +578,19 @@ export const SendMessageTool: Tool = }, backfillObservableInput(input) { - if ('type' in input) return if (typeof input.to !== 'string') return + markAndRedactInlineUdsToken( + input as { to: string } & Record, + ) + if ('type' in input) return + if (input.to === '*') { input.type = 'broadcast' if (typeof input.message === 'string') input.content = input.message } else if (typeof input.message === 'string') { input.type = 'message' - input.recipient = input.to + input.recipient = recipientForDisplay(input.to) input.content = input.message } else if (typeof input.message === 'object' && input.message !== null) { const msg = input.message as { @@ -560,7 +601,7 @@ export const SendMessageTool: Tool = feedback?: string } input.type = msg.type - input.recipient = input.to + input.recipient = recipientForDisplay(input.to) if (msg.request_id !== undefined) input.request_id = msg.request_id if (msg.approve !== undefined) input.approve = msg.approve const content = msg.reason ?? msg.feedback @@ -569,16 +610,17 @@ export const SendMessageTool: Tool = }, toAutoClassifierInput(input) { + const recipient = recipientForDisplay(input.to) if (typeof input.message === 'string') { - return `to ${input.to}: ${input.message}` + return `to ${recipient}: ${input.message}` } switch (input.message.type) { case 'shutdown_request': - return `shutdown_request to ${input.to}` + return `shutdown_request to ${recipient}` case 'shutdown_response': return `shutdown_response ${input.message.approve ? 'approve' : 'reject'} ${input.message.request_id}` case 'plan_approval_response': - return `plan_approval ${input.message.approve ? 'approve' : 'reject'} to ${input.to}` + return `plan_approval ${input.message.approve ? 'approve' : 'reject'} to ${recipient}` } }, @@ -630,6 +672,19 @@ export const SendMessageTool: Tool = errorCode: 9, } } + if (feature('UDS_INBOX')) { + if ( + addr.scheme === 'uds' && + (hasInlineUdsToken(input.to) || wasInlineUdsTokenRejected(input)) + ) { + return { + result: false, + message: + 'uds addresses must not include inline auth tokens; use the ListPeers address', + errorCode: 9, + } + } + } if (input.to.includes('@')) { return { result: false, @@ -787,6 +842,16 @@ export const SendMessageTool: Tool = } } if (addr.scheme === 'uds') { + const recipient = recipientForDisplay(input.to) + if (hasInlineUdsToken(input.to) || wasInlineUdsTokenRejected(input)) { + return { + data: { + success: false, + message: + 'uds addresses must not include inline auth tokens; use the ListPeers address', + }, + } + } /* eslint-disable @typescript-eslint/no-require-imports */ const { sendToUdsSocket } = require('src/utils/udsClient.js') as typeof import('src/utils/udsClient.js') @@ -797,14 +862,14 @@ export const SendMessageTool: Tool = return { data: { success: true, - message: `”${preview}” → ${input.to}`, + message: `”${preview}” → ${recipient}`, }, } } catch (e) { return { data: { success: false, - message: `Failed to send to ${input.to}: ${errorMessage(e)}`, + message: `Failed to send to ${recipient}: ${errorMessage(e)}`, }, } } diff --git a/packages/builtin-tools/src/tools/SendMessageTool/__tests__/udsRecipientSanitization.test.ts b/packages/builtin-tools/src/tools/SendMessageTool/__tests__/udsRecipientSanitization.test.ts new file mode 100644 index 000000000..20124b6c3 --- /dev/null +++ b/packages/builtin-tools/src/tools/SendMessageTool/__tests__/udsRecipientSanitization.test.ts @@ -0,0 +1,41 @@ +import { describe, expect, mock, test } from 'bun:test' + +mock.module('bun:bundle', () => ({ + feature: (name: string) => name === 'UDS_INBOX', +})) + +describe('SendMessageTool UDS recipient handling', () => { + test('redacts inline UDS tokens before classifier and observable paths', async () => { + const { SendMessageTool } = await import('../SendMessageTool.js') + const tokenAddress = 'uds:/tmp/peer.sock#token=secret-token' + + const observableInput = { + to: tokenAddress, + message: 'hello', + } as Record + SendMessageTool.backfillObservableInput!(observableInput) + + expect(observableInput.recipient).toBe('uds:/tmp/peer.sock') + expect(JSON.stringify(observableInput)).not.toContain('secret-token') + expect( + SendMessageTool.toAutoClassifierInput({ + to: tokenAddress, + message: 'hello', + }), + ).toBe('to uds:/tmp/peer.sock: hello') + }) + + test('rejects inline UDS tokens during validation', async () => { + const { SendMessageTool } = await import('../SendMessageTool.js') + const result = await SendMessageTool.validateInput!( + { + to: 'uds:/tmp/peer.sock#token=secret-token', + message: 'hello', + }, + {} as never, + ) + + expect(result.result).toBe(false) + expect(JSON.stringify(result)).not.toContain('secret-token') + }) +}) diff --git a/src/cli/print.ts b/src/cli/print.ts index eb2b543f8..7a291fb50 100644 --- a/src/cli/print.ts +++ b/src/cli/print.ts @@ -2763,13 +2763,37 @@ function runHeadlessStreaming( // when a message arrives via the UDS socket in headless mode. if (feature('UDS_INBOX')) { /* eslint-disable @typescript-eslint/no-require-imports */ - const { setOnEnqueue } = require('../utils/udsMessaging.js') + const { drainInbox, setOnEnqueue } = + require('../utils/udsMessaging.js') as typeof import('../utils/udsMessaging.js') /* eslint-enable @typescript-eslint/no-require-imports */ + + const enqueueUdsInboxMessages = (): boolean => { + const entries = drainInbox() + for (const entry of entries) { + const value = + typeof entry.message.data === 'string' + ? entry.message.data + : jsonStringify(entry.message) + enqueue({ + mode: 'prompt', + value, + uuid: randomUUID(), + }) + } + return entries.length > 0 + } + setOnEnqueue(() => { if (!inputClosed) { - void run() + if (enqueueUdsInboxMessages()) { + void run() + } } }) + + if (enqueueUdsInboxMessages()) { + void run() + } } // Cron scheduler: runs scheduled_tasks.json tasks in SDK/-p mode. diff --git a/src/commands/peers/peers.ts b/src/commands/peers/peers.ts index aed37d327..fcb7e17a7 100644 --- a/src/commands/peers/peers.ts +++ b/src/commands/peers/peers.ts @@ -1,6 +1,9 @@ import type { LocalCommandCall } from '../../types/command.js' import { listPeers, isPeerAlive } from '../../utils/udsClient.js' -import { getUdsMessagingSocketPath } from '../../utils/udsMessaging.js' +import { + formatUdsAddress, + getUdsMessagingSocketPath, +} from '../../utils/udsMessaging.js' export const call: LocalCommandCall = async (_args, _context) => { const mySocket = getUdsMessagingSocketPath() @@ -29,11 +32,11 @@ export const call: LocalCommandCall = async (_args, _context) => { ? ` started: ${formatAge(peer.startedAt)}` : '' - lines.push( - ` [${status}] PID ${peer.pid} (${label})${cwd}${age}`, - ) + lines.push(` [${status}] PID ${peer.pid} (${label})${cwd}${age}`) if (peer.messagingSocketPath) { - lines.push(` socket: ${peer.messagingSocketPath}`) + lines.push( + ` socket: ${formatUdsAddress(peer.messagingSocketPath)}`, + ) } if (peer.sessionId) { lines.push(` session: ${peer.sessionId}`) @@ -43,7 +46,7 @@ export const call: LocalCommandCall = async (_args, _context) => { lines.push('') lines.push( - 'To message a peer: use SendMessage with to="uds:"', + 'To message a peer: use SendMessage with the shown uds: address', ) return { type: 'text', value: lines.join('\n') } diff --git a/src/services/AgentSummary/__tests__/summaryContext.test.ts b/src/services/AgentSummary/__tests__/summaryContext.test.ts new file mode 100644 index 000000000..3ffa55964 --- /dev/null +++ b/src/services/AgentSummary/__tests__/summaryContext.test.ts @@ -0,0 +1,132 @@ +import { describe, expect, test } from 'bun:test' +import type { Message } from '../../../types/message.js' +import { + getSummaryContextFingerprint, + selectSummaryContextMessages, +} from '../summaryContext.js' + +function makeMessage( + type: 'user' | 'assistant', + uuid: string, + content: string, +): Message { + return { + type, + uuid, + message: { + role: type, + content, + }, + } as unknown as Message +} + +describe('selectSummaryContextMessages', () => { + test('keeps a bounded recent suffix that starts with a user message', () => { + const messages = [ + makeMessage('assistant', 'a0', 'older assistant'), + makeMessage('user', 'u1', 'first prompt'), + makeMessage('assistant', 'a1', 'first response'), + makeMessage('user', 'u2', 'second prompt'), + makeMessage('assistant', 'a2', 'second response'), + ] + + const selected = selectSummaryContextMessages(messages, { + maxMessages: 3, + maxChars: 1_000, + }) + + expect(selected.map(message => String(message.uuid))).toEqual(['u2', 'a2']) + }) + + test('returns no context when the newest message exceeds the byte budget', () => { + const messages = [ + makeMessage('user', 'u1', 'first prompt'), + makeMessage('assistant', 'a1', 'x'.repeat(100)), + ] + + const selected = selectSummaryContextMessages(messages, { + maxMessages: 10, + maxChars: 10, + }) + + expect(selected).toEqual([]) + }) + + test('uses serialized message size for nested content budgets', () => { + const messages = [ + makeMessage('user', 'u1', 'first prompt'), + { + ...makeMessage('assistant', 'a1', 'short'), + nested: { + payload: Array.from({ length: 50 }, (_value, index) => ({ + index, + text: 'x'.repeat(20), + })), + }, + } as unknown as Message, + ] + + const selected = selectSummaryContextMessages(messages, { + maxMessages: 10, + maxChars: 200, + }) + + expect(selected).toEqual([]) + }) + + test('drops leading orphan tool results after bounding', () => { + const messages = [ + makeMessage('assistant', 'a0', 'older assistant'), + { + type: 'user', + uuid: 'u1', + message: { + role: 'user', + content: [ + { type: 'tool_result', tool_use_id: 'tool-1', content: 'ok' }, + ], + }, + } as unknown as Message, + makeMessage('assistant', 'a1', 'after orphan'), + makeMessage('user', 'u2', 'next prompt'), + ] + + const selected = selectSummaryContextMessages(messages, { + maxMessages: 3, + maxChars: 1_000, + }) + + expect(selected.map(message => String(message.uuid))).toEqual(['u2']) + }) +}) + +describe('getSummaryContextFingerprint', () => { + test('changes when the transcript grows', () => { + const messages = [ + makeMessage('user', 'u1', 'first prompt'), + makeMessage('assistant', 'a1', 'first response'), + ] + + const first = getSummaryContextFingerprint(messages) + const second = getSummaryContextFingerprint([ + ...messages, + makeMessage('user', 'u2', 'next prompt'), + ]) + expect(first?.startsWith('2:a1:')).toBe(true) + expect(second?.startsWith('3:u2:')).toBe(true) + expect(first).not.toBe(second) + }) + + test('changes when message content changes under the same uuid', () => { + const first = getSummaryContextFingerprint([ + makeMessage('user', 'u1', 'first prompt'), + makeMessage('assistant', 'a1', 'first response'), + ]) + const second = getSummaryContextFingerprint([ + makeMessage('user', 'u1', 'first prompt'), + makeMessage('assistant', 'a1', 'updated response'), + ]) + + expect(first).not.toBe(second) + }) +}) diff --git a/src/services/AgentSummary/agentSummary.ts b/src/services/AgentSummary/agentSummary.ts index 50146b3c7..2232e839d 100644 --- a/src/services/AgentSummary/agentSummary.ts +++ b/src/services/AgentSummary/agentSummary.ts @@ -23,6 +23,10 @@ import { import { logError } from '../../utils/log.js' import { createUserMessage } from '../../utils/messages.js' import { getAgentTranscript } from '../../utils/sessionStorage.js' +import { + getSummaryContextFingerprint, + selectSummaryContextMessages, +} from './summaryContext.js' const SUMMARY_INTERVAL_MS = 30_000 @@ -58,6 +62,7 @@ export function startAgentSummarization( let timeoutId: ReturnType | null = null let stopped = false let previousSummary: string | null = null + let lastHandledTranscriptFingerprint: string | null = null async function runSummary(): Promise { if (stopped) return @@ -82,15 +87,35 @@ export function startAgentSummarization( // Filter to clean message state const cleanMessages = filterIncompleteToolCalls(transcript.messages) + const summaryContext = filterIncompleteToolCalls( + selectSummaryContextMessages(cleanMessages), + ) + const transcriptFingerprint = getSummaryContextFingerprint(summaryContext) + if ( + transcriptFingerprint && + transcriptFingerprint === lastHandledTranscriptFingerprint + ) { + logForDebugging( + `[AgentSummary] Skipping summary for ${taskId}: transcript unchanged`, + ) + return + } + + if (summaryContext.length < 3) { + logForDebugging( + `[AgentSummary] Skipping summary for ${taskId}: no bounded context available`, + ) + return + } // Build fork params with current messages const forkParams: CacheSafeParams = { ...baseParams, - forkContextMessages: cleanMessages, + forkContextMessages: summaryContext, } logForDebugging( - `[AgentSummary] Forking for summary, ${cleanMessages.length} messages in context`, + `[AgentSummary] Forking for summary, ${summaryContext.length} messages in context`, ) // Create abort controller for this summary @@ -136,13 +161,16 @@ export function startAgentSummarization( ) continue } - const contentArr = Array.isArray(msg.message!.content) ? msg.message!.content : [] + const contentArr = Array.isArray(msg.message!.content) + ? msg.message!.content + : [] const textBlock = contentArr.find(b => b.type === 'text') if (textBlock?.type === 'text' && textBlock.text.trim()) { const summaryText = textBlock.text.trim() logForDebugging( `[AgentSummary] Summary result for ${taskId}: ${summaryText}`, ) + lastHandledTranscriptFingerprint = transcriptFingerprint previousSummary = summaryText updateAgentSummary(taskId, summaryText, setAppState) break diff --git a/src/services/AgentSummary/summaryContext.ts b/src/services/AgentSummary/summaryContext.ts new file mode 100644 index 000000000..4d9f6a6ce --- /dev/null +++ b/src/services/AgentSummary/summaryContext.ts @@ -0,0 +1,175 @@ +import { createHash } from 'crypto' +import type { Message } from '../../types/message.js' + +export const MAX_SUMMARY_CONTEXT_MESSAGES = 120 +export const MAX_SUMMARY_CONTEXT_CHARS = 200_000 + +function estimateJsonChars( + value: unknown, + limit: number, + seen = new Set(), +): number { + if (value === null) return 4 + switch (typeof value) { + case 'string': + return value.length + 2 + case 'number': + case 'boolean': + return String(value).length + case 'undefined': + case 'function': + case 'symbol': + return 0 + case 'object': { + if (seen.has(value)) return Number.POSITIVE_INFINITY + seen.add(value) + let total = 2 + if (Array.isArray(value)) { + for (let index = 0; index < value.length; index++) { + total += String(index).length + 3 + total += estimateJsonChars(value[index], limit - total, seen) + if (total > limit) return total + } + } else { + const record = value as Record + for (const key in record) { + if (!Object.hasOwn(record, key)) continue + total += key.length + 3 + total += estimateJsonChars(record[key], limit - total, seen) + if (total > limit) return total + } + } + seen.delete(value) + return total + } + } + return 0 +} + +function updateFingerprintHash( + hash: ReturnType, + value: unknown, + limit: { remaining: number }, + seen = new Set(), +): void { + if (limit.remaining <= 0) return + if (value === null || typeof value !== 'object') { + const text = String(value) + hash.update(typeof value) + hash.update(':') + hash.update(text.slice(0, limit.remaining)) + limit.remaining -= text.length + return + } + if (seen.has(value)) { + hash.update('[Circular]') + return + } + seen.add(value) + if (Array.isArray(value)) { + for (let index = 0; index < value.length; index++) { + if (limit.remaining <= 0) break + const key = String(index) + hash.update(key) + limit.remaining -= key.length + updateFingerprintHash(hash, value[index], limit, seen) + } + } else { + const record = value as Record + for (const key in record) { + if (limit.remaining <= 0) break + if (!Object.hasOwn(record, key)) continue + hash.update(key) + limit.remaining -= key.length + updateFingerprintHash(hash, record[key], limit, seen) + } + } + seen.delete(value) +} + +export function estimateMessageChars( + message: Message, + limit = Number.POSITIVE_INFINITY, +): number { + const estimated = estimateJsonChars(message, limit) + if (!Number.isFinite(estimated)) { + return Number.POSITIVE_INFINITY + } + return estimated +} + +function hasToolResultBlock(message: Message): boolean { + if (message.type !== 'user') return false + const content = message.message?.content + return ( + Array.isArray(content) && + content.some(block => { + return Boolean( + block && + typeof block === 'object' && + 'type' in block && + block.type === 'tool_result', + ) + }) + ) +} + +export function getSummaryContextFingerprint( + messages: Message[], +): string | null { + const lastMessage = messages.at(-1) + if (!lastMessage) return null + const hash = createHash('sha256') + updateFingerprintHash(hash, messages, { + remaining: MAX_SUMMARY_CONTEXT_CHARS, + }) + return `${messages.length}:${lastMessage.uuid}:${hash.digest('hex').slice(0, 16)}` +} + +export function selectSummaryContextMessages( + messages: Message[], + limits: { + maxMessages?: number + maxChars?: number + } = {}, +): Message[] { + const maxMessages = limits.maxMessages ?? MAX_SUMMARY_CONTEXT_MESSAGES + const maxChars = limits.maxChars ?? MAX_SUMMARY_CONTEXT_CHARS + if (maxMessages <= 0 || maxChars <= 0) return [] + + const selected: Message[] = [] + let selectedChars = 0 + + for (let i = messages.length - 1; i >= 0; i--) { + const message = messages[i] + if (!message) continue + + const messageChars = estimateMessageChars(message, maxChars - selectedChars) + if (messageChars > maxChars) { + if (selected.length === 0) return [] + break + } + + if ( + selected.length >= maxMessages || + selectedChars + messageChars > maxChars + ) { + break + } + + selected.unshift(message) + selectedChars += messageChars + } + + while (selected.length > 0) { + const first = selected[0] + if (!first) break + if (first.type !== 'user' || hasToolResultBlock(first)) { + selected.shift() + continue + } + break + } + + return selected +} diff --git a/src/utils/__tests__/ndjsonFramer.test.ts b/src/utils/__tests__/ndjsonFramer.test.ts new file mode 100644 index 000000000..35174162a --- /dev/null +++ b/src/utils/__tests__/ndjsonFramer.test.ts @@ -0,0 +1,91 @@ +import { EventEmitter } from 'node:events' +import type { Socket } from 'node:net' +import { describe, expect, test } from 'bun:test' +import { attachNdjsonFramer } from '../ndjsonFramer.js' + +type TestSocket = Socket & { + destroyed: boolean + emitData: (chunk: Buffer) => void +} + +function createTestSocket(): TestSocket { + const emitter = new EventEmitter() as TestSocket + emitter.destroyed = false + emitter.destroy = ((_error?: Error) => { + emitter.destroyed = true + emitter.emit('close') + return emitter + }) as TestSocket['destroy'] + emitter.emitData = (chunk: Buffer) => { + emitter.emit('data', chunk) + } + return emitter +} + +describe('attachNdjsonFramer', () => { + test('accepts a complete frame at the configured byte limit', () => { + const socket = createTestSocket() + const messages: unknown[] = [] + const errors: Error[] = [] + + attachNdjsonFramer( + socket, + msg => messages.push(msg), + text => JSON.parse(text) as unknown, + { + maxFrameBytes: Buffer.byteLength('{"a":1}', 'utf8'), + onFrameError: error => errors.push(error), + }, + ) + + socket.emitData(Buffer.from('{"a":1}\n')) + + expect(messages).toEqual([{ a: 1 }]) + expect(errors).toEqual([]) + expect(socket.destroyed).toBe(false) + }) + + test('destroys a complete frame over the configured byte limit', () => { + const socket = createTestSocket() + const messages: unknown[] = [] + const errors: Error[] = [] + + attachNdjsonFramer( + socket, + msg => messages.push(msg), + text => JSON.parse(text) as unknown, + { + maxFrameBytes: 8, + onFrameError: error => errors.push(error), + }, + ) + + socket.emitData(Buffer.from('{"long":true}\n')) + + expect(messages).toEqual([]) + expect(errors[0]?.message).toContain('NDJSON frame exceeded') + expect(socket.destroyed).toBe(true) + }) + + test('destroys oversized no-newline input before a frame can form', () => { + const socket = createTestSocket() + const messages: unknown[] = [] + const errors: Error[] = [] + + attachNdjsonFramer( + socket, + msg => messages.push(msg), + text => JSON.parse(text) as unknown, + { + maxFrameBytes: 8, + onFrameError: error => errors.push(error), + }, + ) + + socket.emitData(Buffer.from('x'.repeat(9))) + + expect(messages).toEqual([]) + expect(errors[0]?.message).toContain('NDJSON frame exceeded') + expect(socket.destroyed).toBe(true) + }) +}) diff --git a/src/utils/__tests__/teammateMailbox.test.ts b/src/utils/__tests__/teammateMailbox.test.ts new file mode 100644 index 000000000..577c4331f --- /dev/null +++ b/src/utils/__tests__/teammateMailbox.test.ts @@ -0,0 +1,310 @@ +import { afterEach, beforeEach, describe, expect, test } from 'bun:test' +import { mkdir, readFile, rm, writeFile } from 'node:fs/promises' +import { mkdtempSync } from 'node:fs' +import { tmpdir } from 'node:os' +import { dirname, join } from 'node:path' +import { + compactMailboxMessages, + getInboxPath, + markMessageAsReadByIndex, + markMessageAsReadByIdentity, + markMessagesAsRead, + markMessagesAsReadByPredicate, + MAX_MAILBOX_MESSAGE_TEXT_BYTES, + MAX_MAILBOX_MESSAGES, + MAX_READ_MAILBOX_MESSAGES, + MAX_UNREAD_PROTOCOL_MAILBOX_MESSAGES, + readMailbox, + type TeammateMessage, + writeToMailbox, +} from '../teammateMailbox.js' + +let tempHome = '' +let previousConfigDir: string | undefined + +function message( + text: string, + read: boolean, + timestamp = new Date(0).toISOString(), +): TeammateMessage { + return { + from: 'team-lead', + text, + timestamp, + read, + } +} + +async function seedMailbox( + agentName: string, + teamName: string, + messages: TeammateMessage[], +): Promise { + const inboxPath = getInboxPath(agentName, teamName) + await mkdir(dirname(inboxPath), { recursive: true }) + await writeFile(inboxPath, JSON.stringify(messages, null, 2), 'utf-8') +} + +async function readRawMailbox( + agentName: string, + teamName: string, +): Promise { + const content = await readFile(getInboxPath(agentName, teamName), 'utf-8') + return JSON.parse(content) as TeammateMessage[] +} + +beforeEach(() => { + previousConfigDir = process.env.CLAUDE_CONFIG_DIR + tempHome = mkdtempSync(join(tmpdir(), 'teammate-mailbox-')) + process.env.CLAUDE_CONFIG_DIR = tempHome +}) + +afterEach(async () => { + if (previousConfigDir === undefined) { + delete process.env.CLAUDE_CONFIG_DIR + } else { + process.env.CLAUDE_CONFIG_DIR = previousConfigDir + } + await rm(tempHome, { recursive: true, force: true }) +}) + +describe('compactMailboxMessages', () => { + test('prioritizes unread messages and keeps only recent read history', () => { + const compacted = compactMailboxMessages( + [ + message('read-1', true), + message('read-2', true), + message('unread-1', false), + message('read-3', true), + message('unread-2', false), + message('read-4', true), + message('read-5', true), + message('unread-3', false), + ], + { maxMessages: 5, maxReadMessages: 2 }, + ) + + expect(compacted.map(m => m.text)).toEqual([ + 'unread-1', + 'unread-2', + 'read-4', + 'read-5', + 'unread-3', + ]) + }) + + test('retains unread protocol messages separately from regular cap', () => { + const protocol = message( + JSON.stringify({ type: 'permission_response', request_id: 'req-1' }), + false, + ) + const compacted = compactMailboxMessages( + [ + protocol, + ...Array.from({ length: 5 }, (_value, index) => + message(`regular-${index}`, false), + ), + ], + { + maxMessages: 2, + maxReadMessages: 0, + maxUnreadProtocolMessages: 1, + }, + ) + + expect(compacted.map(m => m.text)).toEqual([ + protocol.text, + 'regular-3', + 'regular-4', + ]) + }) + + test('caps unread protocol messages with an independent bound', () => { + const compacted = compactMailboxMessages( + Array.from( + { length: MAX_UNREAD_PROTOCOL_MAILBOX_MESSAGES + 1 }, + (_value, index) => + message( + JSON.stringify({ + type: 'permission_response', + request_id: `req-${index}`, + }), + false, + ), + ), + ) + + expect(compacted).toHaveLength(MAX_UNREAD_PROTOCOL_MAILBOX_MESSAGES) + expect(compacted[0]?.text).toContain('req-1') + }) + + test('keeps retained mailbox bytes under an explicit budget', () => { + const compacted = compactMailboxMessages( + Array.from({ length: 20 }, (_value, index) => + message(`msg-${index}-${'x'.repeat(200)}`, false), + ), + { + maxMessages: 20, + maxReadMessages: 0, + maxRetainedBytes: 1_000, + }, + ) + + expect( + Buffer.byteLength(JSON.stringify(compacted), 'utf8'), + ).toBeLessThanOrEqual(1_000) + expect(compacted.length).toBeLessThan(20) + expect(compacted.at(-1)?.text).toContain('msg-19') + }) +}) + +describe('teammate mailbox retention', () => { + test('writeToMailbox compacts oversized unread inbox files', async () => { + const existing = Array.from( + { length: MAX_MAILBOX_MESSAGES + 20 }, + (_value, index) => message(`old-${index}`, false), + ) + await seedMailbox('worker', 'alpha', existing) + + await writeToMailbox( + 'worker', + { + from: 'team-lead', + text: 'newest', + timestamp: new Date(1).toISOString(), + }, + 'alpha', + ) + + const after = await readMailbox('worker', 'alpha') + expect(after).toHaveLength(MAX_MAILBOX_MESSAGES) + expect(after[0]?.text).toBe('old-21') + expect(after.at(-1)?.text).toBe('newest') + }) + + test('markMessagesAsRead compacts read history after consumption', async () => { + const existing = Array.from( + { length: MAX_MAILBOX_MESSAGES + 20 }, + (_value, index) => message(`msg-${index}`, false), + ) + await seedMailbox('worker', 'alpha', existing) + + await markMessagesAsRead('worker', 'alpha') + + const after = await readRawMailbox('worker', 'alpha') + expect(after).toHaveLength(MAX_READ_MAILBOX_MESSAGES) + expect(after.every(m => m.read)).toBe(true) + expect(after[0]?.text).toBe( + `msg-${MAX_MAILBOX_MESSAGES + 20 - MAX_READ_MAILBOX_MESSAGES}`, + ) + }) + + test('markMessagesAsReadByPredicate leaves structured messages unread', async () => { + await seedMailbox('worker', 'alpha', [ + message('plain', false), + message(JSON.stringify({ type: 'permission_request' }), false), + ]) + + await markMessagesAsReadByPredicate( + 'worker', + m => !m.text.includes('permission_request'), + 'alpha', + ) + + const after = await readRawMailbox('worker', 'alpha') + expect(after.map(m => m.read)).toEqual([true, false]) + }) + + test('markMessageAsReadByIdentity survives compaction shifting indexes', async () => { + const permissionResponse = message( + JSON.stringify({ type: 'permission_response', request_id: 'req-1' }), + false, + ) + await seedMailbox('worker', 'alpha', [ + permissionResponse, + ...Array.from({ length: MAX_MAILBOX_MESSAGES + 20 }, (_value, index) => + message(`regular-${index}`, false), + ), + ]) + + await writeToMailbox( + 'worker', + { + from: 'team-lead', + text: 'newest', + timestamp: new Date(2).toISOString(), + }, + 'alpha', + ) + const marked = await markMessageAsReadByIdentity( + 'worker', + 'alpha', + permissionResponse, + ) + + const after = await readRawMailbox('worker', 'alpha') + expect(marked).toBe(true) + expect(after.some(m => m.text === permissionResponse.text && !m.read)).toBe( + false, + ) + }) + + test('markMessageAsReadByIndex also compacts through the compatibility path', async () => { + const existing = Array.from( + { length: MAX_MAILBOX_MESSAGES + 10 }, + (_value, index) => message(`msg-${index}`, false), + ) + await seedMailbox('worker', 'alpha', existing) + + await markMessageAsReadByIndex('worker', 'alpha', existing.length - 1) + + const after = await readRawMailbox('worker', 'alpha') + expect(after).toHaveLength(MAX_MAILBOX_MESSAGES) + expect(after.some(m => m.text === `msg-${existing.length - 1}`)).toBe(false) + expect(after.at(-1)?.text).toBe(`msg-${existing.length - 2}`) + }) + + test('writeToMailbox rejects oversized message text instead of storing it', async () => { + await expect( + writeToMailbox( + 'worker', + { + from: 'team-lead', + text: 'x'.repeat(MAX_MAILBOX_MESSAGE_TEXT_BYTES + 1), + timestamp: new Date(3).toISOString(), + }, + 'alpha', + ), + ).rejects.toThrow('Mailbox message text exceeds') + + expect(await readRawMailbox('worker', 'alpha')).toEqual([]) + }) + + test('writeToMailbox fails closed when an existing mailbox is corrupt', async () => { + const inboxPath = getInboxPath('worker', 'alpha') + await mkdir(dirname(inboxPath), { recursive: true }) + await writeFile(inboxPath, '{not-json', 'utf-8') + + await expect( + writeToMailbox( + 'worker', + { + from: 'team-lead', + text: 'new', + timestamp: new Date(4).toISOString(), + }, + 'alpha', + ), + ).rejects.toThrow() + + expect(await readFile(inboxPath, 'utf-8')).toBe('{not-json') + }) + + test('readMailbox fails closed on corrupt mailbox content', async () => { + const inboxPath = getInboxPath('worker', 'alpha') + await mkdir(dirname(inboxPath), { recursive: true }) + await writeFile(inboxPath, '{not-json', 'utf-8') + + await expect(readMailbox('worker', 'alpha')).rejects.toThrow() + }) +}) diff --git a/src/utils/__tests__/udsMessaging.test.ts b/src/utils/__tests__/udsMessaging.test.ts new file mode 100644 index 000000000..52cb57abc --- /dev/null +++ b/src/utils/__tests__/udsMessaging.test.ts @@ -0,0 +1,305 @@ +import { afterEach, describe, expect, test } from 'bun:test' +import { chmod, mkdir, rm, stat, symlink, unlink } from 'node:fs/promises' +import { createConnection, createServer } from 'node:net' +import { dirname, join } from 'node:path' +import { tmpdir } from 'node:os' +import { + drainInbox, + MAX_UDS_INBOX_ENTRIES, + MAX_UDS_INBOX_BYTES, + MAX_UDS_FRAME_BYTES, + parseUdsTarget, + sendUdsMessage, + setOnEnqueue, + startUdsMessaging, + stopUdsMessaging, +} from '../udsMessaging.js' + +function socketPath(label: string): string { + const suffix = `${process.pid}-${Date.now()}-${Math.random().toString(16).slice(2)}-${label}` + if (process.platform === 'win32') { + return `\\\\.\\pipe\\claude-code-test-${suffix}` + } + return join(tmpdir(), 'claude-code-test', `${suffix}.sock`) +} + +function sleep(ms: number): Promise { + return new Promise(resolve => setTimeout(resolve, ms)) +} + +async function waitForEnqueues( + expected: number, + sendMessages: () => Promise, +): Promise { + let count = 0 + let resolveDone: (() => void) | undefined + const done = new Promise(resolve => { + resolveDone = resolve + }) + + setOnEnqueue(() => { + count++ + if (count >= expected) resolveDone?.() + }) + + await sendMessages() + await Promise.race([ + done, + sleep(5_000).then(() => { + throw new Error(`Timed out waiting for ${expected} UDS enqueues`) + }), + ]) + setOnEnqueue(null) +} + +afterEach(async () => { + setOnEnqueue(null) + drainInbox() + await stopUdsMessaging() +}) + +async function closeServer(server: ReturnType): Promise { + await new Promise(resolve => { + server.close(() => resolve()) + }) +} + +describe('UDS inbox retention', () => { + test('drainInbox returns each pending socket message once', async () => { + const path = socketPath('drain') + await startUdsMessaging(path, { isExplicit: true }) + expect(process.env.CLAUDE_CODE_MESSAGING_TOKEN).toBeUndefined() + + await waitForEnqueues(2, async () => { + await sendUdsMessage(path, { type: 'text', data: 'one' }) + await sendUdsMessage(path, { type: 'text', data: 'two' }) + }) + + const drained = drainInbox() + expect(drained.map(entry => entry.message.data)).toEqual(['one', 'two']) + expect(drained.every(entry => entry.status === 'processed')).toBe(true) + expect(drainInbox()).toEqual([]) + }) + + test('inbox is capped when messages arrive faster than they are drained', async () => { + const path = socketPath('cap') + await startUdsMessaging(path, { isExplicit: true }) + + await waitForEnqueues(MAX_UDS_INBOX_ENTRIES, async () => { + for (let i = 0; i < MAX_UDS_INBOX_ENTRIES; i++) { + await sendUdsMessage(path, { type: 'text', data: String(i) }) + } + }) + await expect( + sendUdsMessage(path, { type: 'text', data: 'overflow' }), + ).rejects.toThrow('inbox full') + + const drained = drainInbox() + expect(drained).toHaveLength(MAX_UDS_INBOX_ENTRIES) + expect(drained[0]?.message.data).toBe('0') + expect(drained.at(-1)?.message.data).toBe(String(MAX_UDS_INBOX_ENTRIES - 1)) + }) + + test('inbox is capped by retained bytes before entry count', async () => { + const path = socketPath('byte-cap') + await startUdsMessaging(path, { isExplicit: true }) + + const payload = 'x'.repeat(32 * 1024) + let accepted = 0 + for (;;) { + try { + await sendUdsMessage(path, { type: 'text', data: payload }) + accepted++ + if (accepted > MAX_UDS_INBOX_BYTES / payload.length + 20) { + throw new Error('byte cap was not enforced') + } + } catch (error) { + expect(error).toBeInstanceOf(Error) + expect((error as Error).message).toContain('inbox full') + break + } + } + + const drained = drainInbox() + expect(drained.length).toBe(accepted) + expect(drained.length).toBeLessThan(MAX_UDS_INBOX_ENTRIES) + }) + + test('ping replies with pong without enqueueing inbox work', async () => { + const path = socketPath('ping') + await startUdsMessaging(path, { isExplicit: true }) + + await sendUdsMessage(path, { type: 'ping' }) + expect(drainInbox()).toEqual([]) + }) + + test('drained entries never expose the UDS auth token', async () => { + const path = socketPath('strip-token') + await startUdsMessaging(path, { isExplicit: true }) + + await waitForEnqueues(1, async () => { + await sendUdsMessage(path, { + type: 'notification', + meta: { keep: 'visible' }, + }) + }) + + const drained = drainInbox() + expect(drained).toHaveLength(1) + expect(drained[0]?.message.meta).toEqual({ keep: 'visible' }) + expect(drained[0]?.message.meta).not.toHaveProperty('authToken') + }) + + test('rejects unauthenticated socket messages', async () => { + const path = socketPath('auth') + await startUdsMessaging(path, { isExplicit: true }) + + const response = await new Promise((resolve, reject) => { + const conn = createConnection(path, () => { + conn.write(`${JSON.stringify({ type: 'text', data: 'bad' })}\n`) + }) + conn.setTimeout(5_000, () => { + conn.destroy() + reject(new Error('Timed out waiting for auth rejection')) + }) + conn.on('data', chunk => { + const text = chunk.toString('utf-8') + if (text.includes('\n')) { + conn.end() + resolve(text) + } + }) + conn.on('error', reject) + }) + + expect(JSON.parse(response).type).toBe('error') + expect(drainInbox()).toEqual([]) + }) + + test('destroys oversized frames before enqueueing inbox work', async () => { + const path = socketPath('oversized') + await startUdsMessaging(path, { isExplicit: true }) + + await new Promise((resolve, reject) => { + const conn = createConnection(path, () => { + conn.write('x'.repeat(MAX_UDS_FRAME_BYTES + 1)) + }) + conn.setTimeout(5_000, () => { + conn.destroy() + reject(new Error('Timed out waiting for oversized frame close')) + }) + conn.on('close', () => resolve()) + conn.on('error', () => resolve()) + }) + + expect(drainInbox()).toEqual([]) + }) + + test('rejects oversized receiver responses before retaining them', async () => { + const path = socketPath('oversized-response') + if (process.platform !== 'win32') { + await mkdir(dirname(path), { recursive: true }) + } + const receiver = createServer(socket => { + socket.on('data', () => { + socket.write('x'.repeat(MAX_UDS_FRAME_BYTES + 1)) + }) + }) + await new Promise((resolve, reject) => { + receiver.on('error', reject) + receiver.listen(path, () => resolve()) + }) + + try { + await expect( + sendUdsMessage( + path, + { type: 'text', data: 'hello' }, + { authToken: 'test-token' }, + ), + ).rejects.toThrow('UDS response frame exceeded size limit') + } finally { + await closeServer(receiver) + if (process.platform !== 'win32') { + await unlink(path).catch(() => undefined) + } + } + }) + + test('rejects inline auth token UDS targets instead of parsing them', async () => { + const path = socketPath('inline-token') + + const targetWithToken = `${path}#token=secret` + expect(() => parseUdsTarget(targetWithToken)).toThrow('inline auth token') + try { + parseUdsTarget(targetWithToken) + } catch (error) { + expect((error as Error).message).not.toContain('secret') + } + + const { sendToUdsSocket } = await import('../udsClient.js') + await expect(sendToUdsSocket(targetWithToken, 'hello')).rejects.toThrow( + 'inline auth token', + ) + }) + + if (process.platform !== 'win32') { + test('creates the listening socket with owner-only permissions', async () => { + const path = socketPath('socket-mode') + await startUdsMessaging(path, { isExplicit: true }) + + const mode = (await stat(path)).mode & 0o777 + expect(mode).toBe(0o600) + }) + + test('fails closed when the capability directory is not private', async () => { + const previousConfigDir = process.env.CLAUDE_CONFIG_DIR + const tempHome = join( + tmpdir(), + `uds-capability-${process.pid}-${Date.now()}-${Math.random().toString(16).slice(2)}`, + ) + process.env.CLAUDE_CONFIG_DIR = tempHome + const capabilityDir = join(tempHome, 'messaging-capabilities') + await mkdir(capabilityDir, { recursive: true, mode: 0o755 }) + await chmod(capabilityDir, 0o755) + + try { + await expect( + startUdsMessaging(socketPath('broad-capdir'), { isExplicit: true }), + ).rejects.toThrow('permissions are too broad') + } finally { + if (previousConfigDir === undefined) { + delete process.env.CLAUDE_CONFIG_DIR + } else { + process.env.CLAUDE_CONFIG_DIR = previousConfigDir + } + await rm(tempHome, { recursive: true, force: true }) + } + }) + + test('fails closed when the capability directory is a symlink', async () => { + const previousConfigDir = process.env.CLAUDE_CONFIG_DIR + const tempHome = join( + tmpdir(), + `uds-capability-link-${process.pid}-${Date.now()}-${Math.random().toString(16).slice(2)}`, + ) + const target = join(tempHome, 'target') + process.env.CLAUDE_CONFIG_DIR = tempHome + await mkdir(target, { recursive: true, mode: 0o700 }) + await symlink(target, join(tempHome, 'messaging-capabilities'), 'dir') + + try { + await expect( + startUdsMessaging(socketPath('symlink-capdir'), { isExplicit: true }), + ).rejects.toThrow('not a private directory') + } finally { + if (previousConfigDir === undefined) { + delete process.env.CLAUDE_CONFIG_DIR + } else { + process.env.CLAUDE_CONFIG_DIR = previousConfigDir + } + await rm(tempHome, { recursive: true, force: true }) + } + }) + } +}) diff --git a/src/utils/messages/systemInit.ts b/src/utils/messages/systemInit.ts index fcb9e74d1..4585c7817 100644 --- a/src/utils/messages/systemInit.ts +++ b/src/utils/messages/systemInit.ts @@ -87,8 +87,10 @@ export function buildSystemInitMessage(inputs: SystemInitInputs): SDKMessage { // Hidden from public SDK types — ant-only UDS messaging socket path if (feature('UDS_INBOX')) { /* eslint-disable @typescript-eslint/no-require-imports */ + const udsMessaging = + require('../udsMessaging.js') as typeof import('../udsMessaging.js') ;(initMessage as Record).messaging_socket_path = - require('../udsMessaging.js').getUdsMessagingSocketPath() + udsMessaging.getUdsMessagingSocketPath() /* eslint-enable @typescript-eslint/no-require-imports */ } initMessage.fast_mode_state = getFastModeState(inputs.model, inputs.fastMode) diff --git a/src/utils/ndjsonFramer.ts b/src/utils/ndjsonFramer.ts index 968ee5217..ecaa04dc8 100644 --- a/src/utils/ndjsonFramer.ts +++ b/src/utils/ndjsonFramer.ts @@ -7,6 +7,11 @@ */ import type { Socket } from 'net' +export type NdjsonFramerOptions = { + maxFrameBytes?: number + onFrameError?: (error: Error) => void +} + /** * Attach an NDJSON framer to a socket. Calls `onMessage` for each * complete JSON line received. Malformed lines are silently skipped. @@ -19,21 +24,54 @@ export function attachNdjsonFramer( socket: Socket, onMessage: (msg: T) => void, parse: (text: string) => T = text => JSON.parse(text) as T, + options: NdjsonFramerOptions = {}, ): void { let buffer = '' + const maxFrameBytes = options.maxFrameBytes ?? Number.POSITIVE_INFINITY + + const rejectOversizedFrame = (bytes: number): void => { + const error = new Error( + `NDJSON frame exceeded ${maxFrameBytes} bytes (${bytes})`, + ) + options.onFrameError?.(error) + socket.destroy(error) + } socket.on('data', (chunk: Buffer) => { + if ( + Number.isFinite(maxFrameBytes) && + !chunk.includes(0x0a) && + Buffer.byteLength(buffer, 'utf8') + chunk.byteLength > maxFrameBytes + ) { + rejectOversizedFrame(Buffer.byteLength(buffer, 'utf8') + chunk.byteLength) + return + } + buffer += chunk.toString() const lines = buffer.split('\n') buffer = lines.pop() ?? '' for (const line of lines) { if (!line.trim()) continue + if ( + Number.isFinite(maxFrameBytes) && + Buffer.byteLength(line, 'utf8') > maxFrameBytes + ) { + rejectOversizedFrame(Buffer.byteLength(line, 'utf8')) + return + } try { onMessage(parse(line)) } catch { // Malformed JSON — skip } } + + if ( + Number.isFinite(maxFrameBytes) && + Buffer.byteLength(buffer, 'utf8') > maxFrameBytes + ) { + rejectOversizedFrame(Buffer.byteLength(buffer, 'utf8')) + } }) } diff --git a/src/utils/swarm/inProcessRunner.ts b/src/utils/swarm/inProcessRunner.ts index 1735500b4..eaab58ef7 100644 --- a/src/utils/swarm/inProcessRunner.ts +++ b/src/utils/swarm/inProcessRunner.ts @@ -97,7 +97,7 @@ import { getLastPeerDmSummary, isPermissionResponse, isShutdownRequest, - markMessageAsReadByIndex, + markMessageAsReadByIdentity, readMailbox, writeToMailbox, } from '../teammateMailbox.js' @@ -405,10 +405,10 @@ function createInProcessCanUseTool( if (msg && !msg.read) { const parsed = isPermissionResponse(msg.text) if (parsed && parsed.request_id === request.id) { - await markMessageAsReadByIndex( + await markMessageAsReadByIdentity( identity.agentName, identity.teamName, - i, + msg, ) if (parsed.subtype === 'success') { processMailboxPermissionResponse({ @@ -801,10 +801,10 @@ async function waitForNextPromptOrShutdown( logForDebugging( `[inProcessRunner] ${identity.agentName} received shutdown request from ${shutdownParsed?.from} (prioritized over ${skippedUnread} unread messages)`, ) - await markMessageAsReadByIndex( + await markMessageAsReadByIdentity( identity.agentName, identity.teamName, - shutdownIndex, + msg, ) return { type: 'shutdown_request', @@ -839,10 +839,10 @@ async function waitForNextPromptOrShutdown( logForDebugging( `[inProcessRunner] ${identity.agentName} received new message from ${msg.from} (index ${selectedIndex})`, ) - await markMessageAsReadByIndex( + await markMessageAsReadByIdentity( identity.agentName, identity.teamName, - selectedIndex, + msg, ) return { type: 'new_message', @@ -1246,8 +1246,13 @@ export async function runInProcessTeammate( // Track in-progress tool use IDs for animation in transcript view let inProgressToolUseIDs = task.inProgressToolUseIDs if (message.type === 'assistant') { - for (const block of (Array.isArray(message.message!.content) ? message.message!.content : [])) { - if (typeof block !== 'string' && block.type === 'tool_use') { + for (const block of Array.isArray(message.message!.content) + ? message.message!.content + : []) { + if ( + typeof block !== 'string' && + block.type === 'tool_use' + ) { inProgressToolUseIDs = new Set([ ...(inProgressToolUseIDs ?? []), block.id, @@ -1318,7 +1323,10 @@ export async function runInProcessTeammate( setAppState, ) if (currentAutonomyRunId) { - await markAutonomyRunFailed(currentAutonomyRunId, ERROR_MESSAGE_USER_ABORT) + await markAutonomyRunFailed( + currentAutonomyRunId, + ERROR_MESSAGE_USER_ABORT, + ) currentAutonomyRunId = undefined } } else if (currentAutonomyRunId) { diff --git a/src/utils/teammateMailbox.ts b/src/utils/teammateMailbox.ts index eb72fcc21..6c18fd721 100644 --- a/src/utils/teammateMailbox.ts +++ b/src/utils/teammateMailbox.ts @@ -7,7 +7,8 @@ * Note: Inboxes are keyed by agent name within a team. */ -import { mkdir, readFile, writeFile } from 'fs/promises' +import { randomBytes } from 'crypto' +import { mkdir, readFile, rename, stat, writeFile } from 'fs/promises' import { join } from 'path' import { z } from 'zod/v4' import { TEAMMATE_MESSAGE_TAG } from '../constants/xml.js' @@ -40,6 +41,13 @@ const LOCK_OPTIONS = { }, } +export const MAX_MAILBOX_MESSAGES = 1_000 +export const MAX_READ_MAILBOX_MESSAGES = 200 +export const MAX_UNREAD_PROTOCOL_MAILBOX_MESSAGES = 2_000 +export const MAX_MAILBOX_MESSAGE_TEXT_BYTES = 64 * 1024 +export const MAX_MAILBOX_RETAINED_BYTES = 2 * 1024 * 1024 +export const MAX_MAILBOX_FILE_BYTES = 4 * 1024 * 1024 + export type TeammateMessage = { from: string text: string @@ -49,6 +57,218 @@ export type TeammateMessage = { summary?: string // 5-10 word summary shown as preview in the UI } +function isJsonLikeMessage(text: string): boolean { + const trimmed = text.trimStart() + return trimmed.startsWith('{') || trimmed.startsWith('[') +} + +function shouldRetainUnreadAsProtocolMessage( + message: TeammateMessage, +): boolean { + if (message.read) return false + if (isStructuredProtocolMessage(message.text)) return true + if (!isJsonLikeMessage(message.text)) return false + + try { + const parsed = jsonParse(message.text) + return Boolean( + parsed && + typeof parsed === 'object' && + 'type' in (parsed as Record), + ) + } catch { + return true + } +} + +function sameMailboxMessage(a: TeammateMessage, b: TeammateMessage): boolean { + return a.from === b.from && a.timestamp === b.timestamp && a.text === b.text +} + +function mailboxMessageStorageBytes(message: TeammateMessage): number { + return Buffer.byteLength(jsonStringify(message), 'utf8') +} + +function assertMailboxMessageSize(message: TeammateMessage): void { + const textBytes = Buffer.byteLength(message.text, 'utf8') + if (textBytes > MAX_MAILBOX_MESSAGE_TEXT_BYTES) { + throw new Error( + `Mailbox message text exceeds ${MAX_MAILBOX_MESSAGE_TEXT_BYTES} bytes`, + ) + } +} + +function toMailboxMessage(value: unknown): TeammateMessage { + if (!value || typeof value !== 'object') { + throw new Error('Invalid mailbox message: expected object') + } + const record = value as Record + if ( + typeof record.from !== 'string' || + typeof record.text !== 'string' || + typeof record.timestamp !== 'string' || + typeof record.read !== 'boolean' + ) { + throw new Error('Invalid mailbox message shape') + } + const message: TeammateMessage = { + from: record.from, + text: record.text, + timestamp: record.timestamp, + read: record.read, + ...(typeof record.color === 'string' ? { color: record.color } : {}), + ...(typeof record.summary === 'string' ? { summary: record.summary } : {}), + } + assertMailboxMessageSize(message) + return message +} + +function parseMailboxMessages(content: string): TeammateMessage[] { + const parsed = jsonParse(content) + if (!Array.isArray(parsed)) { + throw new Error('Invalid mailbox file: expected message array') + } + return parsed.map(toMailboxMessage) +} + +async function readMailboxFile(inboxPath: string): Promise { + const info = await stat(inboxPath) + if (info.size > MAX_MAILBOX_FILE_BYTES) { + throw new Error( + `Mailbox file exceeds ${MAX_MAILBOX_FILE_BYTES} bytes: ${inboxPath}`, + ) + } + return readFile(inboxPath, 'utf-8') +} + +async function readMailboxForMutation( + agentName: string, + teamName?: string, +): Promise { + const inboxPath = getInboxPath(agentName, teamName) + return parseMailboxMessages(await readMailboxFile(inboxPath)) +} + +async function writeMailboxAtomic( + inboxPath: string, + content: string, +): Promise { + const bytes = Buffer.byteLength(content, 'utf8') + if (bytes > MAX_MAILBOX_FILE_BYTES) { + throw new Error( + `Compacted mailbox still exceeds ${MAX_MAILBOX_FILE_BYTES} bytes`, + ) + } + const tempPath = `${inboxPath}.${process.pid}.${randomBytes(8).toString('hex')}.tmp` + await writeFile(tempPath, content, 'utf-8') + await rename(tempPath, inboxPath) +} + +export function compactMailboxMessages( + messages: TeammateMessage[], + limits: { + maxMessages?: number + maxReadMessages?: number + maxUnreadProtocolMessages?: number + maxRetainedBytes?: number + } = {}, +): TeammateMessage[] { + const maxMessages = limits.maxMessages ?? MAX_MAILBOX_MESSAGES + const maxReadMessages = limits.maxReadMessages ?? MAX_READ_MAILBOX_MESSAGES + const maxUnreadProtocolMessages = + limits.maxUnreadProtocolMessages ?? MAX_UNREAD_PROTOCOL_MAILBOX_MESSAGES + const maxRetainedBytes = limits.maxRetainedBytes ?? MAX_MAILBOX_RETAINED_BYTES + + if ( + maxRetainedBytes <= 0 || + (maxMessages <= 0 && maxUnreadProtocolMessages <= 0) + ) { + return [] + } + + const keepIndexes = new Set() + let retainedBytes = 0 + let keptUnreadProtocolMessages = 0 + const tryKeep = (index: number): boolean => { + if (keepIndexes.has(index)) return true + const message = messages[index] + if (!message) return false + const bytes = mailboxMessageStorageBytes(message) + if (bytes > maxRetainedBytes || retainedBytes + bytes > maxRetainedBytes) { + return false + } + keepIndexes.add(index) + retainedBytes += bytes + return true + } + + for (let i = messages.length - 1; i >= 0; i--) { + const message = messages[i] + if (!message || !shouldRetainUnreadAsProtocolMessage(message)) continue + if (keptUnreadProtocolMessages >= maxUnreadProtocolMessages) continue + if (tryKeep(i)) keptUnreadProtocolMessages++ + } + + let keptNonProtocolMessages = 0 + for (let i = messages.length - 1; i >= 0; i--) { + if (keptNonProtocolMessages >= maxMessages) break + const message = messages[i] + if ( + message && + !message.read && + !shouldRetainUnreadAsProtocolMessage(message) + ) { + if (tryKeep(i)) keptNonProtocolMessages++ + } + } + + let keptReadMessages = 0 + for (let i = messages.length - 1; i >= 0; i--) { + if (keptNonProtocolMessages >= maxMessages) break + if (keptReadMessages >= maxReadMessages) break + const message = messages[i] + if (message?.read) { + if (tryKeep(i)) { + keptReadMessages++ + keptNonProtocolMessages++ + } + } + } + + return messages.filter((_message, index) => keepIndexes.has(index)) +} + +function logUnreadMailboxEvictions( + original: TeammateMessage[], + compacted: TeammateMessage[], + context: string, +): void { + const kept = new Set(compacted) + const unreadEvicted = original.filter(message => { + return !message.read && !kept.has(message) + }) + if (unreadEvicted.length === 0) return + + const protocolEvicted = count(unreadEvicted, message => + shouldRetainUnreadAsProtocolMessage(message), + ) + logError( + new Error( + `[TeammateMailbox] Compacted ${unreadEvicted.length} unread message(s) in ${context}; protocol_or_unknown=${protocolEvicted}`, + ), + ) +} + +async function writeCompactedMailbox( + inboxPath: string, + messages: TeammateMessage[], + context: string, +): Promise { + const compacted = compactMailboxMessages(messages) + logUnreadMailboxEvictions(messages, compacted, context) + await writeMailboxAtomic(inboxPath, jsonStringify(compacted, null, 2)) +} + /** * Get the path to a teammate's inbox file * Structure: ~/.claude/teams/{team_name}/inboxes/{agent_name}.json @@ -89,8 +309,7 @@ export async function readMailbox( logForDebugging(`[TeammateMailbox] readMailbox: path=${inboxPath}`) try { - const content = await readFile(inboxPath, 'utf-8') - const messages = jsonParse(content) as TeammateMessage[] + const messages = parseMailboxMessages(await readMailboxFile(inboxPath)) logForDebugging( `[TeammateMailbox] readMailbox: read ${messages.length} message(s)`, ) @@ -103,7 +322,7 @@ export async function readMailbox( } logForDebugging(`Failed to read inbox for ${agentName}: ${error}`) logError(error) - return [] + throw error } } @@ -156,7 +375,7 @@ export async function writeToMailbox( `[TeammateMailbox] writeToMailbox: failed to create inbox file: ${error}`, ) logError(error) - return + throw error } } @@ -168,22 +387,23 @@ export async function writeToMailbox( }) // Re-read messages after acquiring lock to get the latest state - const messages = await readMailbox(recipientName, teamName) + const messages = await readMailboxForMutation(recipientName, teamName) - const newMessage: TeammateMessage = { + const newMessage = toMailboxMessage({ ...message, read: false, - } + }) messages.push(newMessage) - await writeFile(inboxPath, jsonStringify(messages, null, 2), 'utf-8') + await writeCompactedMailbox(inboxPath, messages, 'writeToMailbox') logForDebugging( `[TeammateMailbox] Wrote message to ${recipientName}'s inbox from ${message.from}`, ) } catch (error) { logForDebugging(`Failed to write to inbox for ${recipientName}: ${error}`) logError(error) + throw error } finally { if (release) { await release() @@ -222,7 +442,7 @@ export async function markMessageAsReadByIndex( logForDebugging(`[TeammateMailbox] markMessageAsReadByIndex: lock acquired`) // Re-read messages after acquiring lock to get the latest state - const messages = await readMailbox(agentName, teamName) + const messages = await readMailboxForMutation(agentName, teamName) logForDebugging( `[TeammateMailbox] markMessageAsReadByIndex: read ${messages.length} messages after lock`, ) @@ -244,7 +464,7 @@ export async function markMessageAsReadByIndex( messages[messageIndex] = { ...message, read: true } - await writeFile(inboxPath, jsonStringify(messages, null, 2), 'utf-8') + await writeCompactedMailbox(inboxPath, messages, 'markMessageAsReadByIndex') logForDebugging( `[TeammateMailbox] markMessageAsReadByIndex: marked message at index ${messageIndex} as read`, ) @@ -270,6 +490,46 @@ export async function markMessageAsReadByIndex( } } +export async function markMessageAsReadByIdentity( + agentName: string, + teamName: string | undefined, + expectedMessage: TeammateMessage, +): Promise { + const inboxPath = getInboxPath(agentName, teamName) + const lockFilePath = `${inboxPath}.lock` + + let release: (() => Promise) | undefined + try { + release = await lockfile.lock(inboxPath, { + lockfilePath: lockFilePath, + ...LOCK_OPTIONS, + }) + + const messages = await readMailboxForMutation(agentName, teamName) + const messageIndex = messages.findIndex(message => { + return !message.read && sameMailboxMessage(message, expectedMessage) + }) + if (messageIndex < 0) return false + + messages[messageIndex] = { ...messages[messageIndex]!, read: true } + await writeCompactedMailbox( + inboxPath, + messages, + 'markMessageAsReadByIdentity', + ) + return true + } catch (error) { + const code = getErrnoCode(error) + if (code === 'ENOENT') return false + logError(error) + return false + } finally { + if (release) { + await release() + } + } +} + /** * Mark all messages in a teammate's inbox as read * Uses file locking to prevent race conditions @@ -297,7 +557,7 @@ export async function markMessagesAsRead( logForDebugging(`[TeammateMailbox] markMessagesAsRead: lock acquired`) // Re-read messages after acquiring lock to get the latest state - const messages = await readMailbox(agentName, teamName) + const messages = await readMailboxForMutation(agentName, teamName) logForDebugging( `[TeammateMailbox] markMessagesAsRead: read ${messages.length} messages after lock`, ) @@ -317,7 +577,7 @@ export async function markMessagesAsRead( // messages comes from jsonParse — fresh, unshared objects safe to mutate for (const m of messages) m.read = true - await writeFile(inboxPath, jsonStringify(messages, null, 2), 'utf-8') + await writeCompactedMailbox(inboxPath, messages, 'markMessagesAsRead') logForDebugging( `[TeammateMailbox] markMessagesAsRead: WROTE ${unreadCount} message(s) as read to ${inboxPath}`, ) @@ -1114,7 +1374,7 @@ export async function markMessagesAsReadByPredicate( ...LOCK_OPTIONS, }) - const messages = await readMailbox(agentName, teamName) + const messages = await readMailboxForMutation(agentName, teamName) if (messages.length === 0) { return } @@ -1123,7 +1383,11 @@ export async function markMessagesAsReadByPredicate( !m.read && predicate(m) ? { ...m, read: true } : m, ) - await writeFile(inboxPath, jsonStringify(updatedMessages, null, 2), 'utf-8') + await writeCompactedMailbox( + inboxPath, + updatedMessages, + 'markMessagesAsReadByPredicate', + ) } catch (error) { const code = getErrnoCode(error) if (code === 'ENOENT') { @@ -1161,7 +1425,12 @@ export function getLastPeerDmSummary(messages: Message[]): string | undefined { if (!Array.isArray(content)) continue for (const block of content) { if (typeof block === 'string') continue - const b = block as unknown as { type: string; name?: string; input?: Record; [key: string]: unknown } + const b = block as unknown as { + type: string + name?: string + input?: Record + [key: string]: unknown + } if ( b.type === 'tool_use' && b.name === SEND_MESSAGE_TOOL_NAME && @@ -1177,7 +1446,7 @@ export function getLastPeerDmSummary(messages: Message[]): string | undefined { const to = b.input.to as string const summary = 'summary' in b.input && typeof b.input.summary === 'string' - ? b.input.summary as string + ? (b.input.summary as string) : (b.input.message as string).slice(0, 80) return `[to ${to}] ${summary}` } diff --git a/src/utils/udsClient.ts b/src/utils/udsClient.ts index 781f3ddd1..f08bee696 100644 --- a/src/utils/udsClient.ts +++ b/src/utils/udsClient.ts @@ -16,7 +16,7 @@ import { errorMessage, isFsInaccessible } from './errors.js' import { isProcessRunning } from './genericProcessUtils.js' import { jsonParse, jsonStringify } from './slowOperations.js' import type { SessionKind } from './concurrentSessions.js' -import type { UdsMessage } from './udsMessaging.js' +import { MAX_UDS_FRAME_BYTES, type UdsMessage } from './udsMessaging.js' // --------------------------------------------------------------------------- // Types @@ -43,6 +43,12 @@ function getSessionsDir(): string { return join(getClaudeConfigHomeDir(), 'sessions') } +function getChunkBytes(chunk: string | Buffer): number { + return typeof chunk === 'string' + ? Buffer.byteLength(chunk, 'utf8') + : chunk.byteLength +} + // --------------------------------------------------------------------------- // Discovery // --------------------------------------------------------------------------- @@ -104,9 +110,14 @@ export async function listAllLiveSessions(): Promise { */ export async function listPeers(): Promise { const all = await listAllLiveSessions() - return all.filter( - s => s.pid !== process.pid && s.messagingSocketPath != null, - ) + return all.filter(s => s.pid !== process.pid && s.messagingSocketPath != null) +} + +async function findAuthTokenForSocketPath( + socketPath: string, +): Promise { + const { readUdsCapabilityToken } = await import('./udsMessaging.js') + return readUdsCapabilityToken(socketPath) } // --------------------------------------------------------------------------- @@ -117,10 +128,21 @@ export async function listPeers(): Promise { * Probe a UDS socket to check if a server is listening (ping/pong). * Returns true if the peer responds within the timeout. */ -export async function isPeerAlive(socketPath: string, timeoutMs = 3000): Promise { - return new Promise((resolve) => { +export async function isPeerAlive( + socketPath: string, + timeoutMs = 3000, + authToken?: string, +): Promise { + const token = authToken ?? (await findAuthTokenForSocketPath(socketPath)) + if (!token) return false + + return new Promise(resolve => { const conn = createConnection(socketPath, () => { - const ping: UdsMessage = { type: 'ping', ts: new Date().toISOString() } + const ping: UdsMessage = { + type: 'ping', + ts: new Date().toISOString(), + meta: { authToken: token }, + } conn.write(jsonStringify(ping) + '\n') }) @@ -135,7 +157,19 @@ export async function isPeerAlive(socketPath: string, timeoutMs = 3000): Promise }, timeoutMs) let buffer = '' - conn.on('data', (chunk) => { + conn.on('data', chunk => { + if ( + Buffer.byteLength(buffer, 'utf8') + getChunkBytes(chunk) > + MAX_UDS_FRAME_BYTES + ) { + if (!resolved) { + resolved = true + clearTimeout(timer) + conn.destroy() + resolve(false) + } + return + } buffer += chunk.toString() if (buffer.includes('"pong"')) { if (!resolved) { @@ -165,6 +199,13 @@ export async function sendToUdsSocket( targetSocketPath: string, message: string | Record, ): Promise { + const { parseUdsTarget } = await import('./udsMessaging.js') + const target = parseUdsTarget(targetSocketPath) + const authToken = await findAuthTokenForSocketPath(target.socketPath) + if (!authToken) { + throw new Error(`No auth token found for peer at ${target.socketPath}`) + } + const data = typeof message === 'string' ? message : jsonStringify(message) const udsMsg: UdsMessage = { type: 'text', @@ -177,18 +218,59 @@ export async function sendToUdsSocket( udsMsg.from = getUdsMessagingSocketPath() return new Promise((resolve, reject) => { - const conn = createConnection(targetSocketPath, () => { - conn.write(jsonStringify(udsMsg) + '\n', (err) => { - conn.end() - if (err) reject(err) - else resolve() + let buffer = '' + let settled = false + const finish = (error?: Error): void => { + if (settled) return + settled = true + conn.end() + if (error) reject(error) + else resolve() + } + const conn = createConnection(target.socketPath, () => { + udsMsg.meta = { ...udsMsg.meta, authToken } + conn.write(jsonStringify(udsMsg) + '\n', err => { + if (err) finish(err) }) }) - conn.on('error', (err) => { - reject(new Error(`Failed to connect to peer at ${targetSocketPath}: ${errorMessage(err)}`)) + conn.on('data', chunk => { + if ( + Buffer.byteLength(buffer, 'utf8') + getChunkBytes(chunk) > + MAX_UDS_FRAME_BYTES + ) { + finish(new Error('UDS response frame exceeded size limit')) + return + } + buffer += chunk.toString() + const lines = buffer.split('\n') + buffer = lines.pop() ?? '' + for (const line of lines) { + if (!line.trim()) continue + let response: UdsMessage + try { + response = jsonParse(line) as UdsMessage + } catch { + continue + } + if (response.type === 'response') { + finish() + return + } + if (response.type === 'error') { + finish(new Error(response.data ?? 'UDS receiver rejected message')) + return + } + } + }) + conn.on('error', err => { + finish( + new Error( + `Failed to connect to peer at ${target.socketPath}: ${errorMessage(err)}`, + ), + ) }) conn.setTimeout(5000, () => { - conn.destroy(new Error('Connection timed out')) + finish(new Error('Connection timed out')) }) }) } diff --git a/src/utils/udsMessaging.ts b/src/utils/udsMessaging.ts index 1c95ab63c..7efa7fbf6 100644 --- a/src/utils/udsMessaging.ts +++ b/src/utils/udsMessaging.ts @@ -8,14 +8,25 @@ * but can be overridden via --messaging-socket-path. */ +import { createHash, randomBytes } from 'crypto' import { createServer, type Server, type Socket } from 'net' -import { mkdir, unlink } from 'fs/promises' +import { + chmod, + lstat, + mkdir, + open, + readFile, + rename, + unlink, +} from 'fs/promises' import { dirname, join } from 'path' import { tmpdir } from 'os' import { registerCleanup } from './cleanupRegistry.js' import { logForDebugging } from './debug.js' import { errorMessage } from './errors.js' +import { getClaudeConfigHomeDir } from './envUtils.js' import { attachNdjsonFramer } from './ndjsonFramer.js' +import { logError } from './log.js' import { jsonParse, jsonStringify } from './slowOperations.js' // --------------------------------------------------------------------------- @@ -27,6 +38,7 @@ export type UdsMessageType = | 'notification' | 'query' | 'response' + | 'error' | 'ping' | 'pong' @@ -60,6 +72,15 @@ let onEnqueueCb: (() => void) | null = null const clients = new Set() const inbox: UdsInboxEntry[] = [] let nextId = 1 +let defaultSocketPath: string | null = null +let authToken: string | null = null +let capabilityFilePath: string | null = null +let inboxBytes = 0 + +export const MAX_UDS_INBOX_ENTRIES = 1_000 +export const MAX_UDS_FRAME_BYTES = 64 * 1024 +export const MAX_UDS_INBOX_BYTES = 2 * 1024 * 1024 +export const MAX_UDS_CLIENTS = 128 // --------------------------------------------------------------------------- // Public API — socket path helpers @@ -74,10 +95,19 @@ let nextId = 1 * transparently, but we use the pipe format on Windows for Node.js compat. */ export function getDefaultUdsSocketPath(): string { + if (defaultSocketPath) return defaultSocketPath + const nonce = randomBytes(16).toString('hex') if (process.platform === 'win32') { - return `\\\\.\\pipe\\claude-code-${process.pid}` + defaultSocketPath = `\\\\.\\pipe\\claude-code-${process.pid}-${nonce}` + return defaultSocketPath } - return join(tmpdir(), 'claude-code-socks', `${process.pid}.sock`) + defaultSocketPath = join( + tmpdir(), + 'claude-code-socks', + `${process.pid}-${nonce}`, + 'messaging.sock', + ) + return defaultSocketPath } /** @@ -88,6 +118,142 @@ export function getUdsMessagingSocketPath(): string | undefined { return socketPath ?? undefined } +export function formatUdsAddress(socket: string): string { + return `uds:${socket}` +} + +export function parseUdsTarget(target: string): { + socketPath: string +} { + if (target.includes('#token=')) { + throw new Error( + 'UDS target must not include an inline auth token; use the ListPeers address', + ) + } + return { socketPath: target } +} + +function getCapabilityDir(): string { + return join(getClaudeConfigHomeDir(), 'messaging-capabilities') +} + +function getCapabilityPath(socket: string): string { + const digest = createHash('sha256').update(socket).digest('hex') + return join(getCapabilityDir(), `${digest}.json`) +} + +function isNotFound(error: unknown): boolean { + return ( + typeof error === 'object' && + error !== null && + (error as NodeJS.ErrnoException).code === 'ENOENT' + ) +} + +async function assertPrivateCapabilityDir(dir: string): Promise { + let stat: Awaited> + try { + stat = await lstat(dir) + } catch (error) { + if (!isNotFound(error)) throw error + await mkdir(dir, { recursive: true, mode: 0o700 }) + stat = await lstat(dir) + } + + if (!stat.isDirectory() || stat.isSymbolicLink()) { + throw new Error( + `[udsMessaging] capability directory is not a private directory: ${dir}`, + ) + } + if (process.platform !== 'win32') { + const broadMode = stat.mode & 0o077 + if (broadMode !== 0) { + throw new Error( + `[udsMessaging] capability directory permissions are too broad: ${dir}`, + ) + } + if (typeof process.getuid === 'function' && stat.uid !== process.getuid()) { + throw new Error( + `[udsMessaging] capability directory owner does not match current user: ${dir}`, + ) + } + } + + await chmod(dir, 0o700) +} + +async function writePrivateFileExclusive( + path: string, + content: string, +): Promise { + const handle = await open(path, 'wx', 0o600) + try { + await handle.writeFile(content, 'utf-8') + } finally { + await handle.close() + } + await chmod(path, 0o600) +} + +async function ensureSocketParent(path: string): Promise { + const dir = dirname(path) + try { + const stat = await lstat(dir) + if (!stat.isDirectory() || stat.isSymbolicLink()) { + throw new Error( + `[udsMessaging] socket parent is not a directory: ${dir}`, + ) + } + return + } catch (error) { + if (!isNotFound(error)) throw error + } + + await mkdir(dir, { recursive: true, mode: 0o700 }) + await chmod(dir, 0o700) +} + +async function writeCapabilityFile( + socket: string, + token: string, +): Promise { + const dir = getCapabilityDir() + await assertPrivateCapabilityDir(dir) + const target = getCapabilityPath(socket) + const temp = `${target}.${process.pid}.${randomBytes(8).toString('hex')}.tmp` + try { + await writePrivateFileExclusive( + temp, + jsonStringify({ socketPath: socket, authToken: token }), + ) + await rename(temp, target) + } catch (error) { + try { + await unlink(temp) + } catch { + // Temp file may not exist if exclusive creation failed. + } + throw error + } + capabilityFilePath = target +} + +export async function readUdsCapabilityToken( + socket: string, +): Promise { + try { + const parsed = jsonParse( + await readFile(getCapabilityPath(socket), 'utf-8'), + ) as Record + if (parsed.socketPath === socket && typeof parsed.authToken === 'string') { + return parsed.authToken + } + } catch { + // Missing or unreadable capability file means the peer is not addressable. + } + return undefined +} + // --------------------------------------------------------------------------- // Inbox // --------------------------------------------------------------------------- @@ -101,16 +267,79 @@ export function setOnEnqueue(cb: (() => void) | null): void { } /** - * Drain all pending inbox messages, marking them processed. + * Drain all pending inbox messages and release retained history. */ export function drainInbox(): UdsInboxEntry[] { - const pending = inbox.filter(e => e.status === 'pending') + const pending = inbox.splice(0, inbox.length) + inboxBytes = 0 for (const entry of pending) { entry.status = 'processed' } return pending } +function getMessageBytes(message: UdsMessage): number { + return Buffer.byteLength(jsonStringify(message), 'utf8') +} + +function enqueueInboxEntry(entry: UdsInboxEntry): boolean { + const entryBytes = getMessageBytes(entry.message) + if ( + entryBytes > MAX_UDS_FRAME_BYTES || + inbox.length >= MAX_UDS_INBOX_ENTRIES || + inboxBytes + entryBytes > MAX_UDS_INBOX_BYTES + ) { + logError( + new Error( + `[udsMessaging] inbox full (${inbox.length}/${MAX_UDS_INBOX_ENTRIES}, ${inboxBytes}/${MAX_UDS_INBOX_BYTES} bytes); dropping message type=${entry.message.type}`, + ), + ) + return false + } + inbox.push(entry) + inboxBytes += entryBytes + return true +} + +function ensureAuthToken(): string { + if (!authToken) { + authToken = randomBytes(32).toString('hex') + } + return authToken +} + +function getMessageAuthToken(message: UdsMessage): string | undefined { + const token = message.meta?.authToken + return typeof token === 'string' ? token : undefined +} + +function isAuthorizedMessage(message: UdsMessage): boolean { + return getMessageAuthToken(message) === authToken +} + +function writeSocketMessage(socket: Socket, message: UdsMessage): void { + if (socket.destroyed) return + socket.write(jsonStringify(message) + '\n') +} + +function stripAuthToken(message: UdsMessage): UdsMessage { + const { authToken: _authToken, ...metaWithoutAuth } = message.meta ?? {} + return { + ...message, + meta: Object.keys(metaWithoutAuth).length > 0 ? metaWithoutAuth : undefined, + } +} + +function withRequestAuthToken(message: UdsMessage, token: string): UdsMessage { + return { + ...message, + meta: { + ...message.meta, + authToken: token, + }, + } +} + // --------------------------------------------------------------------------- // Server // --------------------------------------------------------------------------- @@ -132,7 +361,7 @@ export async function startUdsMessaging( // Ensure parent directory exists (skip on Windows — pipe paths aren't files) if (process.platform !== 'win32') { - await mkdir(dirname(path), { recursive: true }) + await ensureSocketParent(path) } // Clean up stale socket file (skip on Windows — pipe paths aren't files) @@ -144,69 +373,134 @@ export async function startUdsMessaging( } } - socketPath = path + const token = ensureAuthToken() + try { + await writeCapabilityFile(path, token) + socketPath = path - await new Promise((resolve, reject) => { - const srv = createServer(socket => { - clients.add(socket) - logForDebugging( - `[udsMessaging] client connected (total: ${clients.size})`, - ) - - attachNdjsonFramer( - socket, - msg => { - // Handle ping with automatic pong - if (msg.type === 'ping') { - const pong: UdsMessage = { - type: 'pong', - from: socketPath ?? undefined, - ts: new Date().toISOString(), - } - if (!socket.destroyed) { - socket.write(jsonStringify(pong) + '\n') - } - return - } - - // Enqueue into inbox - const entry: UdsInboxEntry = { - id: `uds-${nextId++}`, - message: msg, - receivedAt: Date.now(), - status: 'pending', - } - inbox.push(entry) + await new Promise((resolve, reject) => { + const srv = createServer(socket => { + if (clients.size >= MAX_UDS_CLIENTS) { logForDebugging( - `[udsMessaging] enqueued message type=${msg.type} from=${msg.from ?? 'unknown'}`, + `[udsMessaging] rejected client: ${clients.size}/${MAX_UDS_CLIENTS} clients already connected`, ) - onEnqueueCb?.() - }, - text => jsonParse(text) as UdsMessage, - ) + socket.destroy() + return + } + clients.add(socket) + logForDebugging( + `[udsMessaging] client connected (total: ${clients.size})`, + ) - socket.on('close', () => { - clients.delete(socket) + attachNdjsonFramer( + socket, + msg => { + if (!isAuthorizedMessage(msg)) { + logForDebugging( + `[udsMessaging] rejected unauthenticated message type=${msg.type}`, + ) + if (!socket.destroyed) { + socket.write( + jsonStringify({ + type: 'error', + data: 'unauthorized', + ts: new Date().toISOString(), + } satisfies UdsMessage) + '\n', + ) + } + return + } + + // Handle ping with automatic pong + if (msg.type === 'ping') { + writeSocketMessage(socket, { + type: 'pong', + from: socketPath ?? undefined, + ts: new Date().toISOString(), + }) + return + } + + // Enqueue into inbox + const sanitizedMessage = stripAuthToken(msg) + const entry: UdsInboxEntry = { + id: `uds-${nextId++}`, + message: sanitizedMessage, + receivedAt: Date.now(), + status: 'pending', + } + if (!enqueueInboxEntry(entry)) { + writeSocketMessage(socket, { + type: 'error', + data: 'inbox full', + ts: new Date().toISOString(), + }) + return + } + logForDebugging( + `[udsMessaging] enqueued message type=${msg.type} from=${msg.from ?? 'unknown'}`, + ) + writeSocketMessage(socket, { + type: 'response', + data: 'ok', + ts: new Date().toISOString(), + meta: { id: entry.id }, + }) + onEnqueueCb?.() + }, + text => jsonParse(text) as UdsMessage, + { + maxFrameBytes: MAX_UDS_FRAME_BYTES, + onFrameError: error => { + logForDebugging(`[udsMessaging] ${error.message}`) + }, + }, + ) + + socket.on('close', () => { + clients.delete(socket) + }) + + socket.on('error', err => { + clients.delete(socket) + logForDebugging(`[udsMessaging] client error: ${errorMessage(err)}`) + }) }) - socket.on('error', err => { - clients.delete(socket) - logForDebugging(`[udsMessaging] client error: ${errorMessage(err)}`) + srv.on('error', reject) + + srv.listen(path, () => { + void (async () => { + try { + if (process.platform !== 'win32') { + await chmod(path, 0o600) + } + server = srv + // Export so child processes can discover the socket + process.env.CLAUDE_CODE_MESSAGING_SOCKET = path + logForDebugging( + `[udsMessaging] server listening on ${path}${opts?.isExplicit ? ' (explicit)' : ''}`, + ) + resolve() + } catch (error) { + srv.close(() => reject(error)) + } + })() }) }) - - srv.on('error', reject) - - srv.listen(path, () => { - server = srv - // Export so child processes can discover the socket - process.env.CLAUDE_CODE_MESSAGING_SOCKET = path - logForDebugging( - `[udsMessaging] server listening on ${path}${opts?.isExplicit ? ' (explicit)' : ''}`, - ) - resolve() - }) - }) + } catch (error) { + if (capabilityFilePath) { + try { + await unlink(capabilityFilePath) + } catch { + // Already gone. + } + capabilityFilePath = null + } + socketPath = null + authToken = null + throw error + } // Register cleanup so the socket file is removed on exit registerCleanup(async () => { @@ -230,6 +524,9 @@ export async function stopUdsMessaging(): Promise { server!.close(() => resolve()) }) server = null + inbox.length = 0 + inboxBytes = 0 + onEnqueueCb = null // Remove socket file (skip on Windows — pipe paths aren't files) if (socketPath) { @@ -245,7 +542,30 @@ export async function stopUdsMessaging(): Promise { `[udsMessaging] server stopped, socket removed: ${socketPath}`, ) socketPath = null + authToken = null } + if (capabilityFilePath) { + try { + await unlink(capabilityFilePath) + } catch { + // Already gone + } + capabilityFilePath = null + } +} + +function parseResponseLine(line: string): UdsMessage | null { + try { + return jsonParse(line) as UdsMessage + } catch { + return null + } +} + +function getChunkBytes(chunk: string | Buffer): number { + return typeof chunk === 'string' + ? Buffer.byteLength(chunk, 'utf8') + : chunk.byteLength } /** @@ -255,23 +575,66 @@ export async function stopUdsMessaging(): Promise { export async function sendUdsMessage( targetSocketPath: string, message: UdsMessage, + opts: { authToken?: string } = {}, ): Promise { const { createConnection } = await import('net') - message.from = message.from ?? socketPath ?? undefined - message.ts = message.ts ?? new Date().toISOString() + const token = opts.authToken ?? authToken + if (!token) { + throw new Error('Cannot send UDS message without auth token') + } + const outbound = withRequestAuthToken( + { + ...message, + from: message.from ?? socketPath ?? undefined, + ts: message.ts ?? new Date().toISOString(), + }, + token, + ) return new Promise((resolve, reject) => { + let buffer = '' + let settled = false + const finish = (error?: Error): void => { + if (settled) return + settled = true + conn.end() + if (error) reject(error) + else resolve() + } const conn = createConnection(targetSocketPath, () => { - conn.write(jsonStringify(message) + '\n', err => { - conn.end() - if (err) reject(err) - else resolve() + conn.write(jsonStringify(outbound) + '\n', err => { + if (err) finish(err) }) }) - conn.on('error', reject) + conn.on('data', chunk => { + if ( + Buffer.byteLength(buffer, 'utf8') + getChunkBytes(chunk) > + MAX_UDS_FRAME_BYTES + ) { + finish(new Error('UDS response frame exceeded size limit')) + return + } + buffer += chunk.toString() + const lines = buffer.split('\n') + buffer = lines.pop() ?? '' + for (const line of lines) { + if (!line.trim()) continue + const response = parseResponseLine(line) + if (!response) continue + if (response.type === 'response' || response.type === 'pong') { + finish() + return + } + if (response.type === 'error') { + finish(new Error(response.data ?? 'UDS receiver rejected message')) + return + } + } + }) + conn.on('error', err => finish(err)) // Timeout so we don't hang on unreachable sockets conn.setTimeout(5000, () => { - conn.destroy(new Error('Connection timed out')) + finish(new Error('Connection timed out')) }) }) }