196 lines
6.3 KiB
Python
196 lines
6.3 KiB
Python
from __future__ import annotations
|
||
|
||
import json
|
||
import re
|
||
import time
|
||
from typing import Any
|
||
|
||
import httpx
|
||
|
||
from app.config import settings
|
||
|
||
|
||
def _repair_truncated_json(s: str) -> str:
|
||
"""Best-effort repair of truncated JSON by closing unclosed brackets/braces.
|
||
|
||
Handles the common case where a model hits max_tokens mid-output and the
|
||
JSON stream ends without closing delimiters.
|
||
"""
|
||
s = s.rstrip()
|
||
# Strip trailing comma/colon that appears right before the cutoff
|
||
s = re.sub(r"[,:\s]+$", "", s)
|
||
# Walk the string tracking bracket depth (ignore chars inside strings)
|
||
stack: list[str] = []
|
||
in_string = False
|
||
escape = False
|
||
for ch in s:
|
||
if escape:
|
||
escape = False
|
||
continue
|
||
if ch == "\\" and in_string:
|
||
escape = True
|
||
continue
|
||
if ch == '"':
|
||
in_string = not in_string
|
||
continue
|
||
if in_string:
|
||
continue
|
||
if ch in "{[":
|
||
stack.append(ch)
|
||
elif ch in "}]":
|
||
if stack:
|
||
stack.pop()
|
||
# If we ended inside a string literal, close it first
|
||
if in_string:
|
||
s += '"'
|
||
# Close every open bracket in reverse order
|
||
closing = {"{": "}", "[": "]"}
|
||
for opener in reversed(stack):
|
||
s += closing[opener]
|
||
return s
|
||
|
||
|
||
def _extract_json(s: str) -> Any:
|
||
"""Provider-agnostic JSON extraction (Claude/APIYI often wrap or fence)."""
|
||
s = (s or "").strip()
|
||
if not s:
|
||
raise ValueError("LLM 返回空内容")
|
||
try:
|
||
return json.loads(s)
|
||
except Exception:
|
||
pass
|
||
m = re.search(r"```(?:json)?\s*(.*?)```", s, re.S)
|
||
if m:
|
||
try:
|
||
return json.loads(m.group(1).strip())
|
||
except Exception:
|
||
pass
|
||
s = re.sub(r"^```(?:json)?\s*", "", s, flags=re.I).strip().removesuffix("```").strip()
|
||
for op, cl in (("{", "}"), ("[", "]")):
|
||
i, j = s.find(op), s.rfind(cl)
|
||
if i != -1 and j > i:
|
||
try:
|
||
return json.loads(s[i:j + 1])
|
||
except Exception:
|
||
pass
|
||
if i != -1:
|
||
try:
|
||
return json.loads(_repair_truncated_json(s[i:]))
|
||
except Exception:
|
||
pass
|
||
# Last resort: repair a truncated JSON stream and retry
|
||
repaired = _repair_truncated_json(s)
|
||
if repaired != s:
|
||
for op, cl in (("{", "}"), ("[", "]")):
|
||
i, j = repaired.find(op), repaired.rfind(cl)
|
||
if i != -1 and j > i:
|
||
try:
|
||
return json.loads(repaired[i:j + 1])
|
||
except Exception:
|
||
pass
|
||
raise ValueError("LLM 未返回合法 JSON: " + s[:160])
|
||
|
||
|
||
class LlmClient:
|
||
def __init__(self, api_base: str, api_key: str, model: str,
|
||
timeout: int = 30, max_tokens: int = 4000) -> None:
|
||
self.api_base = api_base.rstrip("/")
|
||
self.api_key = api_key
|
||
self.model = model
|
||
self.timeout = timeout
|
||
self.max_tokens = max_tokens
|
||
|
||
@classmethod
|
||
def from_settings(cls) -> "LlmClient":
|
||
return cls(
|
||
api_base=settings.llm_api_base,
|
||
api_key=settings.llm_api_key,
|
||
model=settings.llm_model,
|
||
timeout=settings.llm_timeout_seconds,
|
||
)
|
||
|
||
def available(self) -> bool:
|
||
return bool(self.api_base and self.api_key)
|
||
|
||
def chat_json(self, system: str, user: str) -> dict[str, Any]:
|
||
if not self.available():
|
||
raise RuntimeError("LLM not configured")
|
||
|
||
headers = {
|
||
"Authorization": f"Bearer {self.api_key}",
|
||
"Content-Type": "application/json",
|
||
}
|
||
sys_json = (system or "") + (
|
||
"\n\n严格要求:只输出一个合法 JSON(对象或数组),"
|
||
"不要 markdown 代码块、不要任何解释或前后缀文字。"
|
||
)
|
||
body = {
|
||
"model": self.model,
|
||
"messages": [
|
||
{"role": "system", "content": sys_json},
|
||
{"role": "user", "content": user},
|
||
],
|
||
"temperature": 0.1,
|
||
"max_tokens": self.max_tokens,
|
||
"response_format": {"type": "json_object"},
|
||
}
|
||
try:
|
||
resp = httpx.post(
|
||
f"{self.api_base}/chat/completions",
|
||
headers=headers,
|
||
json=body,
|
||
timeout=self.timeout,
|
||
)
|
||
resp.raise_for_status()
|
||
except httpx.HTTPStatusError as exc:
|
||
# A few OpenAI-compatible gateways still do not accept response_format.
|
||
# Retry once without it; other errors are surfaced normally.
|
||
if exc.response.status_code not in {400, 422}:
|
||
raise
|
||
body.pop("response_format", None)
|
||
resp = httpx.post(
|
||
f"{self.api_base}/chat/completions",
|
||
headers=headers,
|
||
json=body,
|
||
timeout=self.timeout,
|
||
)
|
||
resp.raise_for_status()
|
||
choice = resp.json()["choices"][0]
|
||
finish_reason = choice.get("finish_reason", "")
|
||
content = choice["message"]["content"]
|
||
try:
|
||
return _extract_json(content)
|
||
except ValueError as exc:
|
||
if finish_reason == "length":
|
||
raise ValueError(
|
||
f"[finish_reason=length] 输出被 max_tokens={self.max_tokens} 截断,"
|
||
f"模型={self.model},已尝试 JSON 修复但仍失败:{exc}"
|
||
) from exc
|
||
raise
|
||
|
||
def chat_text(self, system: str, user: str) -> str:
|
||
if not self.available():
|
||
raise RuntimeError("LLM not configured")
|
||
|
||
headers = {
|
||
"Authorization": f"Bearer {self.api_key}",
|
||
"Content-Type": "application/json",
|
||
}
|
||
body = {
|
||
"model": self.model,
|
||
"messages": [
|
||
{"role": "system", "content": system},
|
||
{"role": "user", "content": user},
|
||
],
|
||
"temperature": 0.3,
|
||
"max_tokens": self.max_tokens,
|
||
}
|
||
resp = httpx.post(
|
||
f"{self.api_base}/chat/completions",
|
||
headers=headers,
|
||
json=body,
|
||
timeout=self.timeout,
|
||
)
|
||
resp.raise_for_status()
|
||
return resp.json()["choices"][0]["message"]["content"]
|