diff --git a/packages/builtin-tools/src/tools/WebSearchTool/WebSearchTool.ts b/packages/builtin-tools/src/tools/WebSearchTool/WebSearchTool.ts index 43a032585..478704fd1 100644 --- a/packages/builtin-tools/src/tools/WebSearchTool/WebSearchTool.ts +++ b/packages/builtin-tools/src/tools/WebSearchTool/WebSearchTool.ts @@ -23,6 +23,26 @@ const inputSchema = lazySchema(() => .array(z.string()) .optional() .describe('Never include search results from these domains'), + num_results: z + .number() + .optional() + .describe('Number of search results to return (default: 8)'), + livecrawl: z + .enum(['fallback', 'preferred']) + .optional() + .describe( + "Live crawl mode - 'fallback': use live crawling as backup if cached content unavailable, 'preferred': prioritize live crawling (default: 'fallback')", + ), + search_type: z + .enum(['auto', 'fast', 'deep']) + .optional() + .describe( + "Search type - 'auto': balanced search (default), 'fast': quick results, 'deep': comprehensive search", + ), + context_max_characters: z + .number() + .optional() + .describe('Maximum characters for context string optimized for LLMs (default: 10000)'), }), ) type InputSchema = ReturnType @@ -148,6 +168,10 @@ export const WebSearchTool = buildTool({ const adapterResults = await adapter.search(query, { allowedDomains: input.allowed_domains, blockedDomains: input.blocked_domains, + numResults: input.num_results, + livecrawl: input.livecrawl, + searchType: input.search_type, + contextMaxCharacters: input.context_max_characters, signal: context.abortController.signal, onProgress(progress) { if (onProgress) { diff --git a/packages/builtin-tools/src/tools/WebSearchTool/__tests__/adapterFactory.test.ts b/packages/builtin-tools/src/tools/WebSearchTool/__tests__/adapterFactory.test.ts index 14eb7fecb..4e5353d89 100644 --- a/packages/builtin-tools/src/tools/WebSearchTool/__tests__/adapterFactory.test.ts +++ b/packages/builtin-tools/src/tools/WebSearchTool/__tests__/adapterFactory.test.ts @@ -52,10 +52,10 @@ describe('createAdapter', () => { expect(createAdapter().constructor.name).toBe('ApiSearchAdapter') }) - test('selects the Bing adapter for third-party Anthropic base URLs', () => { + test('selects the Exa adapter for third-party Anthropic base URLs', () => { delete process.env.WEB_SEARCH_ADAPTER isFirstPartyBaseUrl = false - expect(createAdapter().constructor.name).toBe('BingSearchAdapter') + expect(createAdapter().constructor.name).toBe('ExaSearchAdapter') }) }) diff --git a/packages/builtin-tools/src/tools/WebSearchTool/__tests__/exaAdapter.test.ts b/packages/builtin-tools/src/tools/WebSearchTool/__tests__/exaAdapter.test.ts new file mode 100644 index 000000000..8d1ef6f20 --- /dev/null +++ b/packages/builtin-tools/src/tools/WebSearchTool/__tests__/exaAdapter.test.ts @@ -0,0 +1,302 @@ +import { afterEach, describe, expect, mock, test } from 'bun:test' + +const _abortMock = () => ({ + AbortError: class AbortError extends Error { + constructor(message?: string) { super(message); this.name = 'AbortError' } + }, + isAbortError: (e: unknown) => e instanceof Error && (e as Error).name === 'AbortError', +}) +mock.module('src/utils/errors.js', _abortMock) +mock.module('src/utils/errors', _abortMock) + +describe('ExaSearchAdapter.search', () => { + const createAdapter = async () => { + const { ExaSearchAdapter } = await import('../adapters/exaAdapter') + return new ExaSearchAdapter() + } + + // Exa MCP returns SSE lines like: data: {"result":{"content":[{"type":"text","text":"..."}]}} + const buildSseResponse = (text: string) => `data: ${JSON.stringify({ result: { content: [{ type: 'text', text }] } })}\n` + + const STRUCTURED_TEXT = [ + 'Title: Example Result 1', + 'URL: https://example.com/page1', + 'Content: This is the content snippet for page 1.', + '', + '---', + '', + 'Title: Example Result 2', + 'URL: https://example.com/page2', + 'Content: This is the content snippet for page 2.', + ].join('\n') + + afterEach(() => { + mock.restore() + }) + + test('parses structured Title/URL/Content blocks from SSE response', async () => { + mock.module('axios', () => ({ + default: { + post: mock(() => Promise.resolve({ data: buildSseResponse(STRUCTURED_TEXT) })), + isCancel: () => false, + }, + })) + + const adapter = await createAdapter() + const results = await adapter.search('test query', {}) + + expect(results).toHaveLength(2) + expect(results[0]).toEqual({ + title: 'Example Result 1', + url: 'https://example.com/page1', + snippet: 'This is the content snippet for page 1.', + }) + expect(results[1]).toEqual({ + title: 'Example Result 2', + url: 'https://example.com/page2', + snippet: 'This is the content snippet for page 2.', + }) + }) + + test('parses markdown link fallback when no structured blocks', async () => { + const markdownText = '- [React Docs](https://react.dev/docs)\n- [React Hooks](https://react.dev/hooks)' + mock.module('axios', () => ({ + default: { + post: mock(() => Promise.resolve({ data: buildSseResponse(markdownText) })), + isCancel: () => false, + }, + })) + + const adapter = await createAdapter() + const results = await adapter.search('react', {}) + + expect(results).toHaveLength(2) + expect(results[0]).toEqual({ + title: 'React Docs', + url: 'https://react.dev/docs', + snippet: undefined, + }) + expect(results[1].url).toBe('https://react.dev/hooks') + }) + + test('parses plain URL fallback', async () => { + const plainUrlText = 'https://example.com/page1\nhttps://example.com/page2' + mock.module('axios', () => ({ + default: { + post: mock(() => Promise.resolve({ data: buildSseResponse(plainUrlText) })), + isCancel: () => false, + }, + })) + + const adapter = await createAdapter() + const results = await adapter.search('test', {}) + + expect(results).toHaveLength(2) + expect(results[0].url).toBe('https://example.com/page1') + }) + + test('returns empty array for empty response', async () => { + mock.module('axios', () => ({ + default: { + post: mock(() => Promise.resolve({ data: '' })), + isCancel: () => false, + }, + })) + + const adapter = await createAdapter() + const results = await adapter.search('test', {}) + + expect(results).toHaveLength(0) + }) + + test('parses direct JSON response (non-SSE fallback)', async () => { + const jsonResponse = JSON.stringify({ + result: { content: [{ type: 'text', text: STRUCTURED_TEXT }] }, + }) + mock.module('axios', () => ({ + default: { + post: mock(() => Promise.resolve({ data: jsonResponse })), + isCancel: () => false, + }, + })) + + const adapter = await createAdapter() + const results = await adapter.search('test', {}) + + expect(results).toHaveLength(2) + expect(results[0].url).toBe('https://example.com/page1') + }) + + test('calls onProgress with query_update and search_results_received', async () => { + mock.module('axios', () => ({ + default: { + post: mock(() => Promise.resolve({ data: buildSseResponse(STRUCTURED_TEXT) })), + isCancel: () => false, + }, + })) + + const progressCalls: any[] = [] + const onProgress = (p: any) => progressCalls.push(p) + + const adapter = await createAdapter() + await adapter.search('test', { onProgress }) + + expect(progressCalls).toHaveLength(2) + expect(progressCalls[0]).toEqual({ type: 'query_update', query: 'test' }) + expect(progressCalls[1]).toEqual({ + type: 'search_results_received', + resultCount: 2, + query: 'test', + }) + }) + + test('filters results by allowedDomains', async () => { + const mixedText = [ + 'Title: Allowed', + 'URL: https://allowed.com/a', + '---', + 'Title: Blocked', + 'URL: https://blocked.com/b', + ].join('\n') + + mock.module('axios', () => ({ + default: { + post: mock(() => Promise.resolve({ data: buildSseResponse(mixedText) })), + isCancel: () => false, + }, + })) + + const adapter = await createAdapter() + const results = await adapter.search('test', { allowedDomains: ['allowed.com'] }) + + expect(results).toHaveLength(1) + expect(results[0].url).toBe('https://allowed.com/a') + }) + + test('filters results by blockedDomains', async () => { + const mixedText = [ + 'Title: Good', + 'URL: https://good.com/a', + '---', + 'Title: Spam', + 'URL: https://spam.com/b', + ].join('\n') + + mock.module('axios', () => ({ + default: { + post: mock(() => Promise.resolve({ data: buildSseResponse(mixedText) })), + isCancel: () => false, + }, + })) + + const adapter = await createAdapter() + const results = await adapter.search('test', { blockedDomains: ['spam.com'] }) + + expect(results).toHaveLength(1) + expect(results[0].url).toBe('https://good.com/a') + }) + + test('filters subdomains with allowedDomains', async () => { + const text = [ + 'Title: Subdomain', + 'URL: https://docs.example.com/page', + '---', + 'Title: Other', + 'URL: https://other.com/page', + ].join('\n') + + mock.module('axios', () => ({ + default: { + post: mock(() => Promise.resolve({ data: buildSseResponse(text) })), + isCancel: () => false, + }, + })) + + const adapter = await createAdapter() + const results = await adapter.search('test', { allowedDomains: ['example.com'] }) + + expect(results).toHaveLength(1) + expect(results[0].url).toBe('https://docs.example.com/page') + }) + + test('throws AbortError when signal is already aborted', async () => { + mock.module('axios', () => ({ + default: { + post: mock(() => Promise.resolve({ data: buildSseResponse(STRUCTURED_TEXT) })), + isCancel: () => false, + }, + })) + + const adapter = await createAdapter() + const controller = new AbortController() + controller.abort() + + const { AbortError } = await import('src/utils/errors') + await expect( + adapter.search('test', { signal: controller.signal }), + ).rejects.toThrow(AbortError) + }) + + test('re-throws non-abort axios errors', async () => { + const networkError = new Error('Network error') + mock.module('axios', () => ({ + default: { + post: mock(() => Promise.reject(networkError)), + isCancel: () => false, + }, + })) + + const adapter = await createAdapter() + await expect(adapter.search('test', {})).rejects.toThrow('Network error') + }) + + test('sends correct MCP request payload to Exa endpoint', async () => { + const axiosPost = mock(() => Promise.resolve({ data: buildSseResponse(STRUCTURED_TEXT) })) + mock.module('axios', () => ({ + default: { + post: axiosPost, + isCancel: () => false, + }, + })) + + const adapter = await createAdapter() + await adapter.search('hello world', {}) + + expect(axiosPost.mock.calls).toHaveLength(1) + const [url, body, config] = (axiosPost.mock.calls as any[][])[0] + expect(url).toBe('https://mcp.exa.ai/mcp') + expect(body.jsonrpc).toBe('2.0') + expect(body.method).toBe('tools/call') + expect(body.params.name).toBe('web_search_exa') + expect(body.params.arguments.query).toBe('hello world') + expect(body.params.arguments.type).toBe('auto') + expect(body.params.arguments.numResults).toBe(8) + expect(body.params.arguments.livecrawl).toBe('fallback') + expect(body.params.arguments.contextMaxCharacters).toBe(10000) + expect(config.headers.Accept).toBe('application/json, text/event-stream') + }) + + test('passes custom search options to MCP request', async () => { + const axiosPost = mock(() => Promise.resolve({ data: buildSseResponse(STRUCTURED_TEXT) })) + mock.module('axios', () => ({ + default: { + post: axiosPost, + isCancel: () => false, + }, + })) + + const adapter = await createAdapter() + await adapter.search('test', { + numResults: 15, + livecrawl: 'preferred', + searchType: 'deep', + contextMaxCharacters: 20000, + }) + + const [, body] = (axiosPost.mock.calls as any[][])[0] + expect(body.params.arguments.numResults).toBe(15) + expect(body.params.arguments.livecrawl).toBe('preferred') + expect(body.params.arguments.type).toBe('deep') + expect(body.params.arguments.contextMaxCharacters).toBe(20000) + }) +}) diff --git a/packages/builtin-tools/src/tools/WebSearchTool/adapters/exaAdapter.ts b/packages/builtin-tools/src/tools/WebSearchTool/adapters/exaAdapter.ts new file mode 100644 index 000000000..4ebde5842 --- /dev/null +++ b/packages/builtin-tools/src/tools/WebSearchTool/adapters/exaAdapter.ts @@ -0,0 +1,200 @@ +/** + * Exa AI-based search adapter — uses MCP protocol to call Exa's web search API. + * + * Ported from kilocode's production-validated implementation (mcp-exa.ts + websearch.ts). + * Key improvements over previous version: + * - Passes through numResults/livecrawl/type/contextMaxCharacters from options + * - Cleaner SSE parsing matching kilocode's approach + * - Proper content snippet extraction from Exa responses + */ + +import axios from 'axios' +import { AbortError } from 'src/utils/errors.js' +import type { SearchResult, SearchOptions, WebSearchAdapter } from './types.js' + +const EXA_MCP_URL = 'https://mcp.exa.ai/mcp' +const FETCH_TIMEOUT_MS = 25_000 + +export class ExaSearchAdapter implements WebSearchAdapter { + async search( + query: string, + options: SearchOptions, + ): Promise { + const { signal, onProgress, allowedDomains, blockedDomains } = options + + if (signal?.aborted) { + throw new AbortError() + } + + onProgress?.({ type: 'query_update', query }) + + const abortController = new AbortController() + if (signal) { + signal.addEventListener('abort', () => abortController.abort(), { once: true }) + } + + // Use options to derive search params — matches kilocode websearch.ts defaults + const numResults = options.numResults ?? 8 + const livecrawl = options.livecrawl ?? 'fallback' + const searchType = options.searchType ?? 'auto' + const contextMaxCharacters = options.contextMaxCharacters ?? 10000 + + let responseText: string + try { + const response = await axios.post( + EXA_MCP_URL, + { + jsonrpc: '2.0', + id: 1, + method: 'tools/call', + params: { + name: 'web_search_exa', + arguments: { + query, + type: searchType, + numResults, + livecrawl, + contextMaxCharacters, + }, + }, + }, + { + signal: abortController.signal, + timeout: FETCH_TIMEOUT_MS, + headers: { + 'Content-Type': 'application/json', + Accept: 'application/json, text/event-stream', + }, + responseType: 'text', + }, + ) + responseText = response.data as string + } catch (e) { + if (axios.isCancel(e) || abortController.signal.aborted) { + throw new AbortError() + } + throw e + } + + if (abortController.signal.aborted) { + throw new AbortError() + } + + const searchText = this.parseSse(responseText) + + if (abortController.signal.aborted) { + throw new AbortError() + } + + // Parse the Exa results from the text response + const results = this.parseResults(searchText) + + // Client-side domain filtering + const filteredResults = results.filter((r) => { + if (!r.url) return false + try { + const hostname = new URL(r.url).hostname + if (allowedDomains?.length && !allowedDomains.some(d => hostname === d || hostname.endsWith('.' + d))) { + return false + } + if (blockedDomains?.length && blockedDomains.some(d => hostname === d || hostname.endsWith('.' + d))) { + return false + } + } catch { + return false + } + return true + }) + + onProgress?.({ + type: 'search_results_received', + resultCount: filteredResults.length, + query, + }) + + return filteredResults + } + + private parseSse(body: string): string | undefined { + // SSE format: lines starting with "data: " containing JSON + // Matches kilocode mcp-exa.ts parseSse implementation + for (const line of body.split('\n')) { + if (!line.startsWith('data: ')) continue + const data = line.substring(6).trim() + if (!data || data === '[DONE]' || data === 'null') continue + + try { + const parsed = JSON.parse(data) + const content = parsed?.result?.content + if (Array.isArray(content) && content[0]?.text) { + return content[0].text + } + } catch { + // Continue to next line + } + } + + // Fallback: try parsing as direct JSON response (non-SSE) + try { + const parsed = JSON.parse(body) + const content = parsed?.result?.content + if (Array.isArray(content) && content[0]?.text) { + return content[0].text + } + } catch { + // Not JSON + } + + return undefined + } + + private parseResults(text: string | undefined): SearchResult[] { + if (!text) return [] + + const results: SearchResult[] = [] + + // Exa returns structured text with "Title:", "URL:", and "Content:" fields + // separated by "---" between entries + const blocks = text.split(/\n---\n/g) + + for (const block of blocks) { + const titleMatch = block.match(/^Title:\s*(.+)$/m) + const urlMatch = block.match(/^URL:\s*(https?:\/\/[^\s]+)$/m) + const contentMatch = block.match(/^Content:\s*([\s\S]+?)(?=\n(?:Title:|URL:|---)|$)/m) + + if (urlMatch) { + results.push({ + title: titleMatch?.[1]?.trim() ?? urlMatch[1], + url: urlMatch[1].trim(), + snippet: contentMatch?.[1]?.trim().slice(0, 300), + }) + } + } + + // Fallback: markdown links + if (results.length === 0) { + const markdownLinkRegex = /\[([^\]]+)\]\((https?:\/\/[^\)]+)\)/g + let match: RegExpExecArray | null + while ((match = markdownLinkRegex.exec(text)) !== null) { + results.push({ + title: match[1].trim(), + url: match[2].trim(), + }) + } + } + + // Fallback: plain URLs + if (results.length === 0) { + const urlRegex = /^https?:\/\/[^\s<>"\]]+/gm + let match: RegExpExecArray | null + while ((match = urlRegex.exec(text)) !== null) { + results.push({ + title: match[0], + url: match[0], + }) + } + } + + return results + } +} diff --git a/packages/builtin-tools/src/tools/WebSearchTool/adapters/index.ts b/packages/builtin-tools/src/tools/WebSearchTool/adapters/index.ts index 3a3c3cb0b..f1ef10bc9 100644 --- a/packages/builtin-tools/src/tools/WebSearchTool/adapters/index.ts +++ b/packages/builtin-tools/src/tools/WebSearchTool/adapters/index.ts @@ -7,6 +7,7 @@ import { isFirstPartyAnthropicBaseUrl } from 'src/utils/model/providers.js' import { ApiSearchAdapter } from './apiAdapter.js' import { BingSearchAdapter } from './bingAdapter.js' import { BraveSearchAdapter } from './braveAdapter.js' +import { ExaSearchAdapter } from './exaAdapter.js' import type { WebSearchAdapter } from './types.js' export type { @@ -30,7 +31,7 @@ function isThirdPartyProvider(): boolean { } let cachedAdapter: WebSearchAdapter | null = null -let cachedAdapterKey: 'api' | 'bing' | 'brave' | null = null +let cachedAdapterKey: 'api' | 'bing' | 'brave' | 'exa' | null = null export function createAdapter(): WebSearchAdapter { const envAdapter = process.env.WEB_SEARCH_ADAPTER @@ -40,7 +41,7 @@ export function createAdapter(): WebSearchAdapter { // 3. First-party Anthropic API → api (server-side web search + connector_text) // 4. Fallback → bing const adapterKey = - envAdapter === 'api' || envAdapter === 'bing' || envAdapter === 'brave' + envAdapter === 'api' || envAdapter === 'bing' || envAdapter === 'brave' || envAdapter === 'exa' ? envAdapter : isThirdPartyProvider() ? 'bing' @@ -56,9 +57,14 @@ export function createAdapter(): WebSearchAdapter { return cachedAdapter } if (adapterKey === 'brave') { - cachedAdapter = new BraveSearchAdapter() - cachedAdapterKey = 'brave' - return cachedAdapter + cachedAdapter = new BraveSearchAdapter() + cachedAdapterKey = 'brave' + return cachedAdapter + } + if (adapterKey === 'exa') { + cachedAdapter = new ExaSearchAdapter() + cachedAdapterKey = 'exa' + return cachedAdapter } cachedAdapter = new BingSearchAdapter() diff --git a/packages/builtin-tools/src/tools/WebSearchTool/adapters/types.ts b/packages/builtin-tools/src/tools/WebSearchTool/adapters/types.ts index cd04762fb..a867c5d92 100644 --- a/packages/builtin-tools/src/tools/WebSearchTool/adapters/types.ts +++ b/packages/builtin-tools/src/tools/WebSearchTool/adapters/types.ts @@ -9,6 +9,14 @@ export interface SearchOptions { blockedDomains?: string[] signal?: AbortSignal onProgress?: (progress: SearchProgress) => void + /** Number of search results to return (default: 8) */ + numResults?: number + /** Live crawl mode (default: 'fallback') */ + livecrawl?: 'fallback' | 'preferred' + /** Search type (default: 'auto') */ + searchType?: 'auto' | 'fast' | 'deep' + /** Maximum characters for context string (default: 10000) */ + contextMaxCharacters?: number } export interface SearchProgress {