From f353eb056a335760ff2390e1ca6f42c7b031d260 Mon Sep 17 00:00:00 2001 From: unraid Date: Sun, 26 Apr 2026 21:44:42 +0800 Subject: [PATCH] fix: bound agent communication memory growth UDS messaging now uses private local capabilities instead of exposing auth tokens through SDK metadata, environment variables, session registry, peer listing, or tool output. The receive path bounds NDJSON frames, response buffers, active clients, and pending inbox bytes, and strips auth metadata before messages enter the prompt queue. Teammate mailboxes now validate file and message sizes, fail closed on corrupt mutation inputs, compact by count and retained bytes, and use stable message identity for in-process acknowledgements. Agent summaries now fork only a bounded recent context using lazy size estimation and content fingerprints instead of retaining or serializing unbounded histories. Constraint: PR #361 was already merged; this branch is based on upstream/main@c2ac9a74. Rejected: Default-disabling COORDINATOR_MODE/TEAMMEM only | explicit feature enablement still hit unbounded paths. Rejected: Persisting UDS auth in SDK/env/session registry | bridge/remote metadata can leak local capability secrets. Rejected: Inline uds #token addresses | observable/tool/classifier paths can reflect raw addresses outside the UDS request frame. Rejected: Positional mailbox marking after compaction | compaction can shift indices across the lock boundary. Confidence: high Scope-risk: moderate Directive: Do not expose UDS capability tokens through SDK messages, environment variables, session registry, peer-list output, or SendMessage result/classifier surfaces. Directive: Do not reintroduce positional mailbox acknowledgements unless compaction is removed or read+mark is atomic under one lock. Tested: bun test src/utils/__tests__/ndjsonFramer.test.ts src/utils/__tests__/udsMessaging.test.ts packages/builtin-tools/src/tools/SendMessageTool/__tests__/udsRecipientSanitization.test.ts Tested: bunx tsc --noEmit --pretty false Tested: bun run lint Tested: bunx biome lint modified src/package files Tested: bun run test:all (3704 pass, 0 fail, 6734 expects) Tested: bun audit (No vulnerabilities found) Tested: bun run build Tested: bun run build:vite Tested: git diff --check Not-tested: End-to-end external UDS client driving a full production headless model turn. --- .../src/tools/ListPeersTool/ListPeersTool.ts | 24 +- .../tools/SendMessageTool/SendMessageTool.ts | 81 ++- .../udsRecipientSanitization.test.ts | 41 ++ src/cli/print.ts | 28 +- src/commands/peers/peers.ts | 15 +- .../__tests__/summaryContext.test.ts | 132 +++++ src/services/AgentSummary/agentSummary.ts | 34 +- src/services/AgentSummary/summaryContext.ts | 175 ++++++ src/utils/__tests__/ndjsonFramer.test.ts | 91 ++++ src/utils/__tests__/teammateMailbox.test.ts | 310 +++++++++++ src/utils/__tests__/udsMessaging.test.ts | 305 +++++++++++ src/utils/messages/systemInit.ts | 4 +- src/utils/ndjsonFramer.ts | 38 ++ src/utils/swarm/inProcessRunner.ts | 28 +- src/utils/teammateMailbox.ts | 303 ++++++++++- src/utils/udsClient.ts | 114 +++- src/utils/udsMessaging.ts | 501 +++++++++++++++--- 17 files changed, 2087 insertions(+), 137 deletions(-) create mode 100644 packages/builtin-tools/src/tools/SendMessageTool/__tests__/udsRecipientSanitization.test.ts create mode 100644 src/services/AgentSummary/__tests__/summaryContext.test.ts create mode 100644 src/services/AgentSummary/summaryContext.ts create mode 100644 src/utils/__tests__/ndjsonFramer.test.ts create mode 100644 src/utils/__tests__/teammateMailbox.test.ts create mode 100644 src/utils/__tests__/udsMessaging.test.ts diff --git a/packages/builtin-tools/src/tools/ListPeersTool/ListPeersTool.ts b/packages/builtin-tools/src/tools/ListPeersTool/ListPeersTool.ts index e520243a5..48219edc7 100644 --- a/packages/builtin-tools/src/tools/ListPeersTool/ListPeersTool.ts +++ b/packages/builtin-tools/src/tools/ListPeersTool/ListPeersTool.ts @@ -85,21 +85,35 @@ Use this tool to discover messaging targets before sending cross-session message // and optionally includes Remote Control bridge peers. const peers: PeerInfo[] = [] - // Discovery is handled by the UDS messaging subsystem initialized in setup.ts. - // Return discovered peers from the app state. - const appState = context.getAppState() - const messagingSocketPath = (appState as Record).messagingSocketPath as string | undefined + /* eslint-disable @typescript-eslint/no-require-imports */ + const udsMessaging = + require('src/utils/udsMessaging.js') as typeof import('src/utils/udsMessaging.js') + const udsClient = + require('src/utils/udsClient.js') as typeof import('src/utils/udsClient.js') + /* eslint-enable @typescript-eslint/no-require-imports */ + + const messagingSocketPath = udsMessaging.getUdsMessagingSocketPath() if (messagingSocketPath) { // Self entry for reference if (_input.include_self) { peers.push({ - address: `uds:${messagingSocketPath}`, + address: udsMessaging.formatUdsAddress(messagingSocketPath), name: 'self', pid: process.pid, }) } } + for (const peer of await udsClient.listPeers()) { + if (!peer.messagingSocketPath) continue + peers.push({ + address: udsMessaging.formatUdsAddress(peer.messagingSocketPath), + name: peer.name ?? peer.kind, + cwd: peer.cwd, + pid: peer.pid, + }) + } + return { data: { peers }, } diff --git a/packages/builtin-tools/src/tools/SendMessageTool/SendMessageTool.ts b/packages/builtin-tools/src/tools/SendMessageTool/SendMessageTool.ts index 4e9737051..3548544fc 100644 --- a/packages/builtin-tools/src/tools/SendMessageTool/SendMessageTool.ts +++ b/packages/builtin-tools/src/tools/SendMessageTool/SendMessageTool.ts @@ -130,6 +130,43 @@ export type SendMessageToolOutput = | RequestOutput | ResponseOutput +const UDS_INLINE_TOKEN_MARKER = '#token=' +const UDS_INLINE_TOKEN_REJECTED_KEY = '__udsInlineTokenRejected' + +function stripInlineUdsToken(target: string): string { + const markerIndex = target.lastIndexOf(UDS_INLINE_TOKEN_MARKER) + return markerIndex === -1 ? target : target.slice(0, markerIndex) +} + +function hasInlineUdsToken(to: string): boolean { + const addr = parseAddress(to) + return ( + addr.scheme === 'uds' && addr.target.includes(UDS_INLINE_TOKEN_MARKER) + ) +} + +function recipientForDisplay(to: string): string { + const addr = parseAddress(to) + if (addr.scheme !== 'uds') return to + return `uds:${stripInlineUdsToken(addr.target)}` +} + +function markAndRedactInlineUdsToken( + input: { to: string } & Record, +): void { + if (!hasInlineUdsToken(input.to)) return + input.to = recipientForDisplay(input.to) + input[UDS_INLINE_TOKEN_REJECTED_KEY] = true +} + +function wasInlineUdsTokenRejected(input: unknown): boolean { + return ( + typeof input === 'object' && + input !== null && + (input as Record)[UDS_INLINE_TOKEN_REJECTED_KEY] === true + ) +} + function findTeammateColor( appState: { teamContext?: { teammates: { [id: string]: { color?: string } } } @@ -541,15 +578,19 @@ export const SendMessageTool: Tool = }, backfillObservableInput(input) { - if ('type' in input) return if (typeof input.to !== 'string') return + markAndRedactInlineUdsToken( + input as { to: string } & Record, + ) + if ('type' in input) return + if (input.to === '*') { input.type = 'broadcast' if (typeof input.message === 'string') input.content = input.message } else if (typeof input.message === 'string') { input.type = 'message' - input.recipient = input.to + input.recipient = recipientForDisplay(input.to) input.content = input.message } else if (typeof input.message === 'object' && input.message !== null) { const msg = input.message as { @@ -560,7 +601,7 @@ export const SendMessageTool: Tool = feedback?: string } input.type = msg.type - input.recipient = input.to + input.recipient = recipientForDisplay(input.to) if (msg.request_id !== undefined) input.request_id = msg.request_id if (msg.approve !== undefined) input.approve = msg.approve const content = msg.reason ?? msg.feedback @@ -569,16 +610,17 @@ export const SendMessageTool: Tool = }, toAutoClassifierInput(input) { + const recipient = recipientForDisplay(input.to) if (typeof input.message === 'string') { - return `to ${input.to}: ${input.message}` + return `to ${recipient}: ${input.message}` } switch (input.message.type) { case 'shutdown_request': - return `shutdown_request to ${input.to}` + return `shutdown_request to ${recipient}` case 'shutdown_response': return `shutdown_response ${input.message.approve ? 'approve' : 'reject'} ${input.message.request_id}` case 'plan_approval_response': - return `plan_approval ${input.message.approve ? 'approve' : 'reject'} to ${input.to}` + return `plan_approval ${input.message.approve ? 'approve' : 'reject'} to ${recipient}` } }, @@ -630,6 +672,19 @@ export const SendMessageTool: Tool = errorCode: 9, } } + if (feature('UDS_INBOX')) { + if ( + addr.scheme === 'uds' && + (hasInlineUdsToken(input.to) || wasInlineUdsTokenRejected(input)) + ) { + return { + result: false, + message: + 'uds addresses must not include inline auth tokens; use the ListPeers address', + errorCode: 9, + } + } + } if (input.to.includes('@')) { return { result: false, @@ -787,6 +842,16 @@ export const SendMessageTool: Tool = } } if (addr.scheme === 'uds') { + const recipient = recipientForDisplay(input.to) + if (hasInlineUdsToken(input.to) || wasInlineUdsTokenRejected(input)) { + return { + data: { + success: false, + message: + 'uds addresses must not include inline auth tokens; use the ListPeers address', + }, + } + } /* eslint-disable @typescript-eslint/no-require-imports */ const { sendToUdsSocket } = require('src/utils/udsClient.js') as typeof import('src/utils/udsClient.js') @@ -797,14 +862,14 @@ export const SendMessageTool: Tool = return { data: { success: true, - message: `”${preview}” → ${input.to}`, + message: `”${preview}” → ${recipient}`, }, } } catch (e) { return { data: { success: false, - message: `Failed to send to ${input.to}: ${errorMessage(e)}`, + message: `Failed to send to ${recipient}: ${errorMessage(e)}`, }, } } diff --git a/packages/builtin-tools/src/tools/SendMessageTool/__tests__/udsRecipientSanitization.test.ts b/packages/builtin-tools/src/tools/SendMessageTool/__tests__/udsRecipientSanitization.test.ts new file mode 100644 index 000000000..20124b6c3 --- /dev/null +++ b/packages/builtin-tools/src/tools/SendMessageTool/__tests__/udsRecipientSanitization.test.ts @@ -0,0 +1,41 @@ +import { describe, expect, mock, test } from 'bun:test' + +mock.module('bun:bundle', () => ({ + feature: (name: string) => name === 'UDS_INBOX', +})) + +describe('SendMessageTool UDS recipient handling', () => { + test('redacts inline UDS tokens before classifier and observable paths', async () => { + const { SendMessageTool } = await import('../SendMessageTool.js') + const tokenAddress = 'uds:/tmp/peer.sock#token=secret-token' + + const observableInput = { + to: tokenAddress, + message: 'hello', + } as Record + SendMessageTool.backfillObservableInput!(observableInput) + + expect(observableInput.recipient).toBe('uds:/tmp/peer.sock') + expect(JSON.stringify(observableInput)).not.toContain('secret-token') + expect( + SendMessageTool.toAutoClassifierInput({ + to: tokenAddress, + message: 'hello', + }), + ).toBe('to uds:/tmp/peer.sock: hello') + }) + + test('rejects inline UDS tokens during validation', async () => { + const { SendMessageTool } = await import('../SendMessageTool.js') + const result = await SendMessageTool.validateInput!( + { + to: 'uds:/tmp/peer.sock#token=secret-token', + message: 'hello', + }, + {} as never, + ) + + expect(result.result).toBe(false) + expect(JSON.stringify(result)).not.toContain('secret-token') + }) +}) diff --git a/src/cli/print.ts b/src/cli/print.ts index eb2b543f8..7a291fb50 100644 --- a/src/cli/print.ts +++ b/src/cli/print.ts @@ -2763,13 +2763,37 @@ function runHeadlessStreaming( // when a message arrives via the UDS socket in headless mode. if (feature('UDS_INBOX')) { /* eslint-disable @typescript-eslint/no-require-imports */ - const { setOnEnqueue } = require('../utils/udsMessaging.js') + const { drainInbox, setOnEnqueue } = + require('../utils/udsMessaging.js') as typeof import('../utils/udsMessaging.js') /* eslint-enable @typescript-eslint/no-require-imports */ + + const enqueueUdsInboxMessages = (): boolean => { + const entries = drainInbox() + for (const entry of entries) { + const value = + typeof entry.message.data === 'string' + ? entry.message.data + : jsonStringify(entry.message) + enqueue({ + mode: 'prompt', + value, + uuid: randomUUID(), + }) + } + return entries.length > 0 + } + setOnEnqueue(() => { if (!inputClosed) { - void run() + if (enqueueUdsInboxMessages()) { + void run() + } } }) + + if (enqueueUdsInboxMessages()) { + void run() + } } // Cron scheduler: runs scheduled_tasks.json tasks in SDK/-p mode. diff --git a/src/commands/peers/peers.ts b/src/commands/peers/peers.ts index aed37d327..fcb7e17a7 100644 --- a/src/commands/peers/peers.ts +++ b/src/commands/peers/peers.ts @@ -1,6 +1,9 @@ import type { LocalCommandCall } from '../../types/command.js' import { listPeers, isPeerAlive } from '../../utils/udsClient.js' -import { getUdsMessagingSocketPath } from '../../utils/udsMessaging.js' +import { + formatUdsAddress, + getUdsMessagingSocketPath, +} from '../../utils/udsMessaging.js' export const call: LocalCommandCall = async (_args, _context) => { const mySocket = getUdsMessagingSocketPath() @@ -29,11 +32,11 @@ export const call: LocalCommandCall = async (_args, _context) => { ? ` started: ${formatAge(peer.startedAt)}` : '' - lines.push( - ` [${status}] PID ${peer.pid} (${label})${cwd}${age}`, - ) + lines.push(` [${status}] PID ${peer.pid} (${label})${cwd}${age}`) if (peer.messagingSocketPath) { - lines.push(` socket: ${peer.messagingSocketPath}`) + lines.push( + ` socket: ${formatUdsAddress(peer.messagingSocketPath)}`, + ) } if (peer.sessionId) { lines.push(` session: ${peer.sessionId}`) @@ -43,7 +46,7 @@ export const call: LocalCommandCall = async (_args, _context) => { lines.push('') lines.push( - 'To message a peer: use SendMessage with to="uds:"', + 'To message a peer: use SendMessage with the shown uds: address', ) return { type: 'text', value: lines.join('\n') } diff --git a/src/services/AgentSummary/__tests__/summaryContext.test.ts b/src/services/AgentSummary/__tests__/summaryContext.test.ts new file mode 100644 index 000000000..3ffa55964 --- /dev/null +++ b/src/services/AgentSummary/__tests__/summaryContext.test.ts @@ -0,0 +1,132 @@ +import { describe, expect, test } from 'bun:test' +import type { Message } from '../../../types/message.js' +import { + getSummaryContextFingerprint, + selectSummaryContextMessages, +} from '../summaryContext.js' + +function makeMessage( + type: 'user' | 'assistant', + uuid: string, + content: string, +): Message { + return { + type, + uuid, + message: { + role: type, + content, + }, + } as unknown as Message +} + +describe('selectSummaryContextMessages', () => { + test('keeps a bounded recent suffix that starts with a user message', () => { + const messages = [ + makeMessage('assistant', 'a0', 'older assistant'), + makeMessage('user', 'u1', 'first prompt'), + makeMessage('assistant', 'a1', 'first response'), + makeMessage('user', 'u2', 'second prompt'), + makeMessage('assistant', 'a2', 'second response'), + ] + + const selected = selectSummaryContextMessages(messages, { + maxMessages: 3, + maxChars: 1_000, + }) + + expect(selected.map(message => String(message.uuid))).toEqual(['u2', 'a2']) + }) + + test('returns no context when the newest message exceeds the byte budget', () => { + const messages = [ + makeMessage('user', 'u1', 'first prompt'), + makeMessage('assistant', 'a1', 'x'.repeat(100)), + ] + + const selected = selectSummaryContextMessages(messages, { + maxMessages: 10, + maxChars: 10, + }) + + expect(selected).toEqual([]) + }) + + test('uses serialized message size for nested content budgets', () => { + const messages = [ + makeMessage('user', 'u1', 'first prompt'), + { + ...makeMessage('assistant', 'a1', 'short'), + nested: { + payload: Array.from({ length: 50 }, (_value, index) => ({ + index, + text: 'x'.repeat(20), + })), + }, + } as unknown as Message, + ] + + const selected = selectSummaryContextMessages(messages, { + maxMessages: 10, + maxChars: 200, + }) + + expect(selected).toEqual([]) + }) + + test('drops leading orphan tool results after bounding', () => { + const messages = [ + makeMessage('assistant', 'a0', 'older assistant'), + { + type: 'user', + uuid: 'u1', + message: { + role: 'user', + content: [ + { type: 'tool_result', tool_use_id: 'tool-1', content: 'ok' }, + ], + }, + } as unknown as Message, + makeMessage('assistant', 'a1', 'after orphan'), + makeMessage('user', 'u2', 'next prompt'), + ] + + const selected = selectSummaryContextMessages(messages, { + maxMessages: 3, + maxChars: 1_000, + }) + + expect(selected.map(message => String(message.uuid))).toEqual(['u2']) + }) +}) + +describe('getSummaryContextFingerprint', () => { + test('changes when the transcript grows', () => { + const messages = [ + makeMessage('user', 'u1', 'first prompt'), + makeMessage('assistant', 'a1', 'first response'), + ] + + const first = getSummaryContextFingerprint(messages) + const second = getSummaryContextFingerprint([ + ...messages, + makeMessage('user', 'u2', 'next prompt'), + ]) + expect(first?.startsWith('2:a1:')).toBe(true) + expect(second?.startsWith('3:u2:')).toBe(true) + expect(first).not.toBe(second) + }) + + test('changes when message content changes under the same uuid', () => { + const first = getSummaryContextFingerprint([ + makeMessage('user', 'u1', 'first prompt'), + makeMessage('assistant', 'a1', 'first response'), + ]) + const second = getSummaryContextFingerprint([ + makeMessage('user', 'u1', 'first prompt'), + makeMessage('assistant', 'a1', 'updated response'), + ]) + + expect(first).not.toBe(second) + }) +}) diff --git a/src/services/AgentSummary/agentSummary.ts b/src/services/AgentSummary/agentSummary.ts index 50146b3c7..2232e839d 100644 --- a/src/services/AgentSummary/agentSummary.ts +++ b/src/services/AgentSummary/agentSummary.ts @@ -23,6 +23,10 @@ import { import { logError } from '../../utils/log.js' import { createUserMessage } from '../../utils/messages.js' import { getAgentTranscript } from '../../utils/sessionStorage.js' +import { + getSummaryContextFingerprint, + selectSummaryContextMessages, +} from './summaryContext.js' const SUMMARY_INTERVAL_MS = 30_000 @@ -58,6 +62,7 @@ export function startAgentSummarization( let timeoutId: ReturnType | null = null let stopped = false let previousSummary: string | null = null + let lastHandledTranscriptFingerprint: string | null = null async function runSummary(): Promise { if (stopped) return @@ -82,15 +87,35 @@ export function startAgentSummarization( // Filter to clean message state const cleanMessages = filterIncompleteToolCalls(transcript.messages) + const summaryContext = filterIncompleteToolCalls( + selectSummaryContextMessages(cleanMessages), + ) + const transcriptFingerprint = getSummaryContextFingerprint(summaryContext) + if ( + transcriptFingerprint && + transcriptFingerprint === lastHandledTranscriptFingerprint + ) { + logForDebugging( + `[AgentSummary] Skipping summary for ${taskId}: transcript unchanged`, + ) + return + } + + if (summaryContext.length < 3) { + logForDebugging( + `[AgentSummary] Skipping summary for ${taskId}: no bounded context available`, + ) + return + } // Build fork params with current messages const forkParams: CacheSafeParams = { ...baseParams, - forkContextMessages: cleanMessages, + forkContextMessages: summaryContext, } logForDebugging( - `[AgentSummary] Forking for summary, ${cleanMessages.length} messages in context`, + `[AgentSummary] Forking for summary, ${summaryContext.length} messages in context`, ) // Create abort controller for this summary @@ -136,13 +161,16 @@ export function startAgentSummarization( ) continue } - const contentArr = Array.isArray(msg.message!.content) ? msg.message!.content : [] + const contentArr = Array.isArray(msg.message!.content) + ? msg.message!.content + : [] const textBlock = contentArr.find(b => b.type === 'text') if (textBlock?.type === 'text' && textBlock.text.trim()) { const summaryText = textBlock.text.trim() logForDebugging( `[AgentSummary] Summary result for ${taskId}: ${summaryText}`, ) + lastHandledTranscriptFingerprint = transcriptFingerprint previousSummary = summaryText updateAgentSummary(taskId, summaryText, setAppState) break diff --git a/src/services/AgentSummary/summaryContext.ts b/src/services/AgentSummary/summaryContext.ts new file mode 100644 index 000000000..4d9f6a6ce --- /dev/null +++ b/src/services/AgentSummary/summaryContext.ts @@ -0,0 +1,175 @@ +import { createHash } from 'crypto' +import type { Message } from '../../types/message.js' + +export const MAX_SUMMARY_CONTEXT_MESSAGES = 120 +export const MAX_SUMMARY_CONTEXT_CHARS = 200_000 + +function estimateJsonChars( + value: unknown, + limit: number, + seen = new Set(), +): number { + if (value === null) return 4 + switch (typeof value) { + case 'string': + return value.length + 2 + case 'number': + case 'boolean': + return String(value).length + case 'undefined': + case 'function': + case 'symbol': + return 0 + case 'object': { + if (seen.has(value)) return Number.POSITIVE_INFINITY + seen.add(value) + let total = 2 + if (Array.isArray(value)) { + for (let index = 0; index < value.length; index++) { + total += String(index).length + 3 + total += estimateJsonChars(value[index], limit - total, seen) + if (total > limit) return total + } + } else { + const record = value as Record + for (const key in record) { + if (!Object.hasOwn(record, key)) continue + total += key.length + 3 + total += estimateJsonChars(record[key], limit - total, seen) + if (total > limit) return total + } + } + seen.delete(value) + return total + } + } + return 0 +} + +function updateFingerprintHash( + hash: ReturnType, + value: unknown, + limit: { remaining: number }, + seen = new Set(), +): void { + if (limit.remaining <= 0) return + if (value === null || typeof value !== 'object') { + const text = String(value) + hash.update(typeof value) + hash.update(':') + hash.update(text.slice(0, limit.remaining)) + limit.remaining -= text.length + return + } + if (seen.has(value)) { + hash.update('[Circular]') + return + } + seen.add(value) + if (Array.isArray(value)) { + for (let index = 0; index < value.length; index++) { + if (limit.remaining <= 0) break + const key = String(index) + hash.update(key) + limit.remaining -= key.length + updateFingerprintHash(hash, value[index], limit, seen) + } + } else { + const record = value as Record + for (const key in record) { + if (limit.remaining <= 0) break + if (!Object.hasOwn(record, key)) continue + hash.update(key) + limit.remaining -= key.length + updateFingerprintHash(hash, record[key], limit, seen) + } + } + seen.delete(value) +} + +export function estimateMessageChars( + message: Message, + limit = Number.POSITIVE_INFINITY, +): number { + const estimated = estimateJsonChars(message, limit) + if (!Number.isFinite(estimated)) { + return Number.POSITIVE_INFINITY + } + return estimated +} + +function hasToolResultBlock(message: Message): boolean { + if (message.type !== 'user') return false + const content = message.message?.content + return ( + Array.isArray(content) && + content.some(block => { + return Boolean( + block && + typeof block === 'object' && + 'type' in block && + block.type === 'tool_result', + ) + }) + ) +} + +export function getSummaryContextFingerprint( + messages: Message[], +): string | null { + const lastMessage = messages.at(-1) + if (!lastMessage) return null + const hash = createHash('sha256') + updateFingerprintHash(hash, messages, { + remaining: MAX_SUMMARY_CONTEXT_CHARS, + }) + return `${messages.length}:${lastMessage.uuid}:${hash.digest('hex').slice(0, 16)}` +} + +export function selectSummaryContextMessages( + messages: Message[], + limits: { + maxMessages?: number + maxChars?: number + } = {}, +): Message[] { + const maxMessages = limits.maxMessages ?? MAX_SUMMARY_CONTEXT_MESSAGES + const maxChars = limits.maxChars ?? MAX_SUMMARY_CONTEXT_CHARS + if (maxMessages <= 0 || maxChars <= 0) return [] + + const selected: Message[] = [] + let selectedChars = 0 + + for (let i = messages.length - 1; i >= 0; i--) { + const message = messages[i] + if (!message) continue + + const messageChars = estimateMessageChars(message, maxChars - selectedChars) + if (messageChars > maxChars) { + if (selected.length === 0) return [] + break + } + + if ( + selected.length >= maxMessages || + selectedChars + messageChars > maxChars + ) { + break + } + + selected.unshift(message) + selectedChars += messageChars + } + + while (selected.length > 0) { + const first = selected[0] + if (!first) break + if (first.type !== 'user' || hasToolResultBlock(first)) { + selected.shift() + continue + } + break + } + + return selected +} diff --git a/src/utils/__tests__/ndjsonFramer.test.ts b/src/utils/__tests__/ndjsonFramer.test.ts new file mode 100644 index 000000000..35174162a --- /dev/null +++ b/src/utils/__tests__/ndjsonFramer.test.ts @@ -0,0 +1,91 @@ +import { EventEmitter } from 'node:events' +import type { Socket } from 'node:net' +import { describe, expect, test } from 'bun:test' +import { attachNdjsonFramer } from '../ndjsonFramer.js' + +type TestSocket = Socket & { + destroyed: boolean + emitData: (chunk: Buffer) => void +} + +function createTestSocket(): TestSocket { + const emitter = new EventEmitter() as TestSocket + emitter.destroyed = false + emitter.destroy = ((_error?: Error) => { + emitter.destroyed = true + emitter.emit('close') + return emitter + }) as TestSocket['destroy'] + emitter.emitData = (chunk: Buffer) => { + emitter.emit('data', chunk) + } + return emitter +} + +describe('attachNdjsonFramer', () => { + test('accepts a complete frame at the configured byte limit', () => { + const socket = createTestSocket() + const messages: unknown[] = [] + const errors: Error[] = [] + + attachNdjsonFramer( + socket, + msg => messages.push(msg), + text => JSON.parse(text) as unknown, + { + maxFrameBytes: Buffer.byteLength('{"a":1}', 'utf8'), + onFrameError: error => errors.push(error), + }, + ) + + socket.emitData(Buffer.from('{"a":1}\n')) + + expect(messages).toEqual([{ a: 1 }]) + expect(errors).toEqual([]) + expect(socket.destroyed).toBe(false) + }) + + test('destroys a complete frame over the configured byte limit', () => { + const socket = createTestSocket() + const messages: unknown[] = [] + const errors: Error[] = [] + + attachNdjsonFramer( + socket, + msg => messages.push(msg), + text => JSON.parse(text) as unknown, + { + maxFrameBytes: 8, + onFrameError: error => errors.push(error), + }, + ) + + socket.emitData(Buffer.from('{"long":true}\n')) + + expect(messages).toEqual([]) + expect(errors[0]?.message).toContain('NDJSON frame exceeded') + expect(socket.destroyed).toBe(true) + }) + + test('destroys oversized no-newline input before a frame can form', () => { + const socket = createTestSocket() + const messages: unknown[] = [] + const errors: Error[] = [] + + attachNdjsonFramer( + socket, + msg => messages.push(msg), + text => JSON.parse(text) as unknown, + { + maxFrameBytes: 8, + onFrameError: error => errors.push(error), + }, + ) + + socket.emitData(Buffer.from('x'.repeat(9))) + + expect(messages).toEqual([]) + expect(errors[0]?.message).toContain('NDJSON frame exceeded') + expect(socket.destroyed).toBe(true) + }) +}) diff --git a/src/utils/__tests__/teammateMailbox.test.ts b/src/utils/__tests__/teammateMailbox.test.ts new file mode 100644 index 000000000..577c4331f --- /dev/null +++ b/src/utils/__tests__/teammateMailbox.test.ts @@ -0,0 +1,310 @@ +import { afterEach, beforeEach, describe, expect, test } from 'bun:test' +import { mkdir, readFile, rm, writeFile } from 'node:fs/promises' +import { mkdtempSync } from 'node:fs' +import { tmpdir } from 'node:os' +import { dirname, join } from 'node:path' +import { + compactMailboxMessages, + getInboxPath, + markMessageAsReadByIndex, + markMessageAsReadByIdentity, + markMessagesAsRead, + markMessagesAsReadByPredicate, + MAX_MAILBOX_MESSAGE_TEXT_BYTES, + MAX_MAILBOX_MESSAGES, + MAX_READ_MAILBOX_MESSAGES, + MAX_UNREAD_PROTOCOL_MAILBOX_MESSAGES, + readMailbox, + type TeammateMessage, + writeToMailbox, +} from '../teammateMailbox.js' + +let tempHome = '' +let previousConfigDir: string | undefined + +function message( + text: string, + read: boolean, + timestamp = new Date(0).toISOString(), +): TeammateMessage { + return { + from: 'team-lead', + text, + timestamp, + read, + } +} + +async function seedMailbox( + agentName: string, + teamName: string, + messages: TeammateMessage[], +): Promise { + const inboxPath = getInboxPath(agentName, teamName) + await mkdir(dirname(inboxPath), { recursive: true }) + await writeFile(inboxPath, JSON.stringify(messages, null, 2), 'utf-8') +} + +async function readRawMailbox( + agentName: string, + teamName: string, +): Promise { + const content = await readFile(getInboxPath(agentName, teamName), 'utf-8') + return JSON.parse(content) as TeammateMessage[] +} + +beforeEach(() => { + previousConfigDir = process.env.CLAUDE_CONFIG_DIR + tempHome = mkdtempSync(join(tmpdir(), 'teammate-mailbox-')) + process.env.CLAUDE_CONFIG_DIR = tempHome +}) + +afterEach(async () => { + if (previousConfigDir === undefined) { + delete process.env.CLAUDE_CONFIG_DIR + } else { + process.env.CLAUDE_CONFIG_DIR = previousConfigDir + } + await rm(tempHome, { recursive: true, force: true }) +}) + +describe('compactMailboxMessages', () => { + test('prioritizes unread messages and keeps only recent read history', () => { + const compacted = compactMailboxMessages( + [ + message('read-1', true), + message('read-2', true), + message('unread-1', false), + message('read-3', true), + message('unread-2', false), + message('read-4', true), + message('read-5', true), + message('unread-3', false), + ], + { maxMessages: 5, maxReadMessages: 2 }, + ) + + expect(compacted.map(m => m.text)).toEqual([ + 'unread-1', + 'unread-2', + 'read-4', + 'read-5', + 'unread-3', + ]) + }) + + test('retains unread protocol messages separately from regular cap', () => { + const protocol = message( + JSON.stringify({ type: 'permission_response', request_id: 'req-1' }), + false, + ) + const compacted = compactMailboxMessages( + [ + protocol, + ...Array.from({ length: 5 }, (_value, index) => + message(`regular-${index}`, false), + ), + ], + { + maxMessages: 2, + maxReadMessages: 0, + maxUnreadProtocolMessages: 1, + }, + ) + + expect(compacted.map(m => m.text)).toEqual([ + protocol.text, + 'regular-3', + 'regular-4', + ]) + }) + + test('caps unread protocol messages with an independent bound', () => { + const compacted = compactMailboxMessages( + Array.from( + { length: MAX_UNREAD_PROTOCOL_MAILBOX_MESSAGES + 1 }, + (_value, index) => + message( + JSON.stringify({ + type: 'permission_response', + request_id: `req-${index}`, + }), + false, + ), + ), + ) + + expect(compacted).toHaveLength(MAX_UNREAD_PROTOCOL_MAILBOX_MESSAGES) + expect(compacted[0]?.text).toContain('req-1') + }) + + test('keeps retained mailbox bytes under an explicit budget', () => { + const compacted = compactMailboxMessages( + Array.from({ length: 20 }, (_value, index) => + message(`msg-${index}-${'x'.repeat(200)}`, false), + ), + { + maxMessages: 20, + maxReadMessages: 0, + maxRetainedBytes: 1_000, + }, + ) + + expect( + Buffer.byteLength(JSON.stringify(compacted), 'utf8'), + ).toBeLessThanOrEqual(1_000) + expect(compacted.length).toBeLessThan(20) + expect(compacted.at(-1)?.text).toContain('msg-19') + }) +}) + +describe('teammate mailbox retention', () => { + test('writeToMailbox compacts oversized unread inbox files', async () => { + const existing = Array.from( + { length: MAX_MAILBOX_MESSAGES + 20 }, + (_value, index) => message(`old-${index}`, false), + ) + await seedMailbox('worker', 'alpha', existing) + + await writeToMailbox( + 'worker', + { + from: 'team-lead', + text: 'newest', + timestamp: new Date(1).toISOString(), + }, + 'alpha', + ) + + const after = await readMailbox('worker', 'alpha') + expect(after).toHaveLength(MAX_MAILBOX_MESSAGES) + expect(after[0]?.text).toBe('old-21') + expect(after.at(-1)?.text).toBe('newest') + }) + + test('markMessagesAsRead compacts read history after consumption', async () => { + const existing = Array.from( + { length: MAX_MAILBOX_MESSAGES + 20 }, + (_value, index) => message(`msg-${index}`, false), + ) + await seedMailbox('worker', 'alpha', existing) + + await markMessagesAsRead('worker', 'alpha') + + const after = await readRawMailbox('worker', 'alpha') + expect(after).toHaveLength(MAX_READ_MAILBOX_MESSAGES) + expect(after.every(m => m.read)).toBe(true) + expect(after[0]?.text).toBe( + `msg-${MAX_MAILBOX_MESSAGES + 20 - MAX_READ_MAILBOX_MESSAGES}`, + ) + }) + + test('markMessagesAsReadByPredicate leaves structured messages unread', async () => { + await seedMailbox('worker', 'alpha', [ + message('plain', false), + message(JSON.stringify({ type: 'permission_request' }), false), + ]) + + await markMessagesAsReadByPredicate( + 'worker', + m => !m.text.includes('permission_request'), + 'alpha', + ) + + const after = await readRawMailbox('worker', 'alpha') + expect(after.map(m => m.read)).toEqual([true, false]) + }) + + test('markMessageAsReadByIdentity survives compaction shifting indexes', async () => { + const permissionResponse = message( + JSON.stringify({ type: 'permission_response', request_id: 'req-1' }), + false, + ) + await seedMailbox('worker', 'alpha', [ + permissionResponse, + ...Array.from({ length: MAX_MAILBOX_MESSAGES + 20 }, (_value, index) => + message(`regular-${index}`, false), + ), + ]) + + await writeToMailbox( + 'worker', + { + from: 'team-lead', + text: 'newest', + timestamp: new Date(2).toISOString(), + }, + 'alpha', + ) + const marked = await markMessageAsReadByIdentity( + 'worker', + 'alpha', + permissionResponse, + ) + + const after = await readRawMailbox('worker', 'alpha') + expect(marked).toBe(true) + expect(after.some(m => m.text === permissionResponse.text && !m.read)).toBe( + false, + ) + }) + + test('markMessageAsReadByIndex also compacts through the compatibility path', async () => { + const existing = Array.from( + { length: MAX_MAILBOX_MESSAGES + 10 }, + (_value, index) => message(`msg-${index}`, false), + ) + await seedMailbox('worker', 'alpha', existing) + + await markMessageAsReadByIndex('worker', 'alpha', existing.length - 1) + + const after = await readRawMailbox('worker', 'alpha') + expect(after).toHaveLength(MAX_MAILBOX_MESSAGES) + expect(after.some(m => m.text === `msg-${existing.length - 1}`)).toBe(false) + expect(after.at(-1)?.text).toBe(`msg-${existing.length - 2}`) + }) + + test('writeToMailbox rejects oversized message text instead of storing it', async () => { + await expect( + writeToMailbox( + 'worker', + { + from: 'team-lead', + text: 'x'.repeat(MAX_MAILBOX_MESSAGE_TEXT_BYTES + 1), + timestamp: new Date(3).toISOString(), + }, + 'alpha', + ), + ).rejects.toThrow('Mailbox message text exceeds') + + expect(await readRawMailbox('worker', 'alpha')).toEqual([]) + }) + + test('writeToMailbox fails closed when an existing mailbox is corrupt', async () => { + const inboxPath = getInboxPath('worker', 'alpha') + await mkdir(dirname(inboxPath), { recursive: true }) + await writeFile(inboxPath, '{not-json', 'utf-8') + + await expect( + writeToMailbox( + 'worker', + { + from: 'team-lead', + text: 'new', + timestamp: new Date(4).toISOString(), + }, + 'alpha', + ), + ).rejects.toThrow() + + expect(await readFile(inboxPath, 'utf-8')).toBe('{not-json') + }) + + test('readMailbox fails closed on corrupt mailbox content', async () => { + const inboxPath = getInboxPath('worker', 'alpha') + await mkdir(dirname(inboxPath), { recursive: true }) + await writeFile(inboxPath, '{not-json', 'utf-8') + + await expect(readMailbox('worker', 'alpha')).rejects.toThrow() + }) +}) diff --git a/src/utils/__tests__/udsMessaging.test.ts b/src/utils/__tests__/udsMessaging.test.ts new file mode 100644 index 000000000..52cb57abc --- /dev/null +++ b/src/utils/__tests__/udsMessaging.test.ts @@ -0,0 +1,305 @@ +import { afterEach, describe, expect, test } from 'bun:test' +import { chmod, mkdir, rm, stat, symlink, unlink } from 'node:fs/promises' +import { createConnection, createServer } from 'node:net' +import { dirname, join } from 'node:path' +import { tmpdir } from 'node:os' +import { + drainInbox, + MAX_UDS_INBOX_ENTRIES, + MAX_UDS_INBOX_BYTES, + MAX_UDS_FRAME_BYTES, + parseUdsTarget, + sendUdsMessage, + setOnEnqueue, + startUdsMessaging, + stopUdsMessaging, +} from '../udsMessaging.js' + +function socketPath(label: string): string { + const suffix = `${process.pid}-${Date.now()}-${Math.random().toString(16).slice(2)}-${label}` + if (process.platform === 'win32') { + return `\\\\.\\pipe\\claude-code-test-${suffix}` + } + return join(tmpdir(), 'claude-code-test', `${suffix}.sock`) +} + +function sleep(ms: number): Promise { + return new Promise(resolve => setTimeout(resolve, ms)) +} + +async function waitForEnqueues( + expected: number, + sendMessages: () => Promise, +): Promise { + let count = 0 + let resolveDone: (() => void) | undefined + const done = new Promise(resolve => { + resolveDone = resolve + }) + + setOnEnqueue(() => { + count++ + if (count >= expected) resolveDone?.() + }) + + await sendMessages() + await Promise.race([ + done, + sleep(5_000).then(() => { + throw new Error(`Timed out waiting for ${expected} UDS enqueues`) + }), + ]) + setOnEnqueue(null) +} + +afterEach(async () => { + setOnEnqueue(null) + drainInbox() + await stopUdsMessaging() +}) + +async function closeServer(server: ReturnType): Promise { + await new Promise(resolve => { + server.close(() => resolve()) + }) +} + +describe('UDS inbox retention', () => { + test('drainInbox returns each pending socket message once', async () => { + const path = socketPath('drain') + await startUdsMessaging(path, { isExplicit: true }) + expect(process.env.CLAUDE_CODE_MESSAGING_TOKEN).toBeUndefined() + + await waitForEnqueues(2, async () => { + await sendUdsMessage(path, { type: 'text', data: 'one' }) + await sendUdsMessage(path, { type: 'text', data: 'two' }) + }) + + const drained = drainInbox() + expect(drained.map(entry => entry.message.data)).toEqual(['one', 'two']) + expect(drained.every(entry => entry.status === 'processed')).toBe(true) + expect(drainInbox()).toEqual([]) + }) + + test('inbox is capped when messages arrive faster than they are drained', async () => { + const path = socketPath('cap') + await startUdsMessaging(path, { isExplicit: true }) + + await waitForEnqueues(MAX_UDS_INBOX_ENTRIES, async () => { + for (let i = 0; i < MAX_UDS_INBOX_ENTRIES; i++) { + await sendUdsMessage(path, { type: 'text', data: String(i) }) + } + }) + await expect( + sendUdsMessage(path, { type: 'text', data: 'overflow' }), + ).rejects.toThrow('inbox full') + + const drained = drainInbox() + expect(drained).toHaveLength(MAX_UDS_INBOX_ENTRIES) + expect(drained[0]?.message.data).toBe('0') + expect(drained.at(-1)?.message.data).toBe(String(MAX_UDS_INBOX_ENTRIES - 1)) + }) + + test('inbox is capped by retained bytes before entry count', async () => { + const path = socketPath('byte-cap') + await startUdsMessaging(path, { isExplicit: true }) + + const payload = 'x'.repeat(32 * 1024) + let accepted = 0 + for (;;) { + try { + await sendUdsMessage(path, { type: 'text', data: payload }) + accepted++ + if (accepted > MAX_UDS_INBOX_BYTES / payload.length + 20) { + throw new Error('byte cap was not enforced') + } + } catch (error) { + expect(error).toBeInstanceOf(Error) + expect((error as Error).message).toContain('inbox full') + break + } + } + + const drained = drainInbox() + expect(drained.length).toBe(accepted) + expect(drained.length).toBeLessThan(MAX_UDS_INBOX_ENTRIES) + }) + + test('ping replies with pong without enqueueing inbox work', async () => { + const path = socketPath('ping') + await startUdsMessaging(path, { isExplicit: true }) + + await sendUdsMessage(path, { type: 'ping' }) + expect(drainInbox()).toEqual([]) + }) + + test('drained entries never expose the UDS auth token', async () => { + const path = socketPath('strip-token') + await startUdsMessaging(path, { isExplicit: true }) + + await waitForEnqueues(1, async () => { + await sendUdsMessage(path, { + type: 'notification', + meta: { keep: 'visible' }, + }) + }) + + const drained = drainInbox() + expect(drained).toHaveLength(1) + expect(drained[0]?.message.meta).toEqual({ keep: 'visible' }) + expect(drained[0]?.message.meta).not.toHaveProperty('authToken') + }) + + test('rejects unauthenticated socket messages', async () => { + const path = socketPath('auth') + await startUdsMessaging(path, { isExplicit: true }) + + const response = await new Promise((resolve, reject) => { + const conn = createConnection(path, () => { + conn.write(`${JSON.stringify({ type: 'text', data: 'bad' })}\n`) + }) + conn.setTimeout(5_000, () => { + conn.destroy() + reject(new Error('Timed out waiting for auth rejection')) + }) + conn.on('data', chunk => { + const text = chunk.toString('utf-8') + if (text.includes('\n')) { + conn.end() + resolve(text) + } + }) + conn.on('error', reject) + }) + + expect(JSON.parse(response).type).toBe('error') + expect(drainInbox()).toEqual([]) + }) + + test('destroys oversized frames before enqueueing inbox work', async () => { + const path = socketPath('oversized') + await startUdsMessaging(path, { isExplicit: true }) + + await new Promise((resolve, reject) => { + const conn = createConnection(path, () => { + conn.write('x'.repeat(MAX_UDS_FRAME_BYTES + 1)) + }) + conn.setTimeout(5_000, () => { + conn.destroy() + reject(new Error('Timed out waiting for oversized frame close')) + }) + conn.on('close', () => resolve()) + conn.on('error', () => resolve()) + }) + + expect(drainInbox()).toEqual([]) + }) + + test('rejects oversized receiver responses before retaining them', async () => { + const path = socketPath('oversized-response') + if (process.platform !== 'win32') { + await mkdir(dirname(path), { recursive: true }) + } + const receiver = createServer(socket => { + socket.on('data', () => { + socket.write('x'.repeat(MAX_UDS_FRAME_BYTES + 1)) + }) + }) + await new Promise((resolve, reject) => { + receiver.on('error', reject) + receiver.listen(path, () => resolve()) + }) + + try { + await expect( + sendUdsMessage( + path, + { type: 'text', data: 'hello' }, + { authToken: 'test-token' }, + ), + ).rejects.toThrow('UDS response frame exceeded size limit') + } finally { + await closeServer(receiver) + if (process.platform !== 'win32') { + await unlink(path).catch(() => undefined) + } + } + }) + + test('rejects inline auth token UDS targets instead of parsing them', async () => { + const path = socketPath('inline-token') + + const targetWithToken = `${path}#token=secret` + expect(() => parseUdsTarget(targetWithToken)).toThrow('inline auth token') + try { + parseUdsTarget(targetWithToken) + } catch (error) { + expect((error as Error).message).not.toContain('secret') + } + + const { sendToUdsSocket } = await import('../udsClient.js') + await expect(sendToUdsSocket(targetWithToken, 'hello')).rejects.toThrow( + 'inline auth token', + ) + }) + + if (process.platform !== 'win32') { + test('creates the listening socket with owner-only permissions', async () => { + const path = socketPath('socket-mode') + await startUdsMessaging(path, { isExplicit: true }) + + const mode = (await stat(path)).mode & 0o777 + expect(mode).toBe(0o600) + }) + + test('fails closed when the capability directory is not private', async () => { + const previousConfigDir = process.env.CLAUDE_CONFIG_DIR + const tempHome = join( + tmpdir(), + `uds-capability-${process.pid}-${Date.now()}-${Math.random().toString(16).slice(2)}`, + ) + process.env.CLAUDE_CONFIG_DIR = tempHome + const capabilityDir = join(tempHome, 'messaging-capabilities') + await mkdir(capabilityDir, { recursive: true, mode: 0o755 }) + await chmod(capabilityDir, 0o755) + + try { + await expect( + startUdsMessaging(socketPath('broad-capdir'), { isExplicit: true }), + ).rejects.toThrow('permissions are too broad') + } finally { + if (previousConfigDir === undefined) { + delete process.env.CLAUDE_CONFIG_DIR + } else { + process.env.CLAUDE_CONFIG_DIR = previousConfigDir + } + await rm(tempHome, { recursive: true, force: true }) + } + }) + + test('fails closed when the capability directory is a symlink', async () => { + const previousConfigDir = process.env.CLAUDE_CONFIG_DIR + const tempHome = join( + tmpdir(), + `uds-capability-link-${process.pid}-${Date.now()}-${Math.random().toString(16).slice(2)}`, + ) + const target = join(tempHome, 'target') + process.env.CLAUDE_CONFIG_DIR = tempHome + await mkdir(target, { recursive: true, mode: 0o700 }) + await symlink(target, join(tempHome, 'messaging-capabilities'), 'dir') + + try { + await expect( + startUdsMessaging(socketPath('symlink-capdir'), { isExplicit: true }), + ).rejects.toThrow('not a private directory') + } finally { + if (previousConfigDir === undefined) { + delete process.env.CLAUDE_CONFIG_DIR + } else { + process.env.CLAUDE_CONFIG_DIR = previousConfigDir + } + await rm(tempHome, { recursive: true, force: true }) + } + }) + } +}) diff --git a/src/utils/messages/systemInit.ts b/src/utils/messages/systemInit.ts index fcb9e74d1..4585c7817 100644 --- a/src/utils/messages/systemInit.ts +++ b/src/utils/messages/systemInit.ts @@ -87,8 +87,10 @@ export function buildSystemInitMessage(inputs: SystemInitInputs): SDKMessage { // Hidden from public SDK types — ant-only UDS messaging socket path if (feature('UDS_INBOX')) { /* eslint-disable @typescript-eslint/no-require-imports */ + const udsMessaging = + require('../udsMessaging.js') as typeof import('../udsMessaging.js') ;(initMessage as Record).messaging_socket_path = - require('../udsMessaging.js').getUdsMessagingSocketPath() + udsMessaging.getUdsMessagingSocketPath() /* eslint-enable @typescript-eslint/no-require-imports */ } initMessage.fast_mode_state = getFastModeState(inputs.model, inputs.fastMode) diff --git a/src/utils/ndjsonFramer.ts b/src/utils/ndjsonFramer.ts index 968ee5217..ecaa04dc8 100644 --- a/src/utils/ndjsonFramer.ts +++ b/src/utils/ndjsonFramer.ts @@ -7,6 +7,11 @@ */ import type { Socket } from 'net' +export type NdjsonFramerOptions = { + maxFrameBytes?: number + onFrameError?: (error: Error) => void +} + /** * Attach an NDJSON framer to a socket. Calls `onMessage` for each * complete JSON line received. Malformed lines are silently skipped. @@ -19,21 +24,54 @@ export function attachNdjsonFramer( socket: Socket, onMessage: (msg: T) => void, parse: (text: string) => T = text => JSON.parse(text) as T, + options: NdjsonFramerOptions = {}, ): void { let buffer = '' + const maxFrameBytes = options.maxFrameBytes ?? Number.POSITIVE_INFINITY + + const rejectOversizedFrame = (bytes: number): void => { + const error = new Error( + `NDJSON frame exceeded ${maxFrameBytes} bytes (${bytes})`, + ) + options.onFrameError?.(error) + socket.destroy(error) + } socket.on('data', (chunk: Buffer) => { + if ( + Number.isFinite(maxFrameBytes) && + !chunk.includes(0x0a) && + Buffer.byteLength(buffer, 'utf8') + chunk.byteLength > maxFrameBytes + ) { + rejectOversizedFrame(Buffer.byteLength(buffer, 'utf8') + chunk.byteLength) + return + } + buffer += chunk.toString() const lines = buffer.split('\n') buffer = lines.pop() ?? '' for (const line of lines) { if (!line.trim()) continue + if ( + Number.isFinite(maxFrameBytes) && + Buffer.byteLength(line, 'utf8') > maxFrameBytes + ) { + rejectOversizedFrame(Buffer.byteLength(line, 'utf8')) + return + } try { onMessage(parse(line)) } catch { // Malformed JSON — skip } } + + if ( + Number.isFinite(maxFrameBytes) && + Buffer.byteLength(buffer, 'utf8') > maxFrameBytes + ) { + rejectOversizedFrame(Buffer.byteLength(buffer, 'utf8')) + } }) } diff --git a/src/utils/swarm/inProcessRunner.ts b/src/utils/swarm/inProcessRunner.ts index 1735500b4..eaab58ef7 100644 --- a/src/utils/swarm/inProcessRunner.ts +++ b/src/utils/swarm/inProcessRunner.ts @@ -97,7 +97,7 @@ import { getLastPeerDmSummary, isPermissionResponse, isShutdownRequest, - markMessageAsReadByIndex, + markMessageAsReadByIdentity, readMailbox, writeToMailbox, } from '../teammateMailbox.js' @@ -405,10 +405,10 @@ function createInProcessCanUseTool( if (msg && !msg.read) { const parsed = isPermissionResponse(msg.text) if (parsed && parsed.request_id === request.id) { - await markMessageAsReadByIndex( + await markMessageAsReadByIdentity( identity.agentName, identity.teamName, - i, + msg, ) if (parsed.subtype === 'success') { processMailboxPermissionResponse({ @@ -801,10 +801,10 @@ async function waitForNextPromptOrShutdown( logForDebugging( `[inProcessRunner] ${identity.agentName} received shutdown request from ${shutdownParsed?.from} (prioritized over ${skippedUnread} unread messages)`, ) - await markMessageAsReadByIndex( + await markMessageAsReadByIdentity( identity.agentName, identity.teamName, - shutdownIndex, + msg, ) return { type: 'shutdown_request', @@ -839,10 +839,10 @@ async function waitForNextPromptOrShutdown( logForDebugging( `[inProcessRunner] ${identity.agentName} received new message from ${msg.from} (index ${selectedIndex})`, ) - await markMessageAsReadByIndex( + await markMessageAsReadByIdentity( identity.agentName, identity.teamName, - selectedIndex, + msg, ) return { type: 'new_message', @@ -1246,8 +1246,13 @@ export async function runInProcessTeammate( // Track in-progress tool use IDs for animation in transcript view let inProgressToolUseIDs = task.inProgressToolUseIDs if (message.type === 'assistant') { - for (const block of (Array.isArray(message.message!.content) ? message.message!.content : [])) { - if (typeof block !== 'string' && block.type === 'tool_use') { + for (const block of Array.isArray(message.message!.content) + ? message.message!.content + : []) { + if ( + typeof block !== 'string' && + block.type === 'tool_use' + ) { inProgressToolUseIDs = new Set([ ...(inProgressToolUseIDs ?? []), block.id, @@ -1318,7 +1323,10 @@ export async function runInProcessTeammate( setAppState, ) if (currentAutonomyRunId) { - await markAutonomyRunFailed(currentAutonomyRunId, ERROR_MESSAGE_USER_ABORT) + await markAutonomyRunFailed( + currentAutonomyRunId, + ERROR_MESSAGE_USER_ABORT, + ) currentAutonomyRunId = undefined } } else if (currentAutonomyRunId) { diff --git a/src/utils/teammateMailbox.ts b/src/utils/teammateMailbox.ts index eb72fcc21..6c18fd721 100644 --- a/src/utils/teammateMailbox.ts +++ b/src/utils/teammateMailbox.ts @@ -7,7 +7,8 @@ * Note: Inboxes are keyed by agent name within a team. */ -import { mkdir, readFile, writeFile } from 'fs/promises' +import { randomBytes } from 'crypto' +import { mkdir, readFile, rename, stat, writeFile } from 'fs/promises' import { join } from 'path' import { z } from 'zod/v4' import { TEAMMATE_MESSAGE_TAG } from '../constants/xml.js' @@ -40,6 +41,13 @@ const LOCK_OPTIONS = { }, } +export const MAX_MAILBOX_MESSAGES = 1_000 +export const MAX_READ_MAILBOX_MESSAGES = 200 +export const MAX_UNREAD_PROTOCOL_MAILBOX_MESSAGES = 2_000 +export const MAX_MAILBOX_MESSAGE_TEXT_BYTES = 64 * 1024 +export const MAX_MAILBOX_RETAINED_BYTES = 2 * 1024 * 1024 +export const MAX_MAILBOX_FILE_BYTES = 4 * 1024 * 1024 + export type TeammateMessage = { from: string text: string @@ -49,6 +57,218 @@ export type TeammateMessage = { summary?: string // 5-10 word summary shown as preview in the UI } +function isJsonLikeMessage(text: string): boolean { + const trimmed = text.trimStart() + return trimmed.startsWith('{') || trimmed.startsWith('[') +} + +function shouldRetainUnreadAsProtocolMessage( + message: TeammateMessage, +): boolean { + if (message.read) return false + if (isStructuredProtocolMessage(message.text)) return true + if (!isJsonLikeMessage(message.text)) return false + + try { + const parsed = jsonParse(message.text) + return Boolean( + parsed && + typeof parsed === 'object' && + 'type' in (parsed as Record), + ) + } catch { + return true + } +} + +function sameMailboxMessage(a: TeammateMessage, b: TeammateMessage): boolean { + return a.from === b.from && a.timestamp === b.timestamp && a.text === b.text +} + +function mailboxMessageStorageBytes(message: TeammateMessage): number { + return Buffer.byteLength(jsonStringify(message), 'utf8') +} + +function assertMailboxMessageSize(message: TeammateMessage): void { + const textBytes = Buffer.byteLength(message.text, 'utf8') + if (textBytes > MAX_MAILBOX_MESSAGE_TEXT_BYTES) { + throw new Error( + `Mailbox message text exceeds ${MAX_MAILBOX_MESSAGE_TEXT_BYTES} bytes`, + ) + } +} + +function toMailboxMessage(value: unknown): TeammateMessage { + if (!value || typeof value !== 'object') { + throw new Error('Invalid mailbox message: expected object') + } + const record = value as Record + if ( + typeof record.from !== 'string' || + typeof record.text !== 'string' || + typeof record.timestamp !== 'string' || + typeof record.read !== 'boolean' + ) { + throw new Error('Invalid mailbox message shape') + } + const message: TeammateMessage = { + from: record.from, + text: record.text, + timestamp: record.timestamp, + read: record.read, + ...(typeof record.color === 'string' ? { color: record.color } : {}), + ...(typeof record.summary === 'string' ? { summary: record.summary } : {}), + } + assertMailboxMessageSize(message) + return message +} + +function parseMailboxMessages(content: string): TeammateMessage[] { + const parsed = jsonParse(content) + if (!Array.isArray(parsed)) { + throw new Error('Invalid mailbox file: expected message array') + } + return parsed.map(toMailboxMessage) +} + +async function readMailboxFile(inboxPath: string): Promise { + const info = await stat(inboxPath) + if (info.size > MAX_MAILBOX_FILE_BYTES) { + throw new Error( + `Mailbox file exceeds ${MAX_MAILBOX_FILE_BYTES} bytes: ${inboxPath}`, + ) + } + return readFile(inboxPath, 'utf-8') +} + +async function readMailboxForMutation( + agentName: string, + teamName?: string, +): Promise { + const inboxPath = getInboxPath(agentName, teamName) + return parseMailboxMessages(await readMailboxFile(inboxPath)) +} + +async function writeMailboxAtomic( + inboxPath: string, + content: string, +): Promise { + const bytes = Buffer.byteLength(content, 'utf8') + if (bytes > MAX_MAILBOX_FILE_BYTES) { + throw new Error( + `Compacted mailbox still exceeds ${MAX_MAILBOX_FILE_BYTES} bytes`, + ) + } + const tempPath = `${inboxPath}.${process.pid}.${randomBytes(8).toString('hex')}.tmp` + await writeFile(tempPath, content, 'utf-8') + await rename(tempPath, inboxPath) +} + +export function compactMailboxMessages( + messages: TeammateMessage[], + limits: { + maxMessages?: number + maxReadMessages?: number + maxUnreadProtocolMessages?: number + maxRetainedBytes?: number + } = {}, +): TeammateMessage[] { + const maxMessages = limits.maxMessages ?? MAX_MAILBOX_MESSAGES + const maxReadMessages = limits.maxReadMessages ?? MAX_READ_MAILBOX_MESSAGES + const maxUnreadProtocolMessages = + limits.maxUnreadProtocolMessages ?? MAX_UNREAD_PROTOCOL_MAILBOX_MESSAGES + const maxRetainedBytes = limits.maxRetainedBytes ?? MAX_MAILBOX_RETAINED_BYTES + + if ( + maxRetainedBytes <= 0 || + (maxMessages <= 0 && maxUnreadProtocolMessages <= 0) + ) { + return [] + } + + const keepIndexes = new Set() + let retainedBytes = 0 + let keptUnreadProtocolMessages = 0 + const tryKeep = (index: number): boolean => { + if (keepIndexes.has(index)) return true + const message = messages[index] + if (!message) return false + const bytes = mailboxMessageStorageBytes(message) + if (bytes > maxRetainedBytes || retainedBytes + bytes > maxRetainedBytes) { + return false + } + keepIndexes.add(index) + retainedBytes += bytes + return true + } + + for (let i = messages.length - 1; i >= 0; i--) { + const message = messages[i] + if (!message || !shouldRetainUnreadAsProtocolMessage(message)) continue + if (keptUnreadProtocolMessages >= maxUnreadProtocolMessages) continue + if (tryKeep(i)) keptUnreadProtocolMessages++ + } + + let keptNonProtocolMessages = 0 + for (let i = messages.length - 1; i >= 0; i--) { + if (keptNonProtocolMessages >= maxMessages) break + const message = messages[i] + if ( + message && + !message.read && + !shouldRetainUnreadAsProtocolMessage(message) + ) { + if (tryKeep(i)) keptNonProtocolMessages++ + } + } + + let keptReadMessages = 0 + for (let i = messages.length - 1; i >= 0; i--) { + if (keptNonProtocolMessages >= maxMessages) break + if (keptReadMessages >= maxReadMessages) break + const message = messages[i] + if (message?.read) { + if (tryKeep(i)) { + keptReadMessages++ + keptNonProtocolMessages++ + } + } + } + + return messages.filter((_message, index) => keepIndexes.has(index)) +} + +function logUnreadMailboxEvictions( + original: TeammateMessage[], + compacted: TeammateMessage[], + context: string, +): void { + const kept = new Set(compacted) + const unreadEvicted = original.filter(message => { + return !message.read && !kept.has(message) + }) + if (unreadEvicted.length === 0) return + + const protocolEvicted = count(unreadEvicted, message => + shouldRetainUnreadAsProtocolMessage(message), + ) + logError( + new Error( + `[TeammateMailbox] Compacted ${unreadEvicted.length} unread message(s) in ${context}; protocol_or_unknown=${protocolEvicted}`, + ), + ) +} + +async function writeCompactedMailbox( + inboxPath: string, + messages: TeammateMessage[], + context: string, +): Promise { + const compacted = compactMailboxMessages(messages) + logUnreadMailboxEvictions(messages, compacted, context) + await writeMailboxAtomic(inboxPath, jsonStringify(compacted, null, 2)) +} + /** * Get the path to a teammate's inbox file * Structure: ~/.claude/teams/{team_name}/inboxes/{agent_name}.json @@ -89,8 +309,7 @@ export async function readMailbox( logForDebugging(`[TeammateMailbox] readMailbox: path=${inboxPath}`) try { - const content = await readFile(inboxPath, 'utf-8') - const messages = jsonParse(content) as TeammateMessage[] + const messages = parseMailboxMessages(await readMailboxFile(inboxPath)) logForDebugging( `[TeammateMailbox] readMailbox: read ${messages.length} message(s)`, ) @@ -103,7 +322,7 @@ export async function readMailbox( } logForDebugging(`Failed to read inbox for ${agentName}: ${error}`) logError(error) - return [] + throw error } } @@ -156,7 +375,7 @@ export async function writeToMailbox( `[TeammateMailbox] writeToMailbox: failed to create inbox file: ${error}`, ) logError(error) - return + throw error } } @@ -168,22 +387,23 @@ export async function writeToMailbox( }) // Re-read messages after acquiring lock to get the latest state - const messages = await readMailbox(recipientName, teamName) + const messages = await readMailboxForMutation(recipientName, teamName) - const newMessage: TeammateMessage = { + const newMessage = toMailboxMessage({ ...message, read: false, - } + }) messages.push(newMessage) - await writeFile(inboxPath, jsonStringify(messages, null, 2), 'utf-8') + await writeCompactedMailbox(inboxPath, messages, 'writeToMailbox') logForDebugging( `[TeammateMailbox] Wrote message to ${recipientName}'s inbox from ${message.from}`, ) } catch (error) { logForDebugging(`Failed to write to inbox for ${recipientName}: ${error}`) logError(error) + throw error } finally { if (release) { await release() @@ -222,7 +442,7 @@ export async function markMessageAsReadByIndex( logForDebugging(`[TeammateMailbox] markMessageAsReadByIndex: lock acquired`) // Re-read messages after acquiring lock to get the latest state - const messages = await readMailbox(agentName, teamName) + const messages = await readMailboxForMutation(agentName, teamName) logForDebugging( `[TeammateMailbox] markMessageAsReadByIndex: read ${messages.length} messages after lock`, ) @@ -244,7 +464,7 @@ export async function markMessageAsReadByIndex( messages[messageIndex] = { ...message, read: true } - await writeFile(inboxPath, jsonStringify(messages, null, 2), 'utf-8') + await writeCompactedMailbox(inboxPath, messages, 'markMessageAsReadByIndex') logForDebugging( `[TeammateMailbox] markMessageAsReadByIndex: marked message at index ${messageIndex} as read`, ) @@ -270,6 +490,46 @@ export async function markMessageAsReadByIndex( } } +export async function markMessageAsReadByIdentity( + agentName: string, + teamName: string | undefined, + expectedMessage: TeammateMessage, +): Promise { + const inboxPath = getInboxPath(agentName, teamName) + const lockFilePath = `${inboxPath}.lock` + + let release: (() => Promise) | undefined + try { + release = await lockfile.lock(inboxPath, { + lockfilePath: lockFilePath, + ...LOCK_OPTIONS, + }) + + const messages = await readMailboxForMutation(agentName, teamName) + const messageIndex = messages.findIndex(message => { + return !message.read && sameMailboxMessage(message, expectedMessage) + }) + if (messageIndex < 0) return false + + messages[messageIndex] = { ...messages[messageIndex]!, read: true } + await writeCompactedMailbox( + inboxPath, + messages, + 'markMessageAsReadByIdentity', + ) + return true + } catch (error) { + const code = getErrnoCode(error) + if (code === 'ENOENT') return false + logError(error) + return false + } finally { + if (release) { + await release() + } + } +} + /** * Mark all messages in a teammate's inbox as read * Uses file locking to prevent race conditions @@ -297,7 +557,7 @@ export async function markMessagesAsRead( logForDebugging(`[TeammateMailbox] markMessagesAsRead: lock acquired`) // Re-read messages after acquiring lock to get the latest state - const messages = await readMailbox(agentName, teamName) + const messages = await readMailboxForMutation(agentName, teamName) logForDebugging( `[TeammateMailbox] markMessagesAsRead: read ${messages.length} messages after lock`, ) @@ -317,7 +577,7 @@ export async function markMessagesAsRead( // messages comes from jsonParse — fresh, unshared objects safe to mutate for (const m of messages) m.read = true - await writeFile(inboxPath, jsonStringify(messages, null, 2), 'utf-8') + await writeCompactedMailbox(inboxPath, messages, 'markMessagesAsRead') logForDebugging( `[TeammateMailbox] markMessagesAsRead: WROTE ${unreadCount} message(s) as read to ${inboxPath}`, ) @@ -1114,7 +1374,7 @@ export async function markMessagesAsReadByPredicate( ...LOCK_OPTIONS, }) - const messages = await readMailbox(agentName, teamName) + const messages = await readMailboxForMutation(agentName, teamName) if (messages.length === 0) { return } @@ -1123,7 +1383,11 @@ export async function markMessagesAsReadByPredicate( !m.read && predicate(m) ? { ...m, read: true } : m, ) - await writeFile(inboxPath, jsonStringify(updatedMessages, null, 2), 'utf-8') + await writeCompactedMailbox( + inboxPath, + updatedMessages, + 'markMessagesAsReadByPredicate', + ) } catch (error) { const code = getErrnoCode(error) if (code === 'ENOENT') { @@ -1161,7 +1425,12 @@ export function getLastPeerDmSummary(messages: Message[]): string | undefined { if (!Array.isArray(content)) continue for (const block of content) { if (typeof block === 'string') continue - const b = block as unknown as { type: string; name?: string; input?: Record; [key: string]: unknown } + const b = block as unknown as { + type: string + name?: string + input?: Record + [key: string]: unknown + } if ( b.type === 'tool_use' && b.name === SEND_MESSAGE_TOOL_NAME && @@ -1177,7 +1446,7 @@ export function getLastPeerDmSummary(messages: Message[]): string | undefined { const to = b.input.to as string const summary = 'summary' in b.input && typeof b.input.summary === 'string' - ? b.input.summary as string + ? (b.input.summary as string) : (b.input.message as string).slice(0, 80) return `[to ${to}] ${summary}` } diff --git a/src/utils/udsClient.ts b/src/utils/udsClient.ts index 781f3ddd1..f08bee696 100644 --- a/src/utils/udsClient.ts +++ b/src/utils/udsClient.ts @@ -16,7 +16,7 @@ import { errorMessage, isFsInaccessible } from './errors.js' import { isProcessRunning } from './genericProcessUtils.js' import { jsonParse, jsonStringify } from './slowOperations.js' import type { SessionKind } from './concurrentSessions.js' -import type { UdsMessage } from './udsMessaging.js' +import { MAX_UDS_FRAME_BYTES, type UdsMessage } from './udsMessaging.js' // --------------------------------------------------------------------------- // Types @@ -43,6 +43,12 @@ function getSessionsDir(): string { return join(getClaudeConfigHomeDir(), 'sessions') } +function getChunkBytes(chunk: string | Buffer): number { + return typeof chunk === 'string' + ? Buffer.byteLength(chunk, 'utf8') + : chunk.byteLength +} + // --------------------------------------------------------------------------- // Discovery // --------------------------------------------------------------------------- @@ -104,9 +110,14 @@ export async function listAllLiveSessions(): Promise { */ export async function listPeers(): Promise { const all = await listAllLiveSessions() - return all.filter( - s => s.pid !== process.pid && s.messagingSocketPath != null, - ) + return all.filter(s => s.pid !== process.pid && s.messagingSocketPath != null) +} + +async function findAuthTokenForSocketPath( + socketPath: string, +): Promise { + const { readUdsCapabilityToken } = await import('./udsMessaging.js') + return readUdsCapabilityToken(socketPath) } // --------------------------------------------------------------------------- @@ -117,10 +128,21 @@ export async function listPeers(): Promise { * Probe a UDS socket to check if a server is listening (ping/pong). * Returns true if the peer responds within the timeout. */ -export async function isPeerAlive(socketPath: string, timeoutMs = 3000): Promise { - return new Promise((resolve) => { +export async function isPeerAlive( + socketPath: string, + timeoutMs = 3000, + authToken?: string, +): Promise { + const token = authToken ?? (await findAuthTokenForSocketPath(socketPath)) + if (!token) return false + + return new Promise(resolve => { const conn = createConnection(socketPath, () => { - const ping: UdsMessage = { type: 'ping', ts: new Date().toISOString() } + const ping: UdsMessage = { + type: 'ping', + ts: new Date().toISOString(), + meta: { authToken: token }, + } conn.write(jsonStringify(ping) + '\n') }) @@ -135,7 +157,19 @@ export async function isPeerAlive(socketPath: string, timeoutMs = 3000): Promise }, timeoutMs) let buffer = '' - conn.on('data', (chunk) => { + conn.on('data', chunk => { + if ( + Buffer.byteLength(buffer, 'utf8') + getChunkBytes(chunk) > + MAX_UDS_FRAME_BYTES + ) { + if (!resolved) { + resolved = true + clearTimeout(timer) + conn.destroy() + resolve(false) + } + return + } buffer += chunk.toString() if (buffer.includes('"pong"')) { if (!resolved) { @@ -165,6 +199,13 @@ export async function sendToUdsSocket( targetSocketPath: string, message: string | Record, ): Promise { + const { parseUdsTarget } = await import('./udsMessaging.js') + const target = parseUdsTarget(targetSocketPath) + const authToken = await findAuthTokenForSocketPath(target.socketPath) + if (!authToken) { + throw new Error(`No auth token found for peer at ${target.socketPath}`) + } + const data = typeof message === 'string' ? message : jsonStringify(message) const udsMsg: UdsMessage = { type: 'text', @@ -177,18 +218,59 @@ export async function sendToUdsSocket( udsMsg.from = getUdsMessagingSocketPath() return new Promise((resolve, reject) => { - const conn = createConnection(targetSocketPath, () => { - conn.write(jsonStringify(udsMsg) + '\n', (err) => { - conn.end() - if (err) reject(err) - else resolve() + let buffer = '' + let settled = false + const finish = (error?: Error): void => { + if (settled) return + settled = true + conn.end() + if (error) reject(error) + else resolve() + } + const conn = createConnection(target.socketPath, () => { + udsMsg.meta = { ...udsMsg.meta, authToken } + conn.write(jsonStringify(udsMsg) + '\n', err => { + if (err) finish(err) }) }) - conn.on('error', (err) => { - reject(new Error(`Failed to connect to peer at ${targetSocketPath}: ${errorMessage(err)}`)) + conn.on('data', chunk => { + if ( + Buffer.byteLength(buffer, 'utf8') + getChunkBytes(chunk) > + MAX_UDS_FRAME_BYTES + ) { + finish(new Error('UDS response frame exceeded size limit')) + return + } + buffer += chunk.toString() + const lines = buffer.split('\n') + buffer = lines.pop() ?? '' + for (const line of lines) { + if (!line.trim()) continue + let response: UdsMessage + try { + response = jsonParse(line) as UdsMessage + } catch { + continue + } + if (response.type === 'response') { + finish() + return + } + if (response.type === 'error') { + finish(new Error(response.data ?? 'UDS receiver rejected message')) + return + } + } + }) + conn.on('error', err => { + finish( + new Error( + `Failed to connect to peer at ${target.socketPath}: ${errorMessage(err)}`, + ), + ) }) conn.setTimeout(5000, () => { - conn.destroy(new Error('Connection timed out')) + finish(new Error('Connection timed out')) }) }) } diff --git a/src/utils/udsMessaging.ts b/src/utils/udsMessaging.ts index 1c95ab63c..7efa7fbf6 100644 --- a/src/utils/udsMessaging.ts +++ b/src/utils/udsMessaging.ts @@ -8,14 +8,25 @@ * but can be overridden via --messaging-socket-path. */ +import { createHash, randomBytes } from 'crypto' import { createServer, type Server, type Socket } from 'net' -import { mkdir, unlink } from 'fs/promises' +import { + chmod, + lstat, + mkdir, + open, + readFile, + rename, + unlink, +} from 'fs/promises' import { dirname, join } from 'path' import { tmpdir } from 'os' import { registerCleanup } from './cleanupRegistry.js' import { logForDebugging } from './debug.js' import { errorMessage } from './errors.js' +import { getClaudeConfigHomeDir } from './envUtils.js' import { attachNdjsonFramer } from './ndjsonFramer.js' +import { logError } from './log.js' import { jsonParse, jsonStringify } from './slowOperations.js' // --------------------------------------------------------------------------- @@ -27,6 +38,7 @@ export type UdsMessageType = | 'notification' | 'query' | 'response' + | 'error' | 'ping' | 'pong' @@ -60,6 +72,15 @@ let onEnqueueCb: (() => void) | null = null const clients = new Set() const inbox: UdsInboxEntry[] = [] let nextId = 1 +let defaultSocketPath: string | null = null +let authToken: string | null = null +let capabilityFilePath: string | null = null +let inboxBytes = 0 + +export const MAX_UDS_INBOX_ENTRIES = 1_000 +export const MAX_UDS_FRAME_BYTES = 64 * 1024 +export const MAX_UDS_INBOX_BYTES = 2 * 1024 * 1024 +export const MAX_UDS_CLIENTS = 128 // --------------------------------------------------------------------------- // Public API — socket path helpers @@ -74,10 +95,19 @@ let nextId = 1 * transparently, but we use the pipe format on Windows for Node.js compat. */ export function getDefaultUdsSocketPath(): string { + if (defaultSocketPath) return defaultSocketPath + const nonce = randomBytes(16).toString('hex') if (process.platform === 'win32') { - return `\\\\.\\pipe\\claude-code-${process.pid}` + defaultSocketPath = `\\\\.\\pipe\\claude-code-${process.pid}-${nonce}` + return defaultSocketPath } - return join(tmpdir(), 'claude-code-socks', `${process.pid}.sock`) + defaultSocketPath = join( + tmpdir(), + 'claude-code-socks', + `${process.pid}-${nonce}`, + 'messaging.sock', + ) + return defaultSocketPath } /** @@ -88,6 +118,142 @@ export function getUdsMessagingSocketPath(): string | undefined { return socketPath ?? undefined } +export function formatUdsAddress(socket: string): string { + return `uds:${socket}` +} + +export function parseUdsTarget(target: string): { + socketPath: string +} { + if (target.includes('#token=')) { + throw new Error( + 'UDS target must not include an inline auth token; use the ListPeers address', + ) + } + return { socketPath: target } +} + +function getCapabilityDir(): string { + return join(getClaudeConfigHomeDir(), 'messaging-capabilities') +} + +function getCapabilityPath(socket: string): string { + const digest = createHash('sha256').update(socket).digest('hex') + return join(getCapabilityDir(), `${digest}.json`) +} + +function isNotFound(error: unknown): boolean { + return ( + typeof error === 'object' && + error !== null && + (error as NodeJS.ErrnoException).code === 'ENOENT' + ) +} + +async function assertPrivateCapabilityDir(dir: string): Promise { + let stat: Awaited> + try { + stat = await lstat(dir) + } catch (error) { + if (!isNotFound(error)) throw error + await mkdir(dir, { recursive: true, mode: 0o700 }) + stat = await lstat(dir) + } + + if (!stat.isDirectory() || stat.isSymbolicLink()) { + throw new Error( + `[udsMessaging] capability directory is not a private directory: ${dir}`, + ) + } + if (process.platform !== 'win32') { + const broadMode = stat.mode & 0o077 + if (broadMode !== 0) { + throw new Error( + `[udsMessaging] capability directory permissions are too broad: ${dir}`, + ) + } + if (typeof process.getuid === 'function' && stat.uid !== process.getuid()) { + throw new Error( + `[udsMessaging] capability directory owner does not match current user: ${dir}`, + ) + } + } + + await chmod(dir, 0o700) +} + +async function writePrivateFileExclusive( + path: string, + content: string, +): Promise { + const handle = await open(path, 'wx', 0o600) + try { + await handle.writeFile(content, 'utf-8') + } finally { + await handle.close() + } + await chmod(path, 0o600) +} + +async function ensureSocketParent(path: string): Promise { + const dir = dirname(path) + try { + const stat = await lstat(dir) + if (!stat.isDirectory() || stat.isSymbolicLink()) { + throw new Error( + `[udsMessaging] socket parent is not a directory: ${dir}`, + ) + } + return + } catch (error) { + if (!isNotFound(error)) throw error + } + + await mkdir(dir, { recursive: true, mode: 0o700 }) + await chmod(dir, 0o700) +} + +async function writeCapabilityFile( + socket: string, + token: string, +): Promise { + const dir = getCapabilityDir() + await assertPrivateCapabilityDir(dir) + const target = getCapabilityPath(socket) + const temp = `${target}.${process.pid}.${randomBytes(8).toString('hex')}.tmp` + try { + await writePrivateFileExclusive( + temp, + jsonStringify({ socketPath: socket, authToken: token }), + ) + await rename(temp, target) + } catch (error) { + try { + await unlink(temp) + } catch { + // Temp file may not exist if exclusive creation failed. + } + throw error + } + capabilityFilePath = target +} + +export async function readUdsCapabilityToken( + socket: string, +): Promise { + try { + const parsed = jsonParse( + await readFile(getCapabilityPath(socket), 'utf-8'), + ) as Record + if (parsed.socketPath === socket && typeof parsed.authToken === 'string') { + return parsed.authToken + } + } catch { + // Missing or unreadable capability file means the peer is not addressable. + } + return undefined +} + // --------------------------------------------------------------------------- // Inbox // --------------------------------------------------------------------------- @@ -101,16 +267,79 @@ export function setOnEnqueue(cb: (() => void) | null): void { } /** - * Drain all pending inbox messages, marking them processed. + * Drain all pending inbox messages and release retained history. */ export function drainInbox(): UdsInboxEntry[] { - const pending = inbox.filter(e => e.status === 'pending') + const pending = inbox.splice(0, inbox.length) + inboxBytes = 0 for (const entry of pending) { entry.status = 'processed' } return pending } +function getMessageBytes(message: UdsMessage): number { + return Buffer.byteLength(jsonStringify(message), 'utf8') +} + +function enqueueInboxEntry(entry: UdsInboxEntry): boolean { + const entryBytes = getMessageBytes(entry.message) + if ( + entryBytes > MAX_UDS_FRAME_BYTES || + inbox.length >= MAX_UDS_INBOX_ENTRIES || + inboxBytes + entryBytes > MAX_UDS_INBOX_BYTES + ) { + logError( + new Error( + `[udsMessaging] inbox full (${inbox.length}/${MAX_UDS_INBOX_ENTRIES}, ${inboxBytes}/${MAX_UDS_INBOX_BYTES} bytes); dropping message type=${entry.message.type}`, + ), + ) + return false + } + inbox.push(entry) + inboxBytes += entryBytes + return true +} + +function ensureAuthToken(): string { + if (!authToken) { + authToken = randomBytes(32).toString('hex') + } + return authToken +} + +function getMessageAuthToken(message: UdsMessage): string | undefined { + const token = message.meta?.authToken + return typeof token === 'string' ? token : undefined +} + +function isAuthorizedMessage(message: UdsMessage): boolean { + return getMessageAuthToken(message) === authToken +} + +function writeSocketMessage(socket: Socket, message: UdsMessage): void { + if (socket.destroyed) return + socket.write(jsonStringify(message) + '\n') +} + +function stripAuthToken(message: UdsMessage): UdsMessage { + const { authToken: _authToken, ...metaWithoutAuth } = message.meta ?? {} + return { + ...message, + meta: Object.keys(metaWithoutAuth).length > 0 ? metaWithoutAuth : undefined, + } +} + +function withRequestAuthToken(message: UdsMessage, token: string): UdsMessage { + return { + ...message, + meta: { + ...message.meta, + authToken: token, + }, + } +} + // --------------------------------------------------------------------------- // Server // --------------------------------------------------------------------------- @@ -132,7 +361,7 @@ export async function startUdsMessaging( // Ensure parent directory exists (skip on Windows — pipe paths aren't files) if (process.platform !== 'win32') { - await mkdir(dirname(path), { recursive: true }) + await ensureSocketParent(path) } // Clean up stale socket file (skip on Windows — pipe paths aren't files) @@ -144,69 +373,134 @@ export async function startUdsMessaging( } } - socketPath = path + const token = ensureAuthToken() + try { + await writeCapabilityFile(path, token) + socketPath = path - await new Promise((resolve, reject) => { - const srv = createServer(socket => { - clients.add(socket) - logForDebugging( - `[udsMessaging] client connected (total: ${clients.size})`, - ) - - attachNdjsonFramer( - socket, - msg => { - // Handle ping with automatic pong - if (msg.type === 'ping') { - const pong: UdsMessage = { - type: 'pong', - from: socketPath ?? undefined, - ts: new Date().toISOString(), - } - if (!socket.destroyed) { - socket.write(jsonStringify(pong) + '\n') - } - return - } - - // Enqueue into inbox - const entry: UdsInboxEntry = { - id: `uds-${nextId++}`, - message: msg, - receivedAt: Date.now(), - status: 'pending', - } - inbox.push(entry) + await new Promise((resolve, reject) => { + const srv = createServer(socket => { + if (clients.size >= MAX_UDS_CLIENTS) { logForDebugging( - `[udsMessaging] enqueued message type=${msg.type} from=${msg.from ?? 'unknown'}`, + `[udsMessaging] rejected client: ${clients.size}/${MAX_UDS_CLIENTS} clients already connected`, ) - onEnqueueCb?.() - }, - text => jsonParse(text) as UdsMessage, - ) + socket.destroy() + return + } + clients.add(socket) + logForDebugging( + `[udsMessaging] client connected (total: ${clients.size})`, + ) - socket.on('close', () => { - clients.delete(socket) + attachNdjsonFramer( + socket, + msg => { + if (!isAuthorizedMessage(msg)) { + logForDebugging( + `[udsMessaging] rejected unauthenticated message type=${msg.type}`, + ) + if (!socket.destroyed) { + socket.write( + jsonStringify({ + type: 'error', + data: 'unauthorized', + ts: new Date().toISOString(), + } satisfies UdsMessage) + '\n', + ) + } + return + } + + // Handle ping with automatic pong + if (msg.type === 'ping') { + writeSocketMessage(socket, { + type: 'pong', + from: socketPath ?? undefined, + ts: new Date().toISOString(), + }) + return + } + + // Enqueue into inbox + const sanitizedMessage = stripAuthToken(msg) + const entry: UdsInboxEntry = { + id: `uds-${nextId++}`, + message: sanitizedMessage, + receivedAt: Date.now(), + status: 'pending', + } + if (!enqueueInboxEntry(entry)) { + writeSocketMessage(socket, { + type: 'error', + data: 'inbox full', + ts: new Date().toISOString(), + }) + return + } + logForDebugging( + `[udsMessaging] enqueued message type=${msg.type} from=${msg.from ?? 'unknown'}`, + ) + writeSocketMessage(socket, { + type: 'response', + data: 'ok', + ts: new Date().toISOString(), + meta: { id: entry.id }, + }) + onEnqueueCb?.() + }, + text => jsonParse(text) as UdsMessage, + { + maxFrameBytes: MAX_UDS_FRAME_BYTES, + onFrameError: error => { + logForDebugging(`[udsMessaging] ${error.message}`) + }, + }, + ) + + socket.on('close', () => { + clients.delete(socket) + }) + + socket.on('error', err => { + clients.delete(socket) + logForDebugging(`[udsMessaging] client error: ${errorMessage(err)}`) + }) }) - socket.on('error', err => { - clients.delete(socket) - logForDebugging(`[udsMessaging] client error: ${errorMessage(err)}`) + srv.on('error', reject) + + srv.listen(path, () => { + void (async () => { + try { + if (process.platform !== 'win32') { + await chmod(path, 0o600) + } + server = srv + // Export so child processes can discover the socket + process.env.CLAUDE_CODE_MESSAGING_SOCKET = path + logForDebugging( + `[udsMessaging] server listening on ${path}${opts?.isExplicit ? ' (explicit)' : ''}`, + ) + resolve() + } catch (error) { + srv.close(() => reject(error)) + } + })() }) }) - - srv.on('error', reject) - - srv.listen(path, () => { - server = srv - // Export so child processes can discover the socket - process.env.CLAUDE_CODE_MESSAGING_SOCKET = path - logForDebugging( - `[udsMessaging] server listening on ${path}${opts?.isExplicit ? ' (explicit)' : ''}`, - ) - resolve() - }) - }) + } catch (error) { + if (capabilityFilePath) { + try { + await unlink(capabilityFilePath) + } catch { + // Already gone. + } + capabilityFilePath = null + } + socketPath = null + authToken = null + throw error + } // Register cleanup so the socket file is removed on exit registerCleanup(async () => { @@ -230,6 +524,9 @@ export async function stopUdsMessaging(): Promise { server!.close(() => resolve()) }) server = null + inbox.length = 0 + inboxBytes = 0 + onEnqueueCb = null // Remove socket file (skip on Windows — pipe paths aren't files) if (socketPath) { @@ -245,7 +542,30 @@ export async function stopUdsMessaging(): Promise { `[udsMessaging] server stopped, socket removed: ${socketPath}`, ) socketPath = null + authToken = null } + if (capabilityFilePath) { + try { + await unlink(capabilityFilePath) + } catch { + // Already gone + } + capabilityFilePath = null + } +} + +function parseResponseLine(line: string): UdsMessage | null { + try { + return jsonParse(line) as UdsMessage + } catch { + return null + } +} + +function getChunkBytes(chunk: string | Buffer): number { + return typeof chunk === 'string' + ? Buffer.byteLength(chunk, 'utf8') + : chunk.byteLength } /** @@ -255,23 +575,66 @@ export async function stopUdsMessaging(): Promise { export async function sendUdsMessage( targetSocketPath: string, message: UdsMessage, + opts: { authToken?: string } = {}, ): Promise { const { createConnection } = await import('net') - message.from = message.from ?? socketPath ?? undefined - message.ts = message.ts ?? new Date().toISOString() + const token = opts.authToken ?? authToken + if (!token) { + throw new Error('Cannot send UDS message without auth token') + } + const outbound = withRequestAuthToken( + { + ...message, + from: message.from ?? socketPath ?? undefined, + ts: message.ts ?? new Date().toISOString(), + }, + token, + ) return new Promise((resolve, reject) => { + let buffer = '' + let settled = false + const finish = (error?: Error): void => { + if (settled) return + settled = true + conn.end() + if (error) reject(error) + else resolve() + } const conn = createConnection(targetSocketPath, () => { - conn.write(jsonStringify(message) + '\n', err => { - conn.end() - if (err) reject(err) - else resolve() + conn.write(jsonStringify(outbound) + '\n', err => { + if (err) finish(err) }) }) - conn.on('error', reject) + conn.on('data', chunk => { + if ( + Buffer.byteLength(buffer, 'utf8') + getChunkBytes(chunk) > + MAX_UDS_FRAME_BYTES + ) { + finish(new Error('UDS response frame exceeded size limit')) + return + } + buffer += chunk.toString() + const lines = buffer.split('\n') + buffer = lines.pop() ?? '' + for (const line of lines) { + if (!line.trim()) continue + const response = parseResponseLine(line) + if (!response) continue + if (response.type === 'response' || response.type === 'pong') { + finish() + return + } + if (response.type === 'error') { + finish(new Error(response.data ?? 'UDS receiver rejected message')) + return + } + } + }) + conn.on('error', err => finish(err)) // Timeout so we don't hang on unreachable sockets conn.setTimeout(5000, () => { - conn.destroy(new Error('Connection timed out')) + finish(new Error('Connection timed out')) }) }) }