feat: add Grok (xAI) API adapter with custom model mapping (#152)

Add xAI Grok as a new API provider. Reuses OpenAI-compatible message/tool
converters and stream adapter with Grok-specific client and model mapping.

Default model mapping:
  opus   → grok-4.20-reasoning
  sonnet → grok-3-mini-fast
  haiku  → grok-3-mini-fast

Users can customize mapping via:
  - GROK_MODEL env var (override all)
  - GROK_MODEL_MAP env var (JSON family map, e.g. {"opus":"grok-4"})
  - GROK_DEFAULT_{FAMILY}_MODEL env vars

Activation: CLAUDE_CODE_USE_GROK=1 or modelType: "grok" in settings.json
Also integrates with /provider command for runtime switching.
This commit is contained in:
uk0
2026-04-07 09:24:55 +08:00
committed by GitHub
parent dfa7aa1d29
commit 70baa6f7db
9 changed files with 488 additions and 6 deletions

View File

@@ -0,0 +1,192 @@
import type { BetaToolUnion } from '@anthropic-ai/sdk/resources/beta/messages/messages.mjs'
import type { SystemPrompt } from '../../../utils/systemPromptType.js'
import type { Message, StreamEvent, SystemAPIErrorMessage, AssistantMessage } from '../../../types/message.js'
import type { Tools } from '../../../Tool.js'
import { getGrokClient } from './client.js'
import { anthropicMessagesToOpenAI } from '../openai/convertMessages.js'
import { anthropicToolsToOpenAI, anthropicToolChoiceToOpenAI } from '../openai/convertTools.js'
import { adaptOpenAIStreamToAnthropic } from '../openai/streamAdapter.js'
import { resolveGrokModel } from './modelMapping.js'
import { normalizeMessagesForAPI } from '../../../utils/messages.js'
import { toolToAPISchema } from '../../../utils/api.js'
import { logForDebugging } from '../../../utils/debug.js'
import { addToTotalSessionCost } from '../../../cost-tracker.js'
import { calculateUSDCost } from '../../../utils/modelCost.js'
import type { Options } from '../claude.js'
import { randomUUID } from 'crypto'
import {
createAssistantAPIErrorMessage,
normalizeContentFromAPI,
} from '../../../utils/messages.js'
/**
* Grok (xAI) query path. Grok uses an OpenAI-compatible API, so we reuse
* the OpenAI message/tool converters and stream adapter. Only the client
* (different base URL + API key) and model mapping are Grok-specific.
*/
export async function* queryModelGrok(
messages: Message[],
systemPrompt: SystemPrompt,
tools: Tools,
signal: AbortSignal,
options: Options,
): AsyncGenerator<
StreamEvent | AssistantMessage | SystemAPIErrorMessage,
void
> {
try {
const grokModel = resolveGrokModel(options.model)
const messagesForAPI = normalizeMessagesForAPI(messages, tools)
const toolSchemas = await Promise.all(
tools.map(tool =>
toolToAPISchema(tool, {
getToolPermissionContext: options.getToolPermissionContext,
tools,
agents: options.agents,
allowedAgentTypes: options.allowedAgentTypes,
model: options.model,
}),
),
)
const standardTools = toolSchemas.filter(
(t): t is BetaToolUnion & { type: string } => {
const anyT = t as Record<string, unknown>
return anyT.type !== 'advisor_20260301' && anyT.type !== 'computer_20250124'
},
)
const openaiMessages = anthropicMessagesToOpenAI(messagesForAPI, systemPrompt)
const openaiTools = anthropicToolsToOpenAI(standardTools)
const openaiToolChoice = anthropicToolChoiceToOpenAI(options.toolChoice)
const client = getGrokClient({
maxRetries: 0,
fetchOverride: options.fetchOverride,
source: options.querySource,
})
logForDebugging(`[Grok] Calling model=${grokModel}, messages=${openaiMessages.length}, tools=${openaiTools.length}`)
const stream = await client.chat.completions.create(
{
model: grokModel,
messages: openaiMessages,
...(openaiTools.length > 0 && {
tools: openaiTools,
...(openaiToolChoice && { tool_choice: openaiToolChoice }),
}),
stream: true,
stream_options: { include_usage: true },
...(options.temperatureOverride !== undefined && {
temperature: options.temperatureOverride,
}),
},
{
signal,
},
)
const adaptedStream = adaptOpenAIStreamToAnthropic(stream, grokModel)
const contentBlocks: Record<number, any> = {}
let partialMessage: any = undefined
let usage = {
input_tokens: 0,
output_tokens: 0,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 0,
}
let ttftMs = 0
const start = Date.now()
for await (const event of adaptedStream) {
switch (event.type) {
case 'message_start': {
partialMessage = (event as any).message
ttftMs = Date.now() - start
if ((event as any).message?.usage) {
usage = { ...usage, ...((event as any).message.usage) }
}
break
}
case 'content_block_start': {
const idx = (event as any).index
const cb = (event as any).content_block
if (cb.type === 'tool_use') {
contentBlocks[idx] = { ...cb, input: '' }
} else if (cb.type === 'text') {
contentBlocks[idx] = { ...cb, text: '' }
} else if (cb.type === 'thinking') {
contentBlocks[idx] = { ...cb, thinking: '', signature: '' }
} else {
contentBlocks[idx] = { ...cb }
}
break
}
case 'content_block_delta': {
const idx = (event as any).index
const delta = (event as any).delta
const block = contentBlocks[idx]
if (!block) break
if (delta.type === 'text_delta') {
block.text = (block.text || '') + delta.text
} else if (delta.type === 'input_json_delta') {
block.input = (block.input || '') + delta.partial_json
} else if (delta.type === 'thinking_delta') {
block.thinking = (block.thinking || '') + delta.thinking
} else if (delta.type === 'signature_delta') {
block.signature = delta.signature
}
break
}
case 'content_block_stop': {
const idx = (event as any).index
const block = contentBlocks[idx]
if (!block || !partialMessage) break
const m: AssistantMessage = {
message: {
...partialMessage,
content: normalizeContentFromAPI([block], tools, options.agentId),
},
requestId: undefined,
type: 'assistant',
uuid: randomUUID(),
timestamp: new Date().toISOString(),
}
yield m
break
}
case 'message_delta': {
const deltaUsage = (event as any).usage
if (deltaUsage) {
usage = { ...usage, ...deltaUsage }
}
break
}
case 'message_stop':
break
}
if (event.type === 'message_stop' && usage.input_tokens + usage.output_tokens > 0) {
const costUSD = calculateUSDCost(grokModel, usage as any)
addToTotalSessionCost(costUSD, usage as any, options.model)
}
yield {
type: 'stream_event',
event,
...(event.type === 'message_start' ? { ttftMs } : undefined),
} as StreamEvent
}
} catch (error) {
const errorMessage = error instanceof Error ? error.message : String(error)
logForDebugging(`[Grok] Error: ${errorMessage}`, { level: 'error' })
yield createAssistantAPIErrorMessage({
content: `API Error: ${errorMessage}`,
apiError: 'api_error',
error: error instanceof Error ? error : new Error(String(error)),
})
}
}