mirror of
https://github.com/claude-code-best/claude-code.git
synced 2026-06-17 22:05:50 +00:00
feat: 实现 SSH Remote — 本地 REPL + 远端工具执行
SSH Remote 允许在本地运行交互式 REPL,同时将工具调用(Bash、文件读写等) 通过 SSH 隧道转发到远程主机执行。 核心模块: - SSHSessionManager: NDJSON 双向通信、权限转发、指数退避重连 - SSHAuthProxy: 本地认证代理 + SSH -R 反向端口转发,nonce 验证 - SSHProbe: 远端主机平台/架构/已有二进制探测 - SSHDeploy: 远端二进制部署(scp) - createSSHSession: 会话编排(probe → deploy → spawn → attach) 新增选项: - --remote-bin: 跳过 probe/deploy,使用自定义远端二进制 - ANTHROPIC_AUTH_NONCE: API 请求认证 nonce header 包含 17 个单元测试和完整文档。
This commit is contained in:
165
src/ssh/SSHAuthProxy.ts
Normal file
165
src/ssh/SSHAuthProxy.ts
Normal file
@@ -0,0 +1,165 @@
|
||||
import { randomUUID } from 'crypto'
|
||||
import { unlinkSync } from 'fs'
|
||||
import { getClaudeAIOAuthTokens } from 'src/utils/auth.js'
|
||||
import { getOauthConfig } from 'src/constants/oauth.js'
|
||||
import { logForDebugging } from 'src/utils/debug.js'
|
||||
|
||||
export interface SSHAuthProxy {
|
||||
stop(): void
|
||||
}
|
||||
|
||||
export interface AuthProxyInfo {
|
||||
proxy: SSHAuthProxy
|
||||
/** Unix socket path or 127.0.0.1:<port> */
|
||||
localAddress: string
|
||||
/** Environment variables to inject into the remote/child CLI process */
|
||||
authEnv: Record<string, string>
|
||||
}
|
||||
|
||||
const isWindows = process.platform === 'win32'
|
||||
|
||||
function resolveAuthHeaders(): Record<string, string> {
|
||||
const apiKey = process.env.ANTHROPIC_API_KEY
|
||||
if (apiKey) {
|
||||
return { 'x-api-key': apiKey }
|
||||
}
|
||||
|
||||
const oauthTokens = getClaudeAIOAuthTokens()
|
||||
if (oauthTokens?.accessToken) {
|
||||
return { Authorization: `Bearer ${oauthTokens.accessToken}` }
|
||||
}
|
||||
|
||||
return {}
|
||||
}
|
||||
|
||||
function resolveUpstreamBaseUrl(): string {
|
||||
return process.env.ANTHROPIC_BASE_URL || getOauthConfig().BASE_API_URL
|
||||
}
|
||||
|
||||
async function proxyFetch(
|
||||
req: Request,
|
||||
nonce: string | null,
|
||||
): Promise<Response> {
|
||||
if (nonce && req.headers.get('x-auth-nonce') !== nonce) {
|
||||
return new Response('Forbidden', { status: 403 })
|
||||
}
|
||||
|
||||
const upstreamBase = resolveUpstreamBaseUrl()
|
||||
const url = new URL(req.url)
|
||||
const upstreamUrl = `${upstreamBase}${url.pathname}${url.search}`
|
||||
|
||||
const authHeaders = resolveAuthHeaders()
|
||||
if (Object.keys(authHeaders).length === 0) {
|
||||
return new Response(
|
||||
JSON.stringify({
|
||||
error: 'No API credentials available on local machine',
|
||||
}),
|
||||
{ status: 401, headers: { 'content-type': 'application/json' } },
|
||||
)
|
||||
}
|
||||
|
||||
const forwardHeaders = new Headers(req.headers)
|
||||
for (const [k, v] of Object.entries(authHeaders)) {
|
||||
forwardHeaders.set(k, v)
|
||||
}
|
||||
forwardHeaders.delete('host')
|
||||
forwardHeaders.delete('x-auth-nonce')
|
||||
|
||||
logForDebugging(
|
||||
`[SSHAuthProxy] ${req.method} ${url.pathname} -> ${upstreamUrl}`,
|
||||
)
|
||||
|
||||
try {
|
||||
const upstreamRes = await fetch(upstreamUrl, {
|
||||
method: req.method,
|
||||
headers: forwardHeaders,
|
||||
body: req.body,
|
||||
// @ts-expect-error Bun supports duplex for streaming request bodies
|
||||
duplex: 'half',
|
||||
})
|
||||
|
||||
const responseHeaders = new Headers(upstreamRes.headers)
|
||||
responseHeaders.delete('content-encoding')
|
||||
responseHeaders.delete('content-length')
|
||||
|
||||
return new Response(upstreamRes.body, {
|
||||
status: upstreamRes.status,
|
||||
statusText: upstreamRes.statusText,
|
||||
headers: responseHeaders,
|
||||
})
|
||||
} catch (err) {
|
||||
const message = err instanceof Error ? err.message : String(err)
|
||||
logForDebugging(`[SSHAuthProxy] upstream error: ${message}`)
|
||||
return new Response(
|
||||
JSON.stringify({ error: `Proxy upstream error: ${message}` }),
|
||||
{ status: 502, headers: { 'content-type': 'application/json' } },
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
export async function createAuthProxy(): Promise<AuthProxyInfo> {
|
||||
const id = randomUUID()
|
||||
|
||||
if (isWindows) {
|
||||
return createTcpAuthProxy(id)
|
||||
}
|
||||
return createUnixSocketAuthProxy(id)
|
||||
}
|
||||
|
||||
async function createUnixSocketAuthProxy(id: string): Promise<AuthProxyInfo> {
|
||||
const socketPath = `/tmp/claude-ssh-auth-${id}.sock`
|
||||
|
||||
const server = Bun.serve({
|
||||
unix: socketPath,
|
||||
fetch: req => proxyFetch(req, null),
|
||||
})
|
||||
|
||||
logForDebugging(`[SSHAuthProxy] listening on unix:${socketPath}`)
|
||||
|
||||
const proxy: SSHAuthProxy = {
|
||||
stop() {
|
||||
server.stop(true)
|
||||
try {
|
||||
unlinkSync(socketPath)
|
||||
} catch {
|
||||
// Socket file may already be cleaned up
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
return {
|
||||
proxy,
|
||||
localAddress: socketPath,
|
||||
authEnv: { ANTHROPIC_AUTH_SOCKET: socketPath },
|
||||
}
|
||||
}
|
||||
|
||||
async function createTcpAuthProxy(id: string): Promise<AuthProxyInfo> {
|
||||
const nonce = randomUUID()
|
||||
|
||||
const server = Bun.serve({
|
||||
port: 0,
|
||||
hostname: '127.0.0.1',
|
||||
fetch: req => proxyFetch(req, nonce),
|
||||
})
|
||||
|
||||
const port = server.port
|
||||
logForDebugging(
|
||||
`[SSHAuthProxy] listening on TCP 127.0.0.1:${port} (nonce-protected)`,
|
||||
)
|
||||
|
||||
const proxy: SSHAuthProxy = {
|
||||
stop() {
|
||||
server.stop(true)
|
||||
},
|
||||
}
|
||||
|
||||
return {
|
||||
proxy,
|
||||
localAddress: `127.0.0.1:${port}`,
|
||||
authEnv: {
|
||||
ANTHROPIC_BASE_URL: `http://127.0.0.1:${port}`,
|
||||
ANTHROPIC_AUTH_NONCE: nonce,
|
||||
},
|
||||
}
|
||||
}
|
||||
123
src/ssh/SSHDeploy.ts
Normal file
123
src/ssh/SSHDeploy.ts
Normal file
@@ -0,0 +1,123 @@
|
||||
import { existsSync } from 'fs'
|
||||
import { resolve } from 'path'
|
||||
import { logForDebugging } from 'src/utils/debug.js'
|
||||
|
||||
const SSH_TIMEOUT_MS = 60_000
|
||||
const REMOTE_BIN_DIR = '~/.local/bin'
|
||||
const REMOTE_CLI_FILE = 'claude-code-cli.js'
|
||||
const REMOTE_WRAPPER = 'claude'
|
||||
|
||||
export interface DeployOptions {
|
||||
host: string
|
||||
remotePlatform: string
|
||||
remoteArch: string
|
||||
localVersion: string
|
||||
onProgress?: (msg: string) => void
|
||||
}
|
||||
|
||||
async function runSshCommand(
|
||||
host: string,
|
||||
command: string,
|
||||
timeoutMs = SSH_TIMEOUT_MS,
|
||||
): Promise<{ stdout: string; stderr: string; exitCode: number }> {
|
||||
const proc = Bun.spawn(['ssh', '-o', 'ConnectTimeout=10', host, command], {
|
||||
stdout: 'pipe',
|
||||
stderr: 'pipe',
|
||||
})
|
||||
|
||||
const timer = setTimeout(() => proc.kill(), timeoutMs)
|
||||
|
||||
try {
|
||||
const [stdout, stderr] = await Promise.all([
|
||||
new Response(proc.stdout).text(),
|
||||
new Response(proc.stderr).text(),
|
||||
])
|
||||
const exitCode = await proc.exited
|
||||
return { stdout: stdout.trim(), stderr: stderr.trim(), exitCode }
|
||||
} finally {
|
||||
clearTimeout(timer)
|
||||
}
|
||||
}
|
||||
|
||||
function findLocalBinary(): string {
|
||||
const projectRoot = resolve(import.meta.dir, '../..')
|
||||
const distPath = resolve(projectRoot, 'dist/cli.js')
|
||||
if (existsSync(distPath)) return distPath
|
||||
|
||||
const devPath = resolve(projectRoot, 'src/entrypoints/cli.tsx')
|
||||
if (existsSync(devPath)) return devPath
|
||||
|
||||
throw new Error(
|
||||
'Cannot find local CLI binary to deploy. Run `bun run build` first.',
|
||||
)
|
||||
}
|
||||
|
||||
export async function deployBinary(options: DeployOptions): Promise<string> {
|
||||
const { host, remotePlatform, remoteArch, localVersion, onProgress } = options
|
||||
|
||||
if (remotePlatform !== 'linux' && remotePlatform !== 'darwin') {
|
||||
throw new Error(
|
||||
`Remote platform "${remotePlatform}" is not supported. Only linux and darwin are supported.`,
|
||||
)
|
||||
}
|
||||
|
||||
logForDebugging(
|
||||
`[SSHDeploy] deploying to ${host} (${remotePlatform}/${remoteArch}, v${localVersion})`,
|
||||
)
|
||||
|
||||
const localBinary = findLocalBinary()
|
||||
logForDebugging(`[SSHDeploy] local binary: ${localBinary}`)
|
||||
|
||||
onProgress?.('Creating remote directory...')
|
||||
const mkdirResult = await runSshCommand(host, `mkdir -p ${REMOTE_BIN_DIR}`)
|
||||
if (mkdirResult.exitCode !== 0) {
|
||||
throw new Error(`Failed to create remote directory: ${mkdirResult.stderr}`)
|
||||
}
|
||||
|
||||
onProgress?.('Uploading binary...')
|
||||
const remotePath = `${REMOTE_BIN_DIR}/${REMOTE_CLI_FILE}`
|
||||
const scpProc = Bun.spawn(
|
||||
['scp', '-o', 'ConnectTimeout=10', localBinary, `${host}:${remotePath}`],
|
||||
{ stdout: 'pipe', stderr: 'pipe' },
|
||||
)
|
||||
const scpTimer = setTimeout(() => scpProc.kill(), SSH_TIMEOUT_MS)
|
||||
const scpStderr = await new Response(scpProc.stderr).text()
|
||||
const scpExit = await scpProc.exited
|
||||
clearTimeout(scpTimer)
|
||||
|
||||
if (scpExit !== 0) {
|
||||
throw new Error(`SCP upload failed (exit ${scpExit}): ${scpStderr.trim()}`)
|
||||
}
|
||||
|
||||
onProgress?.('Installing wrapper script...')
|
||||
const wrapperScript = [
|
||||
`cat > ${REMOTE_BIN_DIR}/${REMOTE_WRAPPER} << 'WRAPPER'`,
|
||||
'#!/bin/sh',
|
||||
`exec bun ${REMOTE_BIN_DIR}/${REMOTE_CLI_FILE} "$@"`,
|
||||
'WRAPPER',
|
||||
`chmod +x ${REMOTE_BIN_DIR}/${REMOTE_WRAPPER}`,
|
||||
].join('\n')
|
||||
|
||||
const wrapperResult = await runSshCommand(host, wrapperScript)
|
||||
if (wrapperResult.exitCode !== 0) {
|
||||
throw new Error(`Failed to install wrapper script: ${wrapperResult.stderr}`)
|
||||
}
|
||||
|
||||
onProgress?.('Verifying installation...')
|
||||
const verifyResult = await runSshCommand(
|
||||
host,
|
||||
`${REMOTE_BIN_DIR}/${REMOTE_WRAPPER} --version`,
|
||||
)
|
||||
if (verifyResult.exitCode !== 0) {
|
||||
throw new Error(
|
||||
`Binary deployed but verification failed (exit ${verifyResult.exitCode}): ${verifyResult.stderr}`,
|
||||
)
|
||||
}
|
||||
|
||||
logForDebugging(
|
||||
`[SSHDeploy] deployed successfully, remote version: ${verifyResult.stdout}`,
|
||||
)
|
||||
onProgress?.(`Deployed v${verifyResult.stdout}`)
|
||||
|
||||
return `${REMOTE_BIN_DIR}/${REMOTE_WRAPPER}`
|
||||
}
|
||||
99
src/ssh/SSHProbe.ts
Normal file
99
src/ssh/SSHProbe.ts
Normal file
@@ -0,0 +1,99 @@
|
||||
import { logForDebugging } from 'src/utils/debug.js'
|
||||
|
||||
const PROBE_TIMEOUT_MS = 15_000
|
||||
|
||||
export interface ProbeResult {
|
||||
hasBinary: boolean
|
||||
remoteVersion: string | null
|
||||
remotePlatform: 'linux' | 'darwin'
|
||||
remoteArch: 'x64' | 'arm64'
|
||||
defaultCwd: string
|
||||
binaryPath: string | null
|
||||
}
|
||||
|
||||
export class SSHProbeError extends Error {
|
||||
constructor(message: string) {
|
||||
super(message)
|
||||
this.name = 'SSHProbeError'
|
||||
}
|
||||
}
|
||||
|
||||
export async function probeRemote(
|
||||
host: string,
|
||||
onProgress?: (msg: string) => void,
|
||||
): Promise<ProbeResult> {
|
||||
onProgress?.('Probing remote host…')
|
||||
|
||||
const proc = Bun.spawn(
|
||||
[
|
||||
'ssh',
|
||||
'-o',
|
||||
'BatchMode=yes',
|
||||
'-o',
|
||||
'ConnectTimeout=10',
|
||||
host,
|
||||
'CLAUDE_BIN=$(test -x "$HOME/.local/bin/claude" && echo "$HOME/.local/bin/claude" || command -v claude 2>/dev/null); echo "$CLAUDE_BIN"; $CLAUDE_BIN --version 2>/dev/null; uname -sm; pwd',
|
||||
],
|
||||
{ stdin: 'ignore', stdout: 'pipe', stderr: 'pipe' },
|
||||
)
|
||||
|
||||
const result = await Promise.race([
|
||||
proc.exited,
|
||||
new Promise<never>((_, reject) =>
|
||||
setTimeout(
|
||||
() =>
|
||||
reject(
|
||||
new SSHProbeError(
|
||||
`SSH probe timed out after ${PROBE_TIMEOUT_MS / 1000}s`,
|
||||
),
|
||||
),
|
||||
PROBE_TIMEOUT_MS,
|
||||
),
|
||||
),
|
||||
])
|
||||
|
||||
const stdout = await new Response(proc.stdout).text()
|
||||
const stderr = await new Response(proc.stderr).text()
|
||||
|
||||
if (result !== 0) {
|
||||
const detail = stderr.trim() || `exit code ${result}`
|
||||
throw new SSHProbeError(`SSH probe failed: ${detail}`)
|
||||
}
|
||||
|
||||
const lines = stdout
|
||||
.split('\n')
|
||||
.map(l => l.trim())
|
||||
.filter(Boolean)
|
||||
logForDebugging(`[SSHProbe] raw lines: ${JSON.stringify(lines)}`)
|
||||
|
||||
const unameIdx = lines.findIndex(l => /^(Linux|Darwin)\s/.test(l))
|
||||
if (unameIdx === -1) {
|
||||
throw new SSHProbeError(
|
||||
'Could not detect remote platform (uname output missing)',
|
||||
)
|
||||
}
|
||||
|
||||
const binaryPath = unameIdx >= 2 ? lines[unameIdx - 2] || null : null
|
||||
const versionLine = unameIdx >= 1 ? lines[unameIdx - 1] || null : null
|
||||
const remoteVersion =
|
||||
versionLine && /^\d+\.\d+/.test(versionLine) ? versionLine : null
|
||||
const hasBinary = binaryPath !== null && binaryPath.startsWith('/')
|
||||
const defaultCwd = lines[unameIdx + 1] || '/'
|
||||
|
||||
const [osName, arch] = lines[unameIdx]!.split(/\s+/)
|
||||
|
||||
const remotePlatform = osName === 'Darwin' ? 'darwin' : 'linux'
|
||||
const remoteArch: 'x64' | 'arm64' =
|
||||
arch === 'aarch64' || arch === 'arm64' ? 'arm64' : 'x64'
|
||||
|
||||
onProgress?.(`Detected ${remotePlatform}/${remoteArch}`)
|
||||
|
||||
return {
|
||||
hasBinary: hasBinary && remoteVersion !== null,
|
||||
remoteVersion,
|
||||
remotePlatform,
|
||||
remoteArch,
|
||||
defaultCwd,
|
||||
binaryPath: hasBinary ? binaryPath : null,
|
||||
}
|
||||
}
|
||||
@@ -1,15 +1,26 @@
|
||||
// Auto-generated stub — replace with real implementation
|
||||
import type { SDKMessage } from '../entrypoints/sdk/coreTypes.js'
|
||||
import type { Subprocess } from 'bun'
|
||||
import type { SDKMessage } from '../entrypoints/agentSdkTypes.js'
|
||||
import type {
|
||||
SDKControlPermissionRequest,
|
||||
StdoutMessage,
|
||||
} from '../entrypoints/sdk/controlTypes.js'
|
||||
import type { PermissionUpdate } from '../types/permissions.js'
|
||||
import { logForDebugging } from '../utils/debug.js'
|
||||
import { jsonParse, jsonStringify } from '../utils/slowOperations.js'
|
||||
import type { RemoteMessageContent } from '../utils/teleport/api.js'
|
||||
|
||||
export interface SSHSessionManagerOptions {
|
||||
onMessage: (sdkMessage: SDKMessage) => void
|
||||
onPermissionRequest: (request: SSHPermissionRequest, requestId: string) => void
|
||||
onPermissionRequest: (
|
||||
request: SSHPermissionRequest,
|
||||
requestId: string,
|
||||
) => void
|
||||
onConnected: () => void
|
||||
onReconnecting: (attempt: number, max: number) => void
|
||||
onDisconnected: () => void
|
||||
onError: (error: Error) => void
|
||||
reconnect?: () => Promise<Subprocess>
|
||||
maxReconnectAttempts?: number
|
||||
}
|
||||
|
||||
export interface SSHPermissionRequest {
|
||||
@@ -26,5 +37,317 @@ export interface SSHSessionManager {
|
||||
disconnect(): void
|
||||
sendMessage(content: RemoteMessageContent): Promise<boolean>
|
||||
sendInterrupt(): void
|
||||
respondToPermissionRequest(requestId: string, response: { behavior: string; message?: string; updatedInput?: unknown }): void
|
||||
respondToPermissionRequest(
|
||||
requestId: string,
|
||||
response: { behavior: string; message?: string; updatedInput?: unknown },
|
||||
): void
|
||||
}
|
||||
|
||||
function isStdoutMessage(value: unknown): value is StdoutMessage {
|
||||
return (
|
||||
typeof value === 'object' &&
|
||||
value !== null &&
|
||||
'type' in value &&
|
||||
typeof (value as Record<string, unknown>).type === 'string'
|
||||
)
|
||||
}
|
||||
|
||||
const BASE_RECONNECT_DELAY_MS = 2_000
|
||||
const MAX_RECONNECT_DELAY_MS = 15_000
|
||||
const DEFAULT_MAX_RECONNECT_ATTEMPTS = 3
|
||||
|
||||
export class SSHSessionManagerImpl implements SSHSessionManager {
|
||||
private proc: Subprocess
|
||||
private options: SSHSessionManagerOptions
|
||||
private connected = false
|
||||
private disconnected = false
|
||||
private readLoopAbort: AbortController | null = null
|
||||
private reconnectAttempt = 0
|
||||
private readonly maxReconnectAttempts: number
|
||||
private userInitiatedDisconnect = false
|
||||
private reconnecting = false
|
||||
|
||||
constructor(proc: Subprocess, options: SSHSessionManagerOptions) {
|
||||
this.proc = proc
|
||||
this.options = options
|
||||
this.maxReconnectAttempts =
|
||||
options.maxReconnectAttempts ?? DEFAULT_MAX_RECONNECT_ATTEMPTS
|
||||
}
|
||||
|
||||
connect(): void {
|
||||
if (this.connected) return
|
||||
|
||||
this.readLoopAbort = new AbortController()
|
||||
this.startReadLoop()
|
||||
this.monitorExit()
|
||||
|
||||
this.connected = true
|
||||
this.options.onConnected()
|
||||
}
|
||||
|
||||
private async startReadLoop(): Promise<void> {
|
||||
const stdout = this.proc.stdout
|
||||
if (!stdout) {
|
||||
this.options.onError(new Error('SSH process stdout is not available'))
|
||||
return
|
||||
}
|
||||
|
||||
const reader = (stdout as ReadableStream<Uint8Array>).getReader()
|
||||
const decoder = new TextDecoder()
|
||||
let lineBuffer = ''
|
||||
|
||||
try {
|
||||
while (!this.disconnected) {
|
||||
const { done, value } = await reader.read()
|
||||
if (done) break
|
||||
|
||||
lineBuffer += decoder.decode(value, { stream: true })
|
||||
const lines = lineBuffer.split('\n')
|
||||
lineBuffer = lines.pop() ?? ''
|
||||
|
||||
for (const line of lines) {
|
||||
const trimmed = line.trim()
|
||||
if (!trimmed) continue
|
||||
this.processLine(trimmed)
|
||||
}
|
||||
}
|
||||
} catch (err) {
|
||||
if (!this.disconnected) {
|
||||
this.options.onError(
|
||||
err instanceof Error ? err : new Error(String(err)),
|
||||
)
|
||||
}
|
||||
} finally {
|
||||
reader.releaseLock()
|
||||
if (!this.disconnected && !this.userInitiatedDisconnect) {
|
||||
void this.handleProcessExit()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private monitorExit(): void {
|
||||
if (this.proc.exitCode !== null) {
|
||||
if (!this.userInitiatedDisconnect) {
|
||||
void this.handleProcessExit()
|
||||
}
|
||||
return
|
||||
}
|
||||
this.proc.exited
|
||||
.then(() => {
|
||||
if (!this.disconnected && !this.userInitiatedDisconnect) {
|
||||
void this.handleProcessExit()
|
||||
}
|
||||
})
|
||||
.catch(() => {
|
||||
if (!this.disconnected && !this.userInitiatedDisconnect) {
|
||||
void this.handleProcessExit()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
private async handleProcessExit(): Promise<void> {
|
||||
if (this.disconnected || this.reconnecting) return
|
||||
this.connected = false
|
||||
|
||||
if (!this.options.reconnect) {
|
||||
this.disconnected = true
|
||||
this.options.onDisconnected()
|
||||
return
|
||||
}
|
||||
|
||||
if (this.reconnectAttempt >= this.maxReconnectAttempts) {
|
||||
this.disconnected = true
|
||||
this.options.onDisconnected()
|
||||
return
|
||||
}
|
||||
|
||||
this.reconnecting = true
|
||||
try {
|
||||
await this.attemptReconnect()
|
||||
} finally {
|
||||
this.reconnecting = false
|
||||
}
|
||||
}
|
||||
|
||||
private async attemptReconnect(): Promise<void> {
|
||||
const reconnect = this.options.reconnect!
|
||||
|
||||
while (this.reconnectAttempt < this.maxReconnectAttempts) {
|
||||
this.reconnectAttempt++
|
||||
this.options.onReconnecting(
|
||||
this.reconnectAttempt,
|
||||
this.maxReconnectAttempts,
|
||||
)
|
||||
|
||||
const delay = Math.min(
|
||||
BASE_RECONNECT_DELAY_MS * 2 ** (this.reconnectAttempt - 1),
|
||||
MAX_RECONNECT_DELAY_MS,
|
||||
)
|
||||
await new Promise<void>(r => setTimeout(r, delay))
|
||||
|
||||
if (this.userInitiatedDisconnect) return
|
||||
|
||||
try {
|
||||
const newProc = await reconnect()
|
||||
this.proc = newProc
|
||||
this.reconnectAttempt = 0
|
||||
this.connected = true
|
||||
this.startReadLoop()
|
||||
this.monitorExit()
|
||||
this.options.onConnected()
|
||||
return
|
||||
} catch (err) {
|
||||
logForDebugging(
|
||||
`[SSH] reconnect attempt ${this.reconnectAttempt} failed: ${err instanceof Error ? err.message : String(err)}`,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
this.disconnected = true
|
||||
this.options.onDisconnected()
|
||||
}
|
||||
|
||||
private processLine(line: string): void {
|
||||
let raw: unknown
|
||||
try {
|
||||
raw = jsonParse(line)
|
||||
} catch {
|
||||
return
|
||||
}
|
||||
|
||||
if (!isStdoutMessage(raw)) return
|
||||
const parsed = raw
|
||||
|
||||
if (parsed.type === 'control_request') {
|
||||
const request = parsed as unknown as {
|
||||
request_id: string
|
||||
request: SDKControlPermissionRequest & { subtype: string }
|
||||
}
|
||||
if (request.request.subtype === 'can_use_tool') {
|
||||
this.options.onPermissionRequest(
|
||||
request.request as unknown as SSHPermissionRequest,
|
||||
request.request_id,
|
||||
)
|
||||
} else {
|
||||
logForDebugging(
|
||||
`[SSH] Unsupported control request subtype: ${request.request.subtype}`,
|
||||
)
|
||||
this.sendErrorResponse(
|
||||
request.request_id,
|
||||
`Unsupported control request subtype: ${request.request.subtype}`,
|
||||
)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if (
|
||||
parsed.type !== 'control_response' &&
|
||||
parsed.type !== 'keep_alive' &&
|
||||
parsed.type !== 'control_cancel_request' &&
|
||||
parsed.type !== 'streamlined_text' &&
|
||||
parsed.type !== 'streamlined_tool_use_summary' &&
|
||||
!(
|
||||
parsed.type === 'system' &&
|
||||
(parsed as Record<string, unknown>).subtype === 'post_turn_summary'
|
||||
)
|
||||
) {
|
||||
this.options.onMessage(parsed as SDKMessage)
|
||||
}
|
||||
}
|
||||
|
||||
private writeToStdin(data: string): boolean {
|
||||
try {
|
||||
const stdin = this.proc.stdin
|
||||
if (!stdin || typeof stdin === 'number' || this.disconnected) return false
|
||||
const encoded = new TextEncoder().encode(data + '\n')
|
||||
;(stdin as unknown as { write(d: Uint8Array): number }).write(encoded)
|
||||
;(stdin as unknown as { flush?(): void }).flush?.()
|
||||
return true
|
||||
} catch {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
async sendMessage(content: RemoteMessageContent): Promise<boolean> {
|
||||
const message = jsonStringify({
|
||||
type: 'user',
|
||||
message: {
|
||||
role: 'user',
|
||||
content,
|
||||
},
|
||||
parent_tool_use_id: null,
|
||||
session_id: '',
|
||||
})
|
||||
return this.writeToStdin(message)
|
||||
}
|
||||
|
||||
sendInterrupt(): void {
|
||||
const request = jsonStringify({
|
||||
type: 'control_request',
|
||||
request_id: crypto.randomUUID(),
|
||||
request: {
|
||||
subtype: 'interrupt',
|
||||
},
|
||||
})
|
||||
this.writeToStdin(request)
|
||||
}
|
||||
|
||||
respondToPermissionRequest(
|
||||
requestId: string,
|
||||
response: { behavior: string; message?: string; updatedInput?: unknown },
|
||||
): void {
|
||||
const msg = jsonStringify({
|
||||
type: 'control_response',
|
||||
response: {
|
||||
subtype: 'success',
|
||||
request_id: requestId,
|
||||
response: {
|
||||
behavior: response.behavior,
|
||||
...(response.behavior === 'allow'
|
||||
? { updatedInput: response.updatedInput }
|
||||
: { message: response.message }),
|
||||
},
|
||||
},
|
||||
})
|
||||
this.writeToStdin(msg)
|
||||
}
|
||||
|
||||
private sendErrorResponse(requestId: string, error: string): void {
|
||||
const response = jsonStringify({
|
||||
type: 'control_response',
|
||||
response: {
|
||||
subtype: 'error',
|
||||
request_id: requestId,
|
||||
error,
|
||||
},
|
||||
})
|
||||
this.writeToStdin(response)
|
||||
}
|
||||
|
||||
disconnect(): void {
|
||||
if (this.disconnected) return
|
||||
this.userInitiatedDisconnect = true
|
||||
this.disconnected = true
|
||||
this.connected = false
|
||||
this.readLoopAbort?.abort()
|
||||
|
||||
try {
|
||||
const stdin = this.proc.stdin
|
||||
if (stdin && typeof stdin !== 'number') {
|
||||
;(stdin as unknown as { end?(): void }).end?.()
|
||||
}
|
||||
} catch {
|
||||
// stdin may already be closed
|
||||
}
|
||||
|
||||
try {
|
||||
this.proc.kill()
|
||||
} catch {
|
||||
// process may already be dead
|
||||
}
|
||||
}
|
||||
|
||||
isConnected(): boolean {
|
||||
return this.connected && !this.disconnected
|
||||
}
|
||||
}
|
||||
|
||||
413
src/ssh/__tests__/SSHSessionManager.test.ts
Normal file
413
src/ssh/__tests__/SSHSessionManager.test.ts
Normal file
@@ -0,0 +1,413 @@
|
||||
import { describe, test, expect, mock, beforeEach } from 'bun:test'
|
||||
import { debugMock } from '../../../tests/mocks/debug'
|
||||
|
||||
mock.module('src/utils/debug.ts', debugMock)
|
||||
|
||||
import { SSHSessionManagerImpl } from '../SSHSessionManager'
|
||||
import type { SSHSessionManagerOptions } from '../SSHSessionManager'
|
||||
import type { Subprocess } from 'bun'
|
||||
|
||||
function createMockSubprocess(options?: {
|
||||
exitCode?: number | null
|
||||
stdoutLines?: string[]
|
||||
}): {
|
||||
proc: Subprocess
|
||||
writeToStdout: (data: string) => void
|
||||
simulateExit: (code?: number) => void
|
||||
} {
|
||||
let stdoutController: ReadableStreamDefaultController<Uint8Array>
|
||||
const exitResolvers: Array<(code: number) => void> = []
|
||||
let exitCode: number | null = options?.exitCode ?? null
|
||||
|
||||
const stdout = new ReadableStream<Uint8Array>({
|
||||
start(controller) {
|
||||
stdoutController = controller
|
||||
if (options?.stdoutLines) {
|
||||
const encoder = new TextEncoder()
|
||||
for (const line of options.stdoutLines) {
|
||||
controller.enqueue(encoder.encode(line + '\n'))
|
||||
}
|
||||
}
|
||||
},
|
||||
})
|
||||
|
||||
const stdinChunks: Uint8Array[] = []
|
||||
const stdin = {
|
||||
write(d: Uint8Array) {
|
||||
stdinChunks.push(d)
|
||||
return d.length
|
||||
},
|
||||
flush() {},
|
||||
end() {},
|
||||
}
|
||||
|
||||
const exited = new Promise<number>(resolve => {
|
||||
exitResolvers.push(resolve)
|
||||
if (exitCode !== null) resolve(exitCode)
|
||||
})
|
||||
|
||||
const proc = {
|
||||
stdout,
|
||||
stdin,
|
||||
stderr: null,
|
||||
get exitCode() {
|
||||
return exitCode
|
||||
},
|
||||
exited,
|
||||
kill: mock(() => {}),
|
||||
pid: 12345,
|
||||
killed: false,
|
||||
signalCode: null,
|
||||
ref: () => {},
|
||||
unref: () => {},
|
||||
} as unknown as Subprocess
|
||||
|
||||
return {
|
||||
proc,
|
||||
writeToStdout(data: string) {
|
||||
const encoder = new TextEncoder()
|
||||
stdoutController.enqueue(encoder.encode(data + '\n'))
|
||||
},
|
||||
simulateExit(code = 0) {
|
||||
exitCode = code
|
||||
try {
|
||||
stdoutController.close()
|
||||
} catch {
|
||||
// may already be closed
|
||||
}
|
||||
for (const resolve of exitResolvers) resolve(code)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
interface MockState {
|
||||
messages: unknown[]
|
||||
permissionRequests: Array<{ request: unknown; requestId: string }>
|
||||
reconnectingCalls: Array<{ attempt: number; max: number }>
|
||||
connectedCount: number
|
||||
disconnectedCount: number
|
||||
errors: Error[]
|
||||
}
|
||||
|
||||
function createMockOptions(
|
||||
overrides?: Partial<SSHSessionManagerOptions>,
|
||||
): SSHSessionManagerOptions & { state: MockState } {
|
||||
const state: MockState = {
|
||||
messages: [],
|
||||
permissionRequests: [],
|
||||
reconnectingCalls: [],
|
||||
connectedCount: 0,
|
||||
disconnectedCount: 0,
|
||||
errors: [],
|
||||
}
|
||||
|
||||
return {
|
||||
state,
|
||||
onMessage: msg => {
|
||||
state.messages.push(msg)
|
||||
},
|
||||
onPermissionRequest: (request, requestId) => {
|
||||
state.permissionRequests.push({ request, requestId })
|
||||
},
|
||||
onConnected: () => {
|
||||
state.connectedCount++
|
||||
},
|
||||
onReconnecting: (attempt, max) => {
|
||||
state.reconnectingCalls.push({ attempt, max })
|
||||
},
|
||||
onDisconnected: () => {
|
||||
state.disconnectedCount++
|
||||
},
|
||||
onError: err => {
|
||||
state.errors.push(err)
|
||||
},
|
||||
...overrides,
|
||||
}
|
||||
}
|
||||
|
||||
describe('SSHSessionManagerImpl', () => {
|
||||
test('connect() sets connected state and calls onConnected', () => {
|
||||
const { proc } = createMockSubprocess()
|
||||
const opts = createMockOptions()
|
||||
const manager = new SSHSessionManagerImpl(proc, opts)
|
||||
|
||||
manager.connect()
|
||||
|
||||
expect(manager.isConnected()).toBe(true)
|
||||
expect(opts.state.connectedCount).toBe(1)
|
||||
})
|
||||
|
||||
test('connect() is idempotent', () => {
|
||||
const { proc } = createMockSubprocess()
|
||||
const opts = createMockOptions()
|
||||
const manager = new SSHSessionManagerImpl(proc, opts)
|
||||
|
||||
manager.connect()
|
||||
manager.connect()
|
||||
|
||||
expect(opts.state.connectedCount).toBe(1)
|
||||
})
|
||||
|
||||
test('disconnect() sets disconnected state and kills process', () => {
|
||||
const { proc } = createMockSubprocess()
|
||||
const opts = createMockOptions()
|
||||
const manager = new SSHSessionManagerImpl(proc, opts)
|
||||
|
||||
manager.connect()
|
||||
manager.disconnect()
|
||||
|
||||
expect(manager.isConnected()).toBe(false)
|
||||
expect((proc.kill as ReturnType<typeof mock>).mock.calls.length).toBe(1)
|
||||
})
|
||||
|
||||
test('disconnect() is idempotent', () => {
|
||||
const { proc } = createMockSubprocess()
|
||||
const opts = createMockOptions()
|
||||
const manager = new SSHSessionManagerImpl(proc, opts)
|
||||
|
||||
manager.connect()
|
||||
manager.disconnect()
|
||||
manager.disconnect()
|
||||
|
||||
expect((proc.kill as ReturnType<typeof mock>).mock.calls.length).toBe(1)
|
||||
})
|
||||
|
||||
test('processLine routes SDK messages to onMessage', async () => {
|
||||
const sdkMessage = JSON.stringify({
|
||||
type: 'assistant',
|
||||
message: { role: 'assistant', content: 'hello' },
|
||||
})
|
||||
|
||||
const { proc, writeToStdout, simulateExit } = createMockSubprocess()
|
||||
const opts = createMockOptions()
|
||||
const manager = new SSHSessionManagerImpl(proc, opts)
|
||||
|
||||
manager.connect()
|
||||
writeToStdout(sdkMessage)
|
||||
|
||||
await new Promise(r => setTimeout(r, 50))
|
||||
simulateExit(0)
|
||||
await new Promise(r => setTimeout(r, 50))
|
||||
|
||||
expect(opts.state.messages.length).toBe(1)
|
||||
expect((opts.state.messages[0] as Record<string, unknown>).type).toBe(
|
||||
'assistant',
|
||||
)
|
||||
})
|
||||
|
||||
test('processLine filters noise types', async () => {
|
||||
const noiseTypes = [
|
||||
'control_response',
|
||||
'keep_alive',
|
||||
'control_cancel_request',
|
||||
'streamlined_text',
|
||||
'streamlined_tool_use_summary',
|
||||
]
|
||||
|
||||
const { proc, writeToStdout, simulateExit } = createMockSubprocess()
|
||||
const opts = createMockOptions()
|
||||
const manager = new SSHSessionManagerImpl(proc, opts)
|
||||
|
||||
manager.connect()
|
||||
|
||||
for (const type of noiseTypes) {
|
||||
writeToStdout(JSON.stringify({ type }))
|
||||
}
|
||||
writeToStdout(
|
||||
JSON.stringify({ type: 'system', subtype: 'post_turn_summary' }),
|
||||
)
|
||||
|
||||
await new Promise(r => setTimeout(r, 50))
|
||||
simulateExit(0)
|
||||
await new Promise(r => setTimeout(r, 50))
|
||||
|
||||
expect(opts.state.messages.length).toBe(0)
|
||||
})
|
||||
|
||||
test('processLine routes control_request to onPermissionRequest', async () => {
|
||||
const controlRequest = JSON.stringify({
|
||||
type: 'control_request',
|
||||
request_id: 'req-123',
|
||||
request: {
|
||||
subtype: 'can_use_tool',
|
||||
tool_name: 'Bash',
|
||||
tool_use_id: 'tool-456',
|
||||
input: { command: 'ls' },
|
||||
},
|
||||
})
|
||||
|
||||
const { proc, writeToStdout, simulateExit } = createMockSubprocess()
|
||||
const opts = createMockOptions()
|
||||
const manager = new SSHSessionManagerImpl(proc, opts)
|
||||
|
||||
manager.connect()
|
||||
writeToStdout(controlRequest)
|
||||
|
||||
await new Promise(r => setTimeout(r, 50))
|
||||
simulateExit(0)
|
||||
await new Promise(r => setTimeout(r, 50))
|
||||
|
||||
expect(opts.state.permissionRequests.length).toBe(1)
|
||||
expect(opts.state.permissionRequests[0]!.requestId).toBe('req-123')
|
||||
})
|
||||
|
||||
test('sendMessage writes NDJSON to stdin', async () => {
|
||||
const { proc } = createMockSubprocess()
|
||||
const opts = createMockOptions()
|
||||
const manager = new SSHSessionManagerImpl(proc, opts)
|
||||
|
||||
manager.connect()
|
||||
const result = await manager.sendMessage('hello world')
|
||||
|
||||
expect(result).toBe(true)
|
||||
})
|
||||
|
||||
test('sendInterrupt writes interrupt control request', () => {
|
||||
const { proc } = createMockSubprocess()
|
||||
const opts = createMockOptions()
|
||||
const manager = new SSHSessionManagerImpl(proc, opts)
|
||||
|
||||
manager.connect()
|
||||
manager.sendInterrupt()
|
||||
|
||||
const stdin = proc.stdin as unknown as { write: ReturnType<typeof mock> }
|
||||
expect(stdin.write).toBeDefined()
|
||||
})
|
||||
|
||||
test('respondToPermissionRequest sends allow response', () => {
|
||||
const { proc } = createMockSubprocess()
|
||||
const opts = createMockOptions()
|
||||
const manager = new SSHSessionManagerImpl(proc, opts)
|
||||
|
||||
manager.connect()
|
||||
manager.respondToPermissionRequest('req-123', {
|
||||
behavior: 'allow',
|
||||
updatedInput: { command: 'ls -la' },
|
||||
})
|
||||
})
|
||||
|
||||
test('respondToPermissionRequest sends deny response', () => {
|
||||
const { proc } = createMockSubprocess()
|
||||
const opts = createMockOptions()
|
||||
const manager = new SSHSessionManagerImpl(proc, opts)
|
||||
|
||||
manager.connect()
|
||||
manager.respondToPermissionRequest('req-123', {
|
||||
behavior: 'deny',
|
||||
message: 'User denied',
|
||||
})
|
||||
})
|
||||
|
||||
test('process exit without reconnect calls onDisconnected', async () => {
|
||||
const { proc, simulateExit } = createMockSubprocess()
|
||||
const opts = createMockOptions()
|
||||
const manager = new SSHSessionManagerImpl(proc, opts)
|
||||
|
||||
manager.connect()
|
||||
simulateExit(1)
|
||||
|
||||
await new Promise(r => setTimeout(r, 100))
|
||||
|
||||
expect(opts.state.disconnectedCount).toBe(1)
|
||||
expect(manager.isConnected()).toBe(false)
|
||||
})
|
||||
|
||||
test('user disconnect does not trigger reconnect', async () => {
|
||||
let reconnectCalled = false
|
||||
const { proc } = createMockSubprocess()
|
||||
const opts = createMockOptions({
|
||||
reconnect: async () => {
|
||||
reconnectCalled = true
|
||||
return createMockSubprocess().proc
|
||||
},
|
||||
maxReconnectAttempts: 3,
|
||||
})
|
||||
const manager = new SSHSessionManagerImpl(proc, opts)
|
||||
|
||||
manager.connect()
|
||||
manager.disconnect()
|
||||
|
||||
await new Promise(r => setTimeout(r, 200))
|
||||
|
||||
expect(reconnectCalled).toBe(false)
|
||||
expect(opts.state.reconnectingCalls.length).toBe(0)
|
||||
})
|
||||
|
||||
test('invalid JSON lines are silently skipped', async () => {
|
||||
const { proc, writeToStdout, simulateExit } = createMockSubprocess()
|
||||
const opts = createMockOptions()
|
||||
const manager = new SSHSessionManagerImpl(proc, opts)
|
||||
|
||||
manager.connect()
|
||||
writeToStdout('not valid json')
|
||||
writeToStdout('{also: broken')
|
||||
writeToStdout(
|
||||
JSON.stringify({ type: 'assistant', message: { role: 'assistant' } }),
|
||||
)
|
||||
|
||||
await new Promise(r => setTimeout(r, 50))
|
||||
simulateExit(0)
|
||||
await new Promise(r => setTimeout(r, 50))
|
||||
|
||||
expect(opts.state.messages.length).toBe(1)
|
||||
expect(opts.state.errors.length).toBe(0)
|
||||
})
|
||||
|
||||
test('non-StdoutMessage objects are skipped', async () => {
|
||||
const { proc, writeToStdout, simulateExit } = createMockSubprocess()
|
||||
const opts = createMockOptions()
|
||||
const manager = new SSHSessionManagerImpl(proc, opts)
|
||||
|
||||
manager.connect()
|
||||
writeToStdout(JSON.stringify({ noTypeField: true }))
|
||||
writeToStdout(JSON.stringify([1, 2, 3]))
|
||||
writeToStdout(JSON.stringify('string'))
|
||||
|
||||
await new Promise(r => setTimeout(r, 50))
|
||||
simulateExit(0)
|
||||
await new Promise(r => setTimeout(r, 50))
|
||||
|
||||
expect(opts.state.messages.length).toBe(0)
|
||||
})
|
||||
|
||||
test('process exit with reconnect factory attempts reconnection', async () => {
|
||||
const { proc: proc1, simulateExit } = createMockSubprocess()
|
||||
const { proc: proc2 } = createMockSubprocess()
|
||||
|
||||
const opts = createMockOptions({
|
||||
reconnect: mock(async () => proc2),
|
||||
maxReconnectAttempts: 3,
|
||||
})
|
||||
const manager = new SSHSessionManagerImpl(proc1, opts)
|
||||
|
||||
manager.connect()
|
||||
simulateExit(1)
|
||||
|
||||
await new Promise(r => setTimeout(r, 3000))
|
||||
|
||||
expect(opts.state.reconnectingCalls.length).toBeGreaterThanOrEqual(1)
|
||||
expect(opts.state.reconnectingCalls[0]!.attempt).toBe(1)
|
||||
expect(opts.state.reconnectingCalls[0]!.max).toBe(3)
|
||||
})
|
||||
|
||||
test('reconnect failure exhausts attempts then disconnects', async () => {
|
||||
const { proc, simulateExit } = createMockSubprocess()
|
||||
|
||||
const opts = createMockOptions({
|
||||
reconnect: mock(async () => {
|
||||
throw new Error('SSH connection refused')
|
||||
}),
|
||||
maxReconnectAttempts: 2,
|
||||
})
|
||||
const manager = new SSHSessionManagerImpl(proc, opts)
|
||||
|
||||
manager.connect()
|
||||
simulateExit(1)
|
||||
|
||||
await new Promise(r => setTimeout(r, 12000))
|
||||
|
||||
expect(opts.state.reconnectingCalls.length).toBe(2)
|
||||
expect(opts.state.disconnectedCount).toBe(1)
|
||||
expect(manager.isConnected()).toBe(false)
|
||||
}, 15000)
|
||||
})
|
||||
@@ -1,10 +1,21 @@
|
||||
// Auto-generated stub — replace with real implementation
|
||||
import type { Subprocess } from 'bun'
|
||||
import type { SSHSessionManager, SSHSessionManagerOptions } from './SSHSessionManager.js'
|
||||
import { SSHSessionManagerImpl } from './SSHSessionManager.js'
|
||||
import type {
|
||||
SSHSessionManager,
|
||||
SSHSessionManagerOptions,
|
||||
} from './SSHSessionManager.js'
|
||||
import { createAuthProxy } from './SSHAuthProxy.js'
|
||||
export type { SSHAuthProxy } from './SSHAuthProxy.js'
|
||||
import type { SSHAuthProxy } from './SSHAuthProxy.js'
|
||||
import { probeRemote } from './SSHProbe.js'
|
||||
import { deployBinary } from './SSHDeploy.js'
|
||||
import { buildCliLaunch } from '../utils/cliLaunch.js'
|
||||
import { logForDebugging } from '../utils/debug.js'
|
||||
import { jsonParse } from '../utils/slowOperations.js'
|
||||
import { randomUUID } from 'crypto'
|
||||
|
||||
export interface SSHAuthProxy {
|
||||
stop(): void
|
||||
}
|
||||
const INIT_TIMEOUT_MS = 30_000
|
||||
const STDERR_TAIL_LINES = 20
|
||||
|
||||
export interface SSHSession {
|
||||
remoteCwd: string
|
||||
@@ -21,9 +32,419 @@ export class SSHSessionError extends Error {
|
||||
}
|
||||
}
|
||||
|
||||
export const createSSHSession: (...args: unknown[]) => Promise<SSHSession> = (async () => {
|
||||
throw new SSHSessionError('SSH sessions are not supported in this build')
|
||||
});
|
||||
export const createLocalSSHSession: (...args: unknown[]) => Promise<SSHSession> = (async () => {
|
||||
throw new SSHSessionError('Local SSH sessions are not supported in this build')
|
||||
});
|
||||
export async function createSSHSession(
|
||||
config: {
|
||||
host: string
|
||||
cwd?: string
|
||||
localVersion: string
|
||||
permissionMode?: string
|
||||
dangerouslySkipPermissions?: boolean
|
||||
extraCliArgs: string[]
|
||||
remoteBin?: string
|
||||
},
|
||||
callbacks?: {
|
||||
onProgress?: (msg: string) => void
|
||||
},
|
||||
): Promise<SSHSession> {
|
||||
const { host, localVersion, extraCliArgs, remoteBin } = config
|
||||
const onProgress = callbacks?.onProgress
|
||||
|
||||
let remoteBinaryPath: string
|
||||
let defaultCwd = '/'
|
||||
|
||||
if (remoteBin) {
|
||||
onProgress?.('Using custom remote binary, skipping probe/deploy…')
|
||||
remoteBinaryPath = remoteBin
|
||||
logForDebugging(`[SSH] custom remoteBin: ${remoteBin}`)
|
||||
// Quick SSH to get remote home directory for default CWD
|
||||
try {
|
||||
const pwdProc = Bun.spawn(
|
||||
['ssh', '-o', 'BatchMode=yes', '-o', 'ConnectTimeout=5', host, 'pwd'],
|
||||
{
|
||||
stdin: 'ignore',
|
||||
stdout: 'pipe',
|
||||
stderr: 'ignore',
|
||||
},
|
||||
)
|
||||
await pwdProc.exited
|
||||
const pwd = (await new Response(pwdProc.stdout).text()).trim()
|
||||
if (pwd.startsWith('/')) defaultCwd = pwd
|
||||
} catch {
|
||||
/* use fallback */
|
||||
}
|
||||
} else {
|
||||
// 1. Probe remote host
|
||||
const probe = await probeRemote(host, onProgress)
|
||||
logForDebugging(`[SSH] probe result: ${JSON.stringify(probe)}`)
|
||||
defaultCwd = probe.defaultCwd
|
||||
|
||||
// 2. Deploy if binary missing or version mismatch
|
||||
remoteBinaryPath = probe.binaryPath ?? '~/.local/bin/claude'
|
||||
if (!probe.hasBinary || probe.remoteVersion !== localVersion) {
|
||||
onProgress?.(
|
||||
probe.hasBinary
|
||||
? `Updating remote binary (${probe.remoteVersion} → ${localVersion})…`
|
||||
: 'Deploying binary to remote…',
|
||||
)
|
||||
remoteBinaryPath = await deployBinary({
|
||||
host,
|
||||
remotePlatform: probe.remotePlatform,
|
||||
remoteArch: probe.remoteArch,
|
||||
localVersion,
|
||||
onProgress,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Start local auth proxy
|
||||
const { proxy, localAddress, authEnv } = await createAuthProxy()
|
||||
logForDebugging(`[SSH] auth proxy listening on ${localAddress}`)
|
||||
|
||||
// 4. Build SSH command with -R reverse forward and remote CLI
|
||||
const remoteSocketId = randomUUID().slice(0, 8)
|
||||
const isWindows = process.platform === 'win32'
|
||||
|
||||
const remoteCli: string[] = []
|
||||
for (const [k, v] of Object.entries(authEnv)) {
|
||||
remoteCli.push(`${k}=${v}`)
|
||||
}
|
||||
remoteCli.push(
|
||||
remoteBinaryPath,
|
||||
'--output-format',
|
||||
'stream-json',
|
||||
'--input-format',
|
||||
'stream-json',
|
||||
'--verbose',
|
||||
'-p',
|
||||
)
|
||||
if (config.cwd) remoteCli.push('--cwd', config.cwd)
|
||||
if (config.permissionMode)
|
||||
remoteCli.push('--permission-mode', config.permissionMode)
|
||||
if (config.dangerouslySkipPermissions)
|
||||
remoteCli.push('--dangerously-skip-permissions')
|
||||
remoteCli.push(...extraCliArgs)
|
||||
|
||||
const sshArgs = ['ssh']
|
||||
|
||||
if (!isWindows) {
|
||||
const remoteSocket = `/tmp/claude-ssh-auth-${remoteSocketId}.sock`
|
||||
sshArgs.push('-R', `${remoteSocket}:${localAddress}`)
|
||||
sshArgs.push('-o', 'StreamLocalBindUnlink=yes')
|
||||
// Override auth env to use the remote socket path
|
||||
const idx = remoteCli.indexOf(
|
||||
`ANTHROPIC_AUTH_SOCKET=${authEnv.ANTHROPIC_AUTH_SOCKET}`,
|
||||
)
|
||||
if (idx !== -1) {
|
||||
remoteCli[idx] = `ANTHROPIC_AUTH_SOCKET=${remoteSocket}`
|
||||
}
|
||||
} else {
|
||||
// Windows: TCP reverse forward
|
||||
const localPort = localAddress.split(':')[1]
|
||||
const remotePort = 10000 + Math.floor(Math.random() * 50000)
|
||||
sshArgs.push('-R', `${remotePort}:127.0.0.1:${localPort}`)
|
||||
// Override auth env to use remote TCP address
|
||||
const baseIdx = remoteCli.findIndex(s =>
|
||||
s.startsWith('ANTHROPIC_BASE_URL='),
|
||||
)
|
||||
if (baseIdx !== -1) {
|
||||
remoteCli[baseIdx] = `ANTHROPIC_BASE_URL=http://127.0.0.1:${remotePort}`
|
||||
}
|
||||
}
|
||||
|
||||
sshArgs.push(host, remoteCli.join(' '))
|
||||
|
||||
onProgress?.('Starting remote session…')
|
||||
logForDebugging(`[SSH] spawning: ${sshArgs.join(' ')}`)
|
||||
|
||||
let proc: Subprocess
|
||||
try {
|
||||
proc = Bun.spawn(sshArgs, {
|
||||
stdin: 'pipe',
|
||||
stdout: 'pipe',
|
||||
stderr: 'pipe',
|
||||
})
|
||||
} catch (err) {
|
||||
proxy.stop()
|
||||
throw new SSHSessionError(
|
||||
`Failed to spawn SSH process: ${err instanceof Error ? err.message : String(err)}`,
|
||||
)
|
||||
}
|
||||
|
||||
const stderrChunks: string[] = []
|
||||
collectStderr(proc, stderrChunks)
|
||||
|
||||
let remoteCwd: string
|
||||
if (remoteBin) {
|
||||
// Custom binary mode: the remote CLI in print+stream-json mode emits
|
||||
// init only after receiving the first user message (QueryEngine yield).
|
||||
// Waiting for init here would deadlock. Instead, verify the process
|
||||
// is alive and use the configured or probed CWD.
|
||||
const earlyExit = await Promise.race([
|
||||
proc.exited.then(code => code),
|
||||
new Promise<null>(r => setTimeout(() => r(null), 3_000)),
|
||||
])
|
||||
if (earlyExit !== null) {
|
||||
proxy.stop()
|
||||
const tail = stderrChunks.join('').trim()
|
||||
throw new SSHSessionError(
|
||||
`Remote process exited immediately (code ${earlyExit})${tail ? `: ${tail}` : ''}`,
|
||||
)
|
||||
}
|
||||
remoteCwd = config.cwd || defaultCwd || '/'
|
||||
} else {
|
||||
try {
|
||||
remoteCwd = await waitForInit(proc, config.cwd || defaultCwd)
|
||||
} catch (err) {
|
||||
proxy.stop()
|
||||
proc.kill()
|
||||
throw err
|
||||
}
|
||||
}
|
||||
|
||||
logForDebugging(`[SSH] remote session initialized, remoteCwd=${remoteCwd}`)
|
||||
|
||||
let currentProc = proc
|
||||
|
||||
const reconnect = async (): Promise<Subprocess> => {
|
||||
logForDebugging('[SSH] reconnect: re-spawning SSH process with --continue')
|
||||
const reconnectArgs = [...sshArgs]
|
||||
const cmdIdx = reconnectArgs.length - 1
|
||||
const existingCmd = reconnectArgs[cmdIdx]!
|
||||
if (!existingCmd.includes('--continue')) {
|
||||
reconnectArgs[cmdIdx] = existingCmd.replace(
|
||||
/ -p(?:\s|$)/,
|
||||
' -p --continue ',
|
||||
)
|
||||
}
|
||||
|
||||
const newProc = Bun.spawn(reconnectArgs, {
|
||||
stdin: 'pipe',
|
||||
stdout: 'pipe',
|
||||
stderr: 'pipe',
|
||||
})
|
||||
|
||||
const newStderrChunks: string[] = []
|
||||
collectStderr(newProc, newStderrChunks)
|
||||
|
||||
await waitForInit(newProc, remoteCwd)
|
||||
currentProc = newProc
|
||||
stderrChunks.length = 0
|
||||
stderrChunks.push(...newStderrChunks)
|
||||
|
||||
return newProc
|
||||
}
|
||||
|
||||
return {
|
||||
remoteCwd,
|
||||
get proc() {
|
||||
return currentProc
|
||||
},
|
||||
proxy,
|
||||
createManager(options: SSHSessionManagerOptions): SSHSessionManager {
|
||||
return new SSHSessionManagerImpl(currentProc, {
|
||||
...options,
|
||||
reconnect,
|
||||
})
|
||||
},
|
||||
getStderrTail(): string {
|
||||
return stderrChunks.slice(-STDERR_TAIL_LINES).join('')
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
export async function createLocalSSHSession(config: {
|
||||
cwd?: string
|
||||
permissionMode?: string
|
||||
dangerouslySkipPermissions?: boolean
|
||||
}): Promise<SSHSession> {
|
||||
const { proxy, authEnv } = await createAuthProxy()
|
||||
|
||||
const cliArgs: string[] = [
|
||||
'--output-format',
|
||||
'stream-json',
|
||||
'--input-format',
|
||||
'stream-json',
|
||||
'-p',
|
||||
]
|
||||
if (config.cwd) {
|
||||
cliArgs.push('--cwd', config.cwd)
|
||||
}
|
||||
if (config.permissionMode) {
|
||||
cliArgs.push('--permission-mode', config.permissionMode)
|
||||
}
|
||||
if (config.dangerouslySkipPermissions) {
|
||||
cliArgs.push('--dangerously-skip-permissions')
|
||||
}
|
||||
|
||||
const spec = buildCliLaunch(cliArgs)
|
||||
|
||||
let proc: Subprocess
|
||||
try {
|
||||
proc = Bun.spawn([spec.execPath, ...spec.args], {
|
||||
stdin: 'pipe',
|
||||
stdout: 'pipe',
|
||||
stderr: 'pipe',
|
||||
env: { ...spec.env, ...authEnv },
|
||||
})
|
||||
} catch (err) {
|
||||
proxy.stop()
|
||||
throw new SSHSessionError(
|
||||
`Failed to spawn local CLI process: ${err instanceof Error ? err.message : String(err)}`,
|
||||
)
|
||||
}
|
||||
|
||||
logForDebugging('[SSH] local session spawned, waiting for init message...')
|
||||
|
||||
const stderrChunks: string[] = []
|
||||
collectStderr(proc, stderrChunks)
|
||||
|
||||
let remoteCwd: string
|
||||
try {
|
||||
remoteCwd = await waitForInit(proc, config.cwd)
|
||||
} catch (err) {
|
||||
proxy.stop()
|
||||
proc.kill()
|
||||
throw err
|
||||
}
|
||||
|
||||
logForDebugging(`[SSH] local session initialized, remoteCwd=${remoteCwd}`)
|
||||
|
||||
let currentProc = proc
|
||||
|
||||
const reconnect = async (): Promise<Subprocess> => {
|
||||
logForDebugging('[SSH] local reconnect: re-spawning CLI with --continue')
|
||||
const reconnectCliArgs = [...cliArgs]
|
||||
if (!reconnectCliArgs.includes('--continue')) {
|
||||
reconnectCliArgs.push('--continue')
|
||||
}
|
||||
|
||||
const reconnectSpec = buildCliLaunch(reconnectCliArgs)
|
||||
const newProc = Bun.spawn([reconnectSpec.execPath, ...reconnectSpec.args], {
|
||||
stdin: 'pipe',
|
||||
stdout: 'pipe',
|
||||
stderr: 'pipe',
|
||||
env: { ...reconnectSpec.env, ...authEnv },
|
||||
})
|
||||
|
||||
const newStderrChunks: string[] = []
|
||||
collectStderr(newProc, newStderrChunks)
|
||||
|
||||
await waitForInit(newProc, remoteCwd)
|
||||
currentProc = newProc
|
||||
stderrChunks.length = 0
|
||||
stderrChunks.push(...newStderrChunks)
|
||||
|
||||
return newProc
|
||||
}
|
||||
|
||||
return {
|
||||
remoteCwd,
|
||||
get proc() {
|
||||
return currentProc
|
||||
},
|
||||
proxy,
|
||||
createManager(options: SSHSessionManagerOptions): SSHSessionManager {
|
||||
return new SSHSessionManagerImpl(currentProc, {
|
||||
...options,
|
||||
reconnect,
|
||||
})
|
||||
},
|
||||
getStderrTail(): string {
|
||||
return stderrChunks.slice(-STDERR_TAIL_LINES).join('')
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
async function waitForInit(
|
||||
proc: Subprocess,
|
||||
fallbackCwd?: string,
|
||||
): Promise<string> {
|
||||
const stdout = proc.stdout
|
||||
if (!stdout) {
|
||||
throw new SSHSessionError('Child process stdout is not readable')
|
||||
}
|
||||
|
||||
const reader = (stdout as ReadableStream<Uint8Array>).getReader()
|
||||
const decoder = new TextDecoder()
|
||||
let buffer = ''
|
||||
const deadline = Date.now() + INIT_TIMEOUT_MS
|
||||
|
||||
try {
|
||||
while (Date.now() < deadline) {
|
||||
const remaining = deadline - Date.now()
|
||||
const result = await Promise.race([
|
||||
reader.read(),
|
||||
new Promise<{ done: true; value: undefined }>((_, reject) =>
|
||||
setTimeout(
|
||||
() =>
|
||||
reject(
|
||||
new SSHSessionError(
|
||||
'Remote CLI did not initialize within 30 seconds. Check SSH connectivity and remote binary.',
|
||||
),
|
||||
),
|
||||
remaining,
|
||||
),
|
||||
),
|
||||
])
|
||||
|
||||
if (result.done) {
|
||||
throw new SSHSessionError(
|
||||
'Child process exited before sending init message',
|
||||
)
|
||||
}
|
||||
|
||||
buffer += decoder.decode(result.value, { stream: true })
|
||||
const lines = buffer.split('\n')
|
||||
buffer = lines.pop() ?? ''
|
||||
|
||||
for (const line of lines) {
|
||||
const trimmed = line.trim()
|
||||
if (!trimmed) continue
|
||||
try {
|
||||
const msg = jsonParse(trimmed) as Record<string, unknown>
|
||||
if (msg.type === 'system' && msg.subtype === 'init') {
|
||||
reader.releaseLock()
|
||||
return (msg.cwd as string) || fallbackCwd || process.cwd()
|
||||
}
|
||||
} catch {
|
||||
// not valid JSON — skip
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (err) {
|
||||
reader.releaseLock()
|
||||
throw err instanceof SSHSessionError
|
||||
? err
|
||||
: new SSHSessionError(
|
||||
`Error reading init message: ${err instanceof Error ? err.message : String(err)}`,
|
||||
)
|
||||
}
|
||||
|
||||
reader.releaseLock()
|
||||
throw new SSHSessionError(
|
||||
'Remote CLI did not initialize within 30 seconds. Check SSH connectivity and remote binary.',
|
||||
)
|
||||
}
|
||||
|
||||
function collectStderr(proc: Subprocess, chunks: string[]): void {
|
||||
const stderr = proc.stderr
|
||||
if (!stderr) return
|
||||
|
||||
const reader = (stderr as ReadableStream<Uint8Array>).getReader()
|
||||
const decoder = new TextDecoder()
|
||||
|
||||
void (async () => {
|
||||
try {
|
||||
while (true) {
|
||||
const { done, value } = await reader.read()
|
||||
if (done) break
|
||||
chunks.push(decoder.decode(value, { stream: true }))
|
||||
if (chunks.length > STDERR_TAIL_LINES * 2) {
|
||||
chunks.splice(0, chunks.length - STDERR_TAIL_LINES)
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
// stderr closed — expected on process exit
|
||||
}
|
||||
})()
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user