Files
zn-ai/electron/gateway/handlers/chat.ts
DEV_DSW 4c61e93c3e Add unit tests for skill capabilities, skill planner, and UV setup
- Implement tests for random ID generation, ensuring preference for crypto.randomUUID.
- Create tests for runtime context capabilities, validating the injection of enabled skill capabilities.
- Add tests for skill capability parsing, including classification and command example extraction.
- Introduce tests for the skill planner, verifying tool call planning based on user requests and attachment requirements.
- Establish tests for UV setup, ensuring proper handling of Python installation scenarios and environment checks.
2026-04-24 17:02:59 +08:00

849 lines
23 KiB
TypeScript

import { createProvider } from '@electron/providers';
import type {
BaseProvider,
ProviderCapabilities,
GatewayChatContentBlock,
GatewayChatMessage,
} from '@electron/providers/BaseProvider';
import { DEFAULT_PROVIDER_CAPABILITIES } from '@electron/providers/BaseProvider';
import { providerApiService } from '@electron/service/provider-api-service';
import logManager from '@electron/service/logger';
import { normalizeAgentSessionKey } from '@runtime/lib/models';
import type {
ContentBlock,
RawMessage,
ToolCallPayload,
ToolStatus,
} from '@runtime/shared/chat-model';
import { appendTranscriptLine } from '@electron/utils/token-usage-writer';
import {
createChatToolRuntime,
createGatewayToolDefinitions,
mapSkillCapabilitiesToRegistryInputs,
} from '../chat-tooling';
import { createRandomId } from '../random-id';
import { buildRuntimeContextMessages } from '../runtime-context';
import { sessionStore } from '../session-store';
import { getEnabledSkillCapabilities } from '../skill-capability-registry';
import { planToolCall } from '../skill-planner';
import { createToolRegistry } from '../tool-registry';
import type { GatewayEvent, GatewayRpcParams, GatewayRpcReturns } from '../types';
import type { ToolRuntime } from '../tool-runtime';
type ResolvedProviderTarget = {
accountId: string;
model: string;
provider: BaseProvider;
providerName: string;
};
type StreamedToolCallState = {
index: number;
id: string;
name: string;
argumentsText: string;
};
function flattenMessageContent(content: RawMessage['content']): string {
if (typeof content === 'string') {
return content;
}
return content
.map((block) => {
if (!block || typeof block !== 'object') {
return '';
}
if (block.type === 'text' && typeof block.text === 'string') {
return block.text;
}
if (block.type === 'thinking' && typeof block.thinking === 'string') {
return block.thinking;
}
if ((block.type === 'tool_result' || block.type === 'toolResult') && typeof block.content === 'string') {
return block.content;
}
if ((block.type === 'tool_result' || block.type === 'toolResult') && Array.isArray(block.content)) {
return flattenMessageContent(block.content as RawMessage['content']);
}
if ((block.type === 'tool_result' || block.type === 'toolResult') && typeof block.summary === 'string') {
return block.summary;
}
if (
(block.type === 'tool_result' || block.type === 'toolResult')
&& block.result
&& typeof block.result === 'object'
&& 'summary' in block.result
&& typeof block.result.summary === 'string'
) {
return block.result.summary;
}
return '';
})
.filter(Boolean)
.join('\n');
}
function contentBlockToGatewayBlock(block: ContentBlock): GatewayChatContentBlock | null {
switch (block.type) {
case 'text':
return typeof block.text === 'string'
? {
type: 'text',
text: block.text,
}
: null;
case 'thinking':
return typeof block.thinking === 'string'
? {
type: 'thinking',
thinking: block.thinking,
}
: null;
case 'tool_use':
case 'toolCall':
return {
type: 'tool_use',
id: block.id || block.toolCallId || createRandomId(),
name: block.name || 'tool',
input: block.input ?? block.arguments,
summary: block.summary,
};
case 'tool_result':
case 'toolResult':
return {
type: 'tool_result',
toolCallId: block.toolCallId || block.id,
content: Array.isArray(block.content)
? block.content
.map((child) => contentBlockToGatewayBlock(child))
.filter((child): child is GatewayChatContentBlock => child !== null)
: block.content,
result: block.result,
summary: block.summary,
ok: block.ok,
error: block.error,
};
default:
return null;
}
}
function buildChatMessages(sessionMessages: RawMessage[]): GatewayChatMessage[] {
return sessionMessages
.map((message): GatewayChatMessage | null => {
if (!message.role || !message.content) {
return null;
}
const role = message.role;
const normalizedRole = role === 'toolresult' ? 'tool_result' : role;
if (typeof message.content === 'string') {
const content = message.content.trim();
if (!content) {
return null;
}
if (normalizedRole === 'user' || normalizedRole === 'assistant' || normalizedRole === 'system' || normalizedRole === 'tool_result') {
return {
role: normalizedRole,
content,
name: message.toolName,
toolCallId: message.toolCallId,
};
}
return null;
}
const blocks = message.content
.map((block) => contentBlockToGatewayBlock(block))
.filter((block): block is GatewayChatContentBlock => block !== null);
if (blocks.length === 0) {
const content = flattenMessageContent(message.content).trim();
if (!content) {
return null;
}
return {
role: normalizedRole,
content,
name: message.toolName,
toolCallId: message.toolCallId,
};
}
return {
role: normalizedRole,
content: blocks,
name: message.toolName,
toolCallId: message.toolCallId,
};
})
.filter((message): message is GatewayChatMessage => message !== null);
}
function appendTranscriptMessage(
sessionKey: string,
message: RawMessage,
extras?: Record<string, unknown>,
): void {
appendTranscriptLine(sessionKey, {
type: 'message',
timestamp: new Date().toISOString(),
message: {
role: message.role === 'tool_result' || message.role === 'toolresult' ? 'toolResult' : message.role,
content: flattenMessageContent(message.content),
toolCallId: message.toolCallId,
tool: message.toolName,
details: message.toolResult,
...extras,
},
});
}
function buildToolUseMessage(toolCallId: string, toolCall: ToolCallPayload): RawMessage {
const toolName = toolCall.name || 'tool';
return {
role: 'assistant',
content: [
{
type: 'tool_use',
id: toolCallId,
name: toolName,
input: toolCall.input,
summary: toolCall.summary,
},
],
timestamp: Date.now(),
toolCallId,
toolName,
toolCall: {
id: toolCallId,
name: toolName,
input: toolCall.input,
summary: toolCall.summary,
},
};
}
function buildMultiToolUseMessage(toolCalls: Array<ToolCallPayload & { id: string }>): RawMessage {
const firstToolCall = toolCalls[0];
return {
role: 'assistant',
content: toolCalls.map((toolCall) => ({
type: 'tool_use' as const,
id: toolCall.id,
name: toolCall.name || 'tool',
input: toolCall.input,
summary: toolCall.summary,
})),
timestamp: Date.now(),
toolCallId: firstToolCall?.id,
toolName: toolCalls.length === 1 ? firstToolCall?.name : undefined,
toolCall: toolCalls.length === 1 && firstToolCall
? {
id: firstToolCall.id,
name: firstToolCall.name,
input: firstToolCall.input,
summary: firstToolCall.summary,
}
: null,
};
}
function buildToolStatus(
toolCallId: string,
toolCall: ToolCallPayload,
status: ToolStatus['status'],
summary: string,
updatedAt: number,
result?: unknown,
durationMs?: number,
): ToolStatus {
return {
id: toolCallId,
toolCallId,
name: toolCall.name || 'tool',
status,
updatedAt,
durationMs,
summary,
input: toolCall.input,
result,
};
}
function collectSessionFiles(sessionMessages: RawMessage[]): RawMessage['_attachedFiles'] {
const files = new Map<string, NonNullable<RawMessage['_attachedFiles']>[number]>();
for (let index = sessionMessages.length - 1; index >= 0; index -= 1) {
const message = sessionMessages[index];
for (const attachment of message?._attachedFiles || []) {
const key = `${attachment.filePath || ''}|${attachment.fileName || ''}|${attachment.mimeType || ''}`;
if (!key.trim() || files.has(key)) {
continue;
}
files.set(key, attachment);
}
}
return Array.from(files.values());
}
function parseProviderToolCallInput(argumentsText: string): unknown {
const trimmed = argumentsText.trim();
if (!trimmed) {
return {};
}
try {
return JSON.parse(trimmed) as unknown;
} catch {
return {
rawArguments: trimmed,
};
}
}
function applyProviderToolCallDelta(
states: Map<number, StreamedToolCallState>,
delta: NonNullable<Awaited<ReturnType<BaseProvider['chat']>> extends AsyncIterable<infer T> ? T : never>['toolCalls'][number],
): void {
const index = typeof delta.index === 'number' ? delta.index : states.size;
const existing = states.get(index) || {
index,
id: delta.id || createRandomId(),
name: delta.name || 'tool',
argumentsText: '',
};
if (typeof delta.id === 'string' && delta.id.trim()) {
existing.id = delta.id;
}
if (typeof delta.name === 'string' && delta.name.trim()) {
existing.name = delta.name;
}
if (typeof delta.argumentsDelta === 'string') {
existing.argumentsText += delta.argumentsDelta;
}
states.set(index, existing);
}
function finalizeProviderToolCalls(
states: Map<number, StreamedToolCallState>,
): Array<ToolCallPayload & { id: string }> {
return Array.from(states.values())
.sort((left, right) => left.index - right.index)
.filter((state) => state.name.trim())
.map((state) => ({
id: state.id,
name: state.name,
input: parseProviderToolCallInput(state.argumentsText),
summary: `Model requested ${state.name}.`,
}));
}
function finalizeAssistantMessage(
sessionKey: string,
runId: string,
message: RawMessage,
broadcast: (event: GatewayEvent) => void,
extras?: Record<string, unknown>,
): void {
sessionStore.appendMessage(sessionKey, message);
sessionStore.clearActiveRun(sessionKey);
appendTranscriptMessage(sessionKey, message, extras);
broadcast({
type: 'chat:final',
sessionKey,
runId,
message,
});
}
async function executeToolCallAndPersist(
sessionKey: string,
runId: string,
runtime: ToolRuntime,
toolCallId: string,
toolCall: ToolCallPayload,
broadcast: (event: GatewayEvent) => void,
): Promise<{ finalStatus: ToolStatus; toolResultMessage: RawMessage }> {
const startedAt = Date.now();
const runningStatus = buildToolStatus(
toolCallId,
toolCall,
'running',
toolCall.summary || `Running ${toolCall.name || 'tool'}`,
startedAt,
);
broadcast({
type: 'tool:status',
sessionKey,
runId,
toolCallId,
toolName: runningStatus.name,
status: runningStatus.status,
updatedAt: runningStatus.updatedAt,
summary: runningStatus.summary,
input: runningStatus.input,
});
const toolRun = await runtime.run(
{
toolCallId,
toolName: toolCall.name || 'tool',
input: toolCall.input,
summary: toolCall.summary,
source: 'planner',
},
{
sessionKey,
runId,
signal: sessionStore.getActiveRun(sessionKey)?.abortController.signal,
files: collectSessionFiles(sessionStore.getOrCreate(sessionKey).messages),
metadata: {
requestedBy: 'chat.send',
},
},
);
const finalStatus = buildToolStatus(
toolCallId,
toolCall,
toolRun.execution.status,
toolRun.normalized.summary || toolCall.summary || `Finished ${toolCall.name || 'tool'}`,
Date.now(),
toolRun.normalized.payload,
toolRun.execution.durationMs,
);
const toolResultMessage: RawMessage = {
...toolRun.normalized.transcriptMessage,
_toolStatuses: [finalStatus],
};
sessionStore.appendMessage(sessionKey, toolResultMessage);
appendTranscriptMessage(sessionKey, toolResultMessage, {
tool: toolCall.name,
toolCallId,
});
broadcast({
type: 'tool:status',
sessionKey,
runId,
toolCallId,
toolName: finalStatus.name,
status: finalStatus.status,
updatedAt: finalStatus.updatedAt,
durationMs: finalStatus.durationMs,
summary: finalStatus.summary,
input: finalStatus.input,
result: finalStatus.result,
});
return {
finalStatus,
toolResultMessage,
};
}
function resolveProviderTarget(
options?: GatewayRpcParams['chat.send']['options'],
): ResolvedProviderTarget {
const accountId = options?.providerAccountId || providerApiService.getDefault().accountId;
if (!accountId) {
throw new Error('No provider account selected');
}
const account = providerApiService.getAccounts().find((candidate) => candidate.id === accountId);
if (!account) {
throw new Error(`Provider account ${accountId} not found`);
}
const model = account.model;
if (!model) {
throw new Error(`Provider account ${accountId} has no model configured`);
}
return {
accountId,
model,
provider: createProvider(accountId),
providerName: account.vendorId || account.label || account.model || 'unknown',
};
}
function tryResolveProviderTarget(
options?: GatewayRpcParams['chat.send']['options'],
): ResolvedProviderTarget | null {
try {
return resolveProviderTarget(options);
} catch (error) {
logManager.warn('Provider resolution skipped for this chat turn:', error);
return null;
}
}
function getProviderCapabilities(provider: BaseProvider): ProviderCapabilities {
if (typeof provider.getCapabilities === 'function') {
return provider.getCapabilities();
}
return DEFAULT_PROVIDER_CAPABILITIES;
}
async function processChatStream(
sessionKey: string,
runId: string,
provider: BaseProvider,
model: string,
providerName: string,
messages: GatewayChatMessage[],
signal: AbortSignal,
broadcast: (event: GatewayEvent) => void,
) {
const capabilities = getEnabledSkillCapabilities();
const capabilityInputs = mapSkillCapabilitiesToRegistryInputs(capabilities);
const registry = createToolRegistry({
capabilities: capabilityInputs,
});
const runtime = createChatToolRuntime(capabilities);
const providerCapabilities = getProviderCapabilities(provider);
const toolDefinitions = providerCapabilities.toolCalls
? createGatewayToolDefinitions(registry)
: undefined;
const maxToolRounds = providerCapabilities.toolCalls && toolDefinitions && toolDefinitions.length > 0 ? 4 : 1;
let currentMessages = [...messages];
let finalUsage: unknown = undefined;
try {
for (let round = 0; round < maxToolRounds; round += 1) {
let assistantContent = '';
const streamedToolCalls = new Map<number, StreamedToolCallState>();
const chunks = await provider.chat(currentMessages, model, {
signal,
...(toolDefinitions?.length ? { tools: toolDefinitions, toolChoice: 'auto' as const } : {}),
metadata: {
sessionKey,
runId,
provider: providerName,
round,
},
});
for await (const chunk of chunks) {
if (signal.aborted) {
break;
}
if (chunk.result) {
assistantContent += chunk.result;
if (!providerCapabilities.toolCalls) {
broadcast({
type: 'chat:delta',
sessionKey,
runId,
delta: chunk.result,
});
}
}
if (chunk.toolCalls?.length) {
for (const toolCallDelta of chunk.toolCalls) {
applyProviderToolCallDelta(streamedToolCalls, toolCallDelta);
}
}
if (chunk.usage !== undefined) {
finalUsage = chunk.usage;
}
}
if (signal.aborted) {
break;
}
const providerToolCalls = finalizeProviderToolCalls(streamedToolCalls);
if (providerToolCalls.length === 0) {
if (providerCapabilities.toolCalls && assistantContent) {
broadcast({
type: 'chat:delta',
sessionKey,
runId,
delta: assistantContent,
});
}
const finalMessage: RawMessage = {
role: 'assistant',
content: assistantContent,
timestamp: Date.now(),
};
finalizeAssistantMessage(sessionKey, runId, finalMessage, broadcast, {
model,
provider: providerName,
usage: finalUsage,
});
return;
}
const toolUseMessage = buildMultiToolUseMessage(providerToolCalls);
sessionStore.appendMessage(sessionKey, toolUseMessage);
appendTranscriptMessage(sessionKey, toolUseMessage, {
toolCalls: providerToolCalls.map((toolCall) => ({
id: toolCall.id,
name: toolCall.name,
})),
});
currentMessages.push(...buildChatMessages([toolUseMessage]));
for (const providerToolCall of providerToolCalls) {
const { toolResultMessage } = await executeToolCallAndPersist(
sessionKey,
runId,
runtime,
providerToolCall.id,
providerToolCall,
broadcast,
);
currentMessages.push(...buildChatMessages([toolResultMessage]));
}
}
} catch (error) {
sessionStore.clearActiveRun(sessionKey);
broadcast({
type: 'chat:error',
sessionKey,
runId,
error: error instanceof Error ? error.message : String(error),
});
}
}
async function processPlannedToolRun(
sessionKey: string,
runId: string,
userMessage: RawMessage,
toolCallId: string,
toolCall: ToolCallPayload,
options: GatewayRpcParams['chat.send']['options'] | undefined,
broadcast: (event: GatewayEvent) => void,
): Promise<void> {
const capabilities = getEnabledSkillCapabilities();
const runtime = createChatToolRuntime(capabilities);
try {
const { finalStatus, toolResultMessage } = await executeToolCallAndPersist(
sessionKey,
runId,
runtime,
toolCallId,
toolCall,
broadcast,
);
const providerTarget = tryResolveProviderTarget(options);
if (!providerTarget) {
finalizeAssistantMessage(
sessionKey,
runId,
{
role: 'assistant',
content: toolRun.normalized.summary || flattenMessageContent(toolResultMessage.content),
timestamp: Date.now(),
_toolStatuses: [finalStatus],
},
broadcast,
{
tool: toolCall.name,
},
);
return;
}
const session = sessionStore.getOrCreate(sessionKey);
const messages = [
...buildRuntimeContextMessages(sessionKey),
...buildChatMessages(session.messages),
];
await processChatStream(
sessionKey,
runId,
providerTarget.provider,
providerTarget.model,
providerTarget.providerName,
messages,
sessionStore.getActiveRun(sessionKey)?.abortController.signal || new AbortController().signal,
broadcast,
);
} catch (error) {
sessionStore.clearActiveRun(sessionKey);
broadcast({
type: 'chat:error',
sessionKey,
runId,
error: error instanceof Error ? error.message : String(error),
});
}
}
function buildPlannerResponse(
sessionKey: string,
runId: string,
summary: string,
broadcast: (event: GatewayEvent) => void,
): GatewayRpcReturns['chat.send'] {
const finalMessage: RawMessage = {
role: 'assistant',
content: summary,
timestamp: Date.now(),
};
finalizeAssistantMessage(sessionKey, runId, finalMessage, broadcast);
return { runId };
}
export function handleChatSend(
params: GatewayRpcParams['chat.send'],
broadcast: (event: GatewayEvent) => void,
): GatewayRpcReturns['chat.send'] {
const sessionKey = normalizeAgentSessionKey(params.sessionKey);
const { message, options } = params;
const runId = createRandomId();
const userMessage: RawMessage = {
...message,
timestamp: message.timestamp || Date.now(),
};
sessionStore.appendMessage(sessionKey, userMessage);
appendTranscriptMessage(sessionKey, userMessage);
const session = sessionStore.getOrCreate(sessionKey);
const capabilities = getEnabledSkillCapabilities();
const capabilityInputs = mapSkillCapabilitiesToRegistryInputs(capabilities);
const registry = createToolRegistry({
capabilities: capabilityInputs,
});
const decision = planToolCall({
message: userMessage,
attachments: userMessage._attachedFiles,
history: session.messages.slice(0, -1),
capabilities: capabilityInputs,
registry,
});
if (decision.kind === 'tool' && decision.toolCall) {
const toolCallId = `${decision.toolCall.name || 'tool'}:${runId}`;
const toolUseMessage = buildToolUseMessage(toolCallId, decision.toolCall);
sessionStore.appendMessage(sessionKey, toolUseMessage);
appendTranscriptMessage(sessionKey, toolUseMessage, {
tool: decision.toolCall.name,
toolCallId,
});
const abortController = new AbortController();
sessionStore.setActiveRun(sessionKey, runId, abortController);
void processPlannedToolRun(
sessionKey,
runId,
userMessage,
toolCallId,
decision.toolCall,
options,
broadcast,
);
return { runId };
}
if (decision.kind === 'no-tool' && decision.blockingIssue) {
return buildPlannerResponse(
sessionKey,
runId,
decision.blockingIssue.message,
broadcast,
);
}
const providerTarget = resolveProviderTarget(options);
const messages = [
...buildRuntimeContextMessages(sessionKey),
...buildChatMessages(session.messages),
];
const abortController = new AbortController();
sessionStore.setActiveRun(sessionKey, runId, abortController);
void processChatStream(
sessionKey,
runId,
providerTarget.provider,
providerTarget.model,
providerTarget.providerName,
messages,
abortController.signal,
broadcast,
).catch((error) => {
logManager.error('Unexpected error in processChatStream:', error);
sessionStore.clearActiveRun(sessionKey);
broadcast({
type: 'chat:error',
sessionKey,
runId,
error: error instanceof Error ? error.message : String(error),
});
});
return { runId };
}
export function handleChatHistory(
params: GatewayRpcParams['chat.history'],
): GatewayRpcReturns['chat.history'] {
return sessionStore.getMessages(normalizeAgentSessionKey(params.sessionKey), params.limit ?? 50);
}
export function handleChatAbort(
params: GatewayRpcParams['chat.abort'],
broadcast: (event: GatewayEvent) => void,
): GatewayRpcReturns['chat.abort'] {
const sessionKey = normalizeAgentSessionKey(params.sessionKey);
const activeRun = sessionStore.getActiveRun(sessionKey);
if (activeRun) {
activeRun.abortController.abort();
sessionStore.clearActiveRun(sessionKey);
broadcast({
type: 'chat:aborted',
sessionKey,
runId: activeRun.runId,
});
}
}
export function handleSessionList(): GatewayRpcReturns['session.list'] {
return sessionStore.getAllKeys();
}
export function handleSessionDelete(
params: GatewayRpcParams['session.delete'],
): GatewayRpcReturns['session.delete'] {
sessionStore.deleteSession(normalizeAgentSessionKey(params.sessionKey));
return { success: true };
}