feat: add Grok (xAI) API adapter with custom model mapping (#152)

Add xAI Grok as a new API provider. Reuses OpenAI-compatible message/tool
converters and stream adapter with Grok-specific client and model mapping.

Default model mapping:
  opus   → grok-4.20-reasoning
  sonnet → grok-3-mini-fast
  haiku  → grok-3-mini-fast

Users can customize mapping via:
  - GROK_MODEL env var (override all)
  - GROK_MODEL_MAP env var (JSON family map, e.g. {"opus":"grok-4"})
  - GROK_DEFAULT_{FAMILY}_MODEL env vars

Activation: CLAUDE_CODE_USE_GROK=1 or modelType: "grok" in settings.json
Also integrates with /provider command for runtime switching.
This commit is contained in:
uk0
2026-04-07 09:24:55 +08:00
committed by GitHub
parent dfa7aa1d29
commit 70baa6f7db
9 changed files with 488 additions and 6 deletions

View File

@@ -15,6 +15,8 @@ function getEnvVarForProvider(provider: string): string {
return 'CLAUDE_CODE_USE_FOUNDRY'
case 'gemini':
return 'CLAUDE_CODE_USE_GEMINI'
case 'grok':
return 'CLAUDE_CODE_USE_GROK'
default:
throw new Error(`Unknown provider: ${provider}`)
}
@@ -48,6 +50,7 @@ const call: LocalCommandCall = async (args, context) => {
delete process.env.CLAUDE_CODE_USE_FOUNDRY
delete process.env.CLAUDE_CODE_USE_OPENAI
delete process.env.CLAUDE_CODE_USE_GEMINI
delete process.env.CLAUDE_CODE_USE_GROK
return {
type: 'text',
value: 'API provider cleared (will use environment variables).',
@@ -59,6 +62,7 @@ const call: LocalCommandCall = async (args, context) => {
'anthropic',
'openai',
'gemini',
'grok',
'bedrock',
'vertex',
'foundry',
@@ -87,6 +91,19 @@ const call: LocalCommandCall = async (args, context) => {
}
}
// Check env vars when switching to grok (including settings.env)
if (arg === 'grok') {
const mergedEnv = getMergedEnv()
const hasKey = !!(mergedEnv.GROK_API_KEY || mergedEnv.XAI_API_KEY)
if (!hasKey) {
updateSettingsForSource('userSettings', { modelType: 'grok' })
return {
type: 'text',
value: `Switched to Grok provider.\nWarning: Missing env var: GROK_API_KEY (or XAI_API_KEY)\nConfigure it via settings.json env or set manually.`,
}
}
}
// Check env vars when switching to gemini (including settings.env)
if (arg === 'gemini') {
const mergedEnv = getMergedEnv()
@@ -104,13 +121,14 @@ const call: LocalCommandCall = async (args, context) => {
// Handle different provider types
// - 'anthropic', 'openai', 'gemini' are stored in settings.json (persistent)
// - 'bedrock', 'vertex', 'foundry' are env-only (do NOT touch settings.json)
if (arg === 'anthropic' || arg === 'openai' || arg === 'gemini') {
if (arg === 'anthropic' || arg === 'openai' || arg === 'gemini' || arg === 'grok') {
// Clear any cloud provider env vars to avoid conflicts
delete process.env.CLAUDE_CODE_USE_BEDROCK
delete process.env.CLAUDE_CODE_USE_VERTEX
delete process.env.CLAUDE_CODE_USE_FOUNDRY
delete process.env.CLAUDE_CODE_USE_OPENAI
delete process.env.CLAUDE_CODE_USE_GEMINI
delete process.env.CLAUDE_CODE_USE_GROK
// Update settings.json
updateSettingsForSource('userSettings', { modelType: arg })
// Ensure settings.env gets applied to process.env
@@ -122,6 +140,7 @@ const call: LocalCommandCall = async (args, context) => {
delete process.env.OPENAI_API_KEY
delete process.env.OPENAI_BASE_URL
delete process.env.CLAUDE_CODE_USE_GEMINI
delete process.env.CLAUDE_CODE_USE_GROK
process.env[getEnvVarForProvider(arg)] = '1'
// Do not modify settings.json - cloud providers controlled solely by env vars
applyConfigEnvironmentVariables()
@@ -136,9 +155,9 @@ const provider = {
type: 'local',
name: 'provider',
description:
'Switch API provider (anthropic/openai/gemini/bedrock/vertex/foundry)',
'Switch API provider (anthropic/openai/gemini/grok/bedrock/vertex/foundry)',
aliases: ['api'],
argumentHint: '[anthropic|openai|gemini|bedrock|vertex|foundry|unset]',
argumentHint: '[anthropic|openai|gemini|grok|bedrock|vertex|foundry|unset]',
supportsNonInteractive: true,
load: () => Promise.resolve({ call }),
} satisfies Command

View File

@@ -1350,6 +1350,12 @@ async function* queryModel(
return
}
if (getAPIProvider() === 'grok') {
const { queryModelGrok } = await import('./grok/index.js')
yield* queryModelGrok(messagesForAPI, systemPrompt, filteredTools, signal, options)
return
}
// Instrumentation: Track message count after normalization
logEvent('tengu_api_after_normalize', {
postNormalizedMessageCount: messagesForAPI.length,

View File

@@ -0,0 +1,44 @@
import { describe, expect, test, beforeEach, afterEach } from 'bun:test'
import { getGrokClient, clearGrokClientCache } from '../client.js'
describe('getGrokClient', () => {
const originalEnv = { ...process.env }
beforeEach(() => {
clearGrokClientCache()
process.env.GROK_API_KEY = 'test-key'
delete process.env.GROK_BASE_URL
})
afterEach(() => {
clearGrokClientCache()
process.env = { ...originalEnv }
})
test('creates client with default base URL', () => {
const client = getGrokClient()
expect(client).toBeDefined()
expect(client.baseURL).toBe('https://api.x.ai/v1')
})
test('uses GROK_BASE_URL when set', () => {
process.env.GROK_BASE_URL = 'https://custom.grok.api/v1'
clearGrokClientCache()
const client = getGrokClient()
expect(client.baseURL).toBe('https://custom.grok.api/v1')
})
test('returns cached client on second call', () => {
const client1 = getGrokClient()
const client2 = getGrokClient()
expect(client1).toBe(client2)
})
test('clearGrokClientCache resets cache', () => {
const client1 = getGrokClient()
clearGrokClientCache()
process.env.GROK_BASE_URL = 'https://other.api/v1'
const client2 = getGrokClient()
expect(client1).not.toBe(client2)
})
})

View File

@@ -0,0 +1,67 @@
import { describe, expect, test, beforeEach, afterEach } from 'bun:test'
import { resolveGrokModel } from '../modelMapping.js'
describe('resolveGrokModel', () => {
const originalEnv = { ...process.env }
beforeEach(() => {
delete process.env.GROK_MODEL
delete process.env.GROK_MODEL_MAP
delete process.env.GROK_DEFAULT_SONNET_MODEL
delete process.env.GROK_DEFAULT_OPUS_MODEL
delete process.env.GROK_DEFAULT_HAIKU_MODEL
delete process.env.ANTHROPIC_DEFAULT_SONNET_MODEL
delete process.env.ANTHROPIC_DEFAULT_OPUS_MODEL
delete process.env.ANTHROPIC_DEFAULT_HAIKU_MODEL
})
afterEach(() => {
process.env = { ...originalEnv }
})
test('GROK_MODEL env var takes highest priority', () => {
process.env.GROK_MODEL = 'grok-custom'
expect(resolveGrokModel('claude-sonnet-4-6')).toBe('grok-custom')
})
test('maps opus models to grok-4.20-reasoning', () => {
expect(resolveGrokModel('claude-opus-4-6')).toBe('grok-4.20-reasoning')
})
test('maps sonnet models to grok-3-mini-fast', () => {
expect(resolveGrokModel('claude-sonnet-4-6')).toBe('grok-3-mini-fast')
})
test('maps haiku models to grok-3-mini-fast', () => {
expect(resolveGrokModel('claude-haiku-4-5-20251001')).toBe('grok-3-mini-fast')
})
test('GROK_MODEL_MAP overrides family mapping', () => {
process.env.GROK_MODEL_MAP = '{"opus":"grok-4","sonnet":"grok-3","haiku":"grok-mini"}'
expect(resolveGrokModel('claude-opus-4-6')).toBe('grok-4')
expect(resolveGrokModel('claude-sonnet-4-6')).toBe('grok-3')
expect(resolveGrokModel('claude-haiku-4-5-20251001')).toBe('grok-mini')
})
test('GROK_MODEL_MAP ignores invalid JSON', () => {
process.env.GROK_MODEL_MAP = 'not-json'
expect(resolveGrokModel('claude-opus-4-6')).toBe('grok-4.20-reasoning')
})
test('GROK_DEFAULT_{FAMILY}_MODEL overrides default map', () => {
process.env.GROK_DEFAULT_OPUS_MODEL = 'grok-2-latest'
expect(resolveGrokModel('claude-opus-4-6')).toBe('grok-2-latest')
})
test('passes through unknown model names', () => {
expect(resolveGrokModel('some-unknown-model')).toBe('some-unknown-model')
})
test('strips [1m] suffix before lookup', () => {
expect(resolveGrokModel('claude-sonnet-4-6[1m]')).toBe('grok-3-mini-fast')
})
test('falls back to family default for unlisted model', () => {
expect(resolveGrokModel('claude-opus-99-20300101')).toBe('grok-4.20-reasoning')
})
})

View File

@@ -0,0 +1,44 @@
import OpenAI from 'openai'
import { getProxyFetchOptions } from 'src/utils/proxy.js'
/**
* Environment variables:
*
* GROK_API_KEY (or XAI_API_KEY): Required. API key for the xAI Grok endpoint.
* GROK_BASE_URL: Optional. Defaults to https://api.x.ai/v1.
*/
const DEFAULT_BASE_URL = 'https://api.x.ai/v1'
let cachedClient: OpenAI | null = null
export function getGrokClient(options?: {
maxRetries?: number
fetchOverride?: typeof fetch
source?: string
}): OpenAI {
if (cachedClient) return cachedClient
const apiKey = process.env.GROK_API_KEY || process.env.XAI_API_KEY || ''
const baseURL = process.env.GROK_BASE_URL || DEFAULT_BASE_URL
const client = new OpenAI({
apiKey,
baseURL,
maxRetries: options?.maxRetries ?? 0,
timeout: parseInt(process.env.API_TIMEOUT_MS || String(600 * 1000), 10),
dangerouslyAllowBrowser: true,
fetchOptions: getProxyFetchOptions({ forAnthropicAPI: false }) as RequestInit,
...(options?.fetchOverride && { fetch: options.fetchOverride }),
})
if (!options?.fetchOverride) {
cachedClient = client
}
return client
}
export function clearGrokClientCache(): void {
cachedClient = null
}

View File

@@ -0,0 +1,192 @@
import type { BetaToolUnion } from '@anthropic-ai/sdk/resources/beta/messages/messages.mjs'
import type { SystemPrompt } from '../../../utils/systemPromptType.js'
import type { Message, StreamEvent, SystemAPIErrorMessage, AssistantMessage } from '../../../types/message.js'
import type { Tools } from '../../../Tool.js'
import { getGrokClient } from './client.js'
import { anthropicMessagesToOpenAI } from '../openai/convertMessages.js'
import { anthropicToolsToOpenAI, anthropicToolChoiceToOpenAI } from '../openai/convertTools.js'
import { adaptOpenAIStreamToAnthropic } from '../openai/streamAdapter.js'
import { resolveGrokModel } from './modelMapping.js'
import { normalizeMessagesForAPI } from '../../../utils/messages.js'
import { toolToAPISchema } from '../../../utils/api.js'
import { logForDebugging } from '../../../utils/debug.js'
import { addToTotalSessionCost } from '../../../cost-tracker.js'
import { calculateUSDCost } from '../../../utils/modelCost.js'
import type { Options } from '../claude.js'
import { randomUUID } from 'crypto'
import {
createAssistantAPIErrorMessage,
normalizeContentFromAPI,
} from '../../../utils/messages.js'
/**
* Grok (xAI) query path. Grok uses an OpenAI-compatible API, so we reuse
* the OpenAI message/tool converters and stream adapter. Only the client
* (different base URL + API key) and model mapping are Grok-specific.
*/
export async function* queryModelGrok(
messages: Message[],
systemPrompt: SystemPrompt,
tools: Tools,
signal: AbortSignal,
options: Options,
): AsyncGenerator<
StreamEvent | AssistantMessage | SystemAPIErrorMessage,
void
> {
try {
const grokModel = resolveGrokModel(options.model)
const messagesForAPI = normalizeMessagesForAPI(messages, tools)
const toolSchemas = await Promise.all(
tools.map(tool =>
toolToAPISchema(tool, {
getToolPermissionContext: options.getToolPermissionContext,
tools,
agents: options.agents,
allowedAgentTypes: options.allowedAgentTypes,
model: options.model,
}),
),
)
const standardTools = toolSchemas.filter(
(t): t is BetaToolUnion & { type: string } => {
const anyT = t as Record<string, unknown>
return anyT.type !== 'advisor_20260301' && anyT.type !== 'computer_20250124'
},
)
const openaiMessages = anthropicMessagesToOpenAI(messagesForAPI, systemPrompt)
const openaiTools = anthropicToolsToOpenAI(standardTools)
const openaiToolChoice = anthropicToolChoiceToOpenAI(options.toolChoice)
const client = getGrokClient({
maxRetries: 0,
fetchOverride: options.fetchOverride,
source: options.querySource,
})
logForDebugging(`[Grok] Calling model=${grokModel}, messages=${openaiMessages.length}, tools=${openaiTools.length}`)
const stream = await client.chat.completions.create(
{
model: grokModel,
messages: openaiMessages,
...(openaiTools.length > 0 && {
tools: openaiTools,
...(openaiToolChoice && { tool_choice: openaiToolChoice }),
}),
stream: true,
stream_options: { include_usage: true },
...(options.temperatureOverride !== undefined && {
temperature: options.temperatureOverride,
}),
},
{
signal,
},
)
const adaptedStream = adaptOpenAIStreamToAnthropic(stream, grokModel)
const contentBlocks: Record<number, any> = {}
let partialMessage: any = undefined
let usage = {
input_tokens: 0,
output_tokens: 0,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 0,
}
let ttftMs = 0
const start = Date.now()
for await (const event of adaptedStream) {
switch (event.type) {
case 'message_start': {
partialMessage = (event as any).message
ttftMs = Date.now() - start
if ((event as any).message?.usage) {
usage = { ...usage, ...((event as any).message.usage) }
}
break
}
case 'content_block_start': {
const idx = (event as any).index
const cb = (event as any).content_block
if (cb.type === 'tool_use') {
contentBlocks[idx] = { ...cb, input: '' }
} else if (cb.type === 'text') {
contentBlocks[idx] = { ...cb, text: '' }
} else if (cb.type === 'thinking') {
contentBlocks[idx] = { ...cb, thinking: '', signature: '' }
} else {
contentBlocks[idx] = { ...cb }
}
break
}
case 'content_block_delta': {
const idx = (event as any).index
const delta = (event as any).delta
const block = contentBlocks[idx]
if (!block) break
if (delta.type === 'text_delta') {
block.text = (block.text || '') + delta.text
} else if (delta.type === 'input_json_delta') {
block.input = (block.input || '') + delta.partial_json
} else if (delta.type === 'thinking_delta') {
block.thinking = (block.thinking || '') + delta.thinking
} else if (delta.type === 'signature_delta') {
block.signature = delta.signature
}
break
}
case 'content_block_stop': {
const idx = (event as any).index
const block = contentBlocks[idx]
if (!block || !partialMessage) break
const m: AssistantMessage = {
message: {
...partialMessage,
content: normalizeContentFromAPI([block], tools, options.agentId),
},
requestId: undefined,
type: 'assistant',
uuid: randomUUID(),
timestamp: new Date().toISOString(),
}
yield m
break
}
case 'message_delta': {
const deltaUsage = (event as any).usage
if (deltaUsage) {
usage = { ...usage, ...deltaUsage }
}
break
}
case 'message_stop':
break
}
if (event.type === 'message_stop' && usage.input_tokens + usage.output_tokens > 0) {
const costUSD = calculateUSDCost(grokModel, usage as any)
addToTotalSessionCost(costUSD, usage as any, options.model)
}
yield {
type: 'stream_event',
event,
...(event.type === 'message_start' ? { ttftMs } : undefined),
} as StreamEvent
}
} catch (error) {
const errorMessage = error instanceof Error ? error.message : String(error)
logForDebugging(`[Grok] Error: ${errorMessage}`, { level: 'error' })
yield createAssistantAPIErrorMessage({
content: `API Error: ${errorMessage}`,
apiError: 'api_error',
error: error instanceof Error ? error : new Error(String(error)),
})
}
}

View File

@@ -0,0 +1,107 @@
/**
* Default mapping from Anthropic model names to Grok model names.
*
* Users can override per-family via GROK_DEFAULT_{FAMILY}_MODEL env vars,
* or override the entire mapping via GROK_MODEL_MAP env var (JSON string):
* GROK_MODEL_MAP='{"opus":"grok-4","sonnet":"grok-3","haiku":"grok-3-mini-fast"}'
*/
const DEFAULT_MODEL_MAP: Record<string, string> = {
'claude-sonnet-4-20250514': 'grok-3-mini-fast',
'claude-sonnet-4-5-20250929': 'grok-3-mini-fast',
'claude-sonnet-4-6': 'grok-3-mini-fast',
'claude-opus-4-20250514': 'grok-4.20-reasoning',
'claude-opus-4-1-20250805': 'grok-4.20-reasoning',
'claude-opus-4-5-20251101': 'grok-4.20-reasoning',
'claude-opus-4-6': 'grok-4.20-reasoning',
'claude-haiku-4-5-20251001': 'grok-3-mini-fast',
'claude-3-5-haiku-20241022': 'grok-3-mini-fast',
'claude-3-7-sonnet-20250219': 'grok-3-mini-fast',
'claude-3-5-sonnet-20241022': 'grok-3-mini-fast',
}
/**
* Family-level mapping defaults (used by GROK_MODEL_MAP).
*/
const DEFAULT_FAMILY_MAP: Record<string, string> = {
opus: 'grok-4.20-reasoning',
sonnet: 'grok-3-mini-fast',
haiku: 'grok-3-mini-fast',
}
function getModelFamily(model: string): 'haiku' | 'sonnet' | 'opus' | null {
if (/haiku/i.test(model)) return 'haiku'
if (/opus/i.test(model)) return 'opus'
if (/sonnet/i.test(model)) return 'sonnet'
return null
}
/**
* Parse user-provided model map from GROK_MODEL_MAP env var.
* Accepts JSON like: {"opus":"grok-4","sonnet":"grok-3","haiku":"grok-3-mini-fast"}
*/
function getUserModelMap(): Record<string, string> | null {
const raw = process.env.GROK_MODEL_MAP
if (!raw) return null
try {
const parsed = JSON.parse(raw)
if (parsed && typeof parsed === 'object' && !Array.isArray(parsed)) {
return parsed as Record<string, string>
}
} catch {
// ignore invalid JSON
}
return null
}
/**
* Resolve the Grok model name for a given Anthropic model.
*
* Priority:
* 1. GROK_MODEL env var (override all)
* 2. GROK_MODEL_MAP env var — JSON family map (e.g. {"opus":"grok-4"})
* 3. GROK_DEFAULT_{FAMILY}_MODEL env var (e.g. GROK_DEFAULT_OPUS_MODEL)
* 4. ANTHROPIC_DEFAULT_{FAMILY}_MODEL env var (backward compat)
* 5. DEFAULT_MODEL_MAP lookup
* 6. Family-level default
* 7. Pass through original model name
*/
export function resolveGrokModel(anthropicModel: string): string {
// 1. Global override
if (process.env.GROK_MODEL) {
return process.env.GROK_MODEL
}
const cleanModel = anthropicModel.replace(/\[1m\]$/, '')
const family = getModelFamily(cleanModel)
// 2. User-provided model map
const userMap = getUserModelMap()
if (userMap && family && userMap[family]) {
return userMap[family]
}
if (family) {
// 3. Grok-specific family override
const grokEnvVar = `GROK_DEFAULT_${family.toUpperCase()}_MODEL`
const grokOverride = process.env[grokEnvVar]
if (grokOverride) return grokOverride
// 4. Anthropic env var (backward compat)
const anthropicEnvVar = `ANTHROPIC_DEFAULT_${family.toUpperCase()}_MODEL`
const anthropicOverride = process.env[anthropicEnvVar]
if (anthropicOverride) return anthropicOverride
}
// 5. Exact model name lookup
if (DEFAULT_MODEL_MAP[cleanModel]) {
return DEFAULT_MODEL_MAP[cleanModel]
}
// 6. Family-level default
if (family && DEFAULT_FAMILY_MAP[family]) {
return DEFAULT_FAMILY_MAP[family]
}
// 7. Pass through
return cleanModel
}

View File

@@ -9,11 +9,13 @@ export type APIProvider =
| 'foundry'
| 'openai'
| 'gemini'
| 'grok'
export function getAPIProvider(): APIProvider {
const modelType = getInitialSettings().modelType
if (modelType === 'openai') return 'openai'
if (modelType === 'gemini') return 'gemini'
if (modelType === 'grok') return 'grok'
if (isEnvTruthy(process.env.CLAUDE_CODE_USE_BEDROCK)) return 'bedrock'
if (isEnvTruthy(process.env.CLAUDE_CODE_USE_VERTEX)) return 'vertex'
@@ -21,6 +23,7 @@ export function getAPIProvider(): APIProvider {
if (isEnvTruthy(process.env.CLAUDE_CODE_USE_OPENAI)) return 'openai'
if (isEnvTruthy(process.env.CLAUDE_CODE_USE_GEMINI)) return 'gemini'
if (isEnvTruthy(process.env.CLAUDE_CODE_USE_GROK)) return 'grok'
return 'firstParty'
}

View File

@@ -373,11 +373,11 @@ export const SettingsSchema = lazySchema(() =>
.optional()
.describe('Tool usage permissions configuration'),
modelType: z
.enum(['anthropic', 'openai', 'gemini'])
.enum(['anthropic', 'openai', 'gemini', 'grok'])
.optional()
.describe(
'API provider type. "anthropic" uses the Anthropic API (default), "openai" uses the OpenAI Chat Completions API (/v1/chat/completions), and "gemini" uses the Gemini Generate Content API. ' +
'When set to "openai", configure OPENAI_API_KEY, OPENAI_BASE_URL, and OPENAI_MODEL in env. When set to "gemini", configure GEMINI_API_KEY, optional GEMINI_BASE_URL, and either GEMINI_MODEL or ANTHROPIC_DEFAULT_*_MODEL family env vars.',
'API provider type. "anthropic" uses the Anthropic API (default), "openai" uses the OpenAI Chat Completions API, "gemini" uses the Gemini API, and "grok" uses the xAI Grok API (OpenAI-compatible). ' +
'When set to "openai", configure OPENAI_API_KEY, OPENAI_BASE_URL, and OPENAI_MODEL. When set to "gemini", configure GEMINI_API_KEY and optional GEMINI_BASE_URL. When set to "grok", configure GROK_API_KEY (or XAI_API_KEY), optional GROK_BASE_URL, GROK_MODEL, and GROK_MODEL_MAP.',
),
model: z
.string()