Files
bxh/app/llm_client.py

196 lines
6.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from __future__ import annotations
import json
import re
import time
from typing import Any
import httpx
from app.config import settings
def _repair_truncated_json(s: str) -> str:
"""Best-effort repair of truncated JSON by closing unclosed brackets/braces.
Handles the common case where a model hits max_tokens mid-output and the
JSON stream ends without closing delimiters.
"""
s = s.rstrip()
# Strip trailing comma/colon that appears right before the cutoff
s = re.sub(r"[,:\s]+$", "", s)
# Walk the string tracking bracket depth (ignore chars inside strings)
stack: list[str] = []
in_string = False
escape = False
for ch in s:
if escape:
escape = False
continue
if ch == "\\" and in_string:
escape = True
continue
if ch == '"':
in_string = not in_string
continue
if in_string:
continue
if ch in "{[":
stack.append(ch)
elif ch in "}]":
if stack:
stack.pop()
# If we ended inside a string literal, close it first
if in_string:
s += '"'
# Close every open bracket in reverse order
closing = {"{": "}", "[": "]"}
for opener in reversed(stack):
s += closing[opener]
return s
def _extract_json(s: str) -> Any:
"""Provider-agnostic JSON extraction (Claude/APIYI often wrap or fence)."""
s = (s or "").strip()
if not s:
raise ValueError("LLM 返回空内容")
try:
return json.loads(s)
except Exception:
pass
m = re.search(r"```(?:json)?\s*(.*?)```", s, re.S)
if m:
try:
return json.loads(m.group(1).strip())
except Exception:
pass
s = re.sub(r"^```(?:json)?\s*", "", s, flags=re.I).strip().removesuffix("```").strip()
for op, cl in (("{", "}"), ("[", "]")):
i, j = s.find(op), s.rfind(cl)
if i != -1 and j > i:
try:
return json.loads(s[i:j + 1])
except Exception:
pass
if i != -1:
try:
return json.loads(_repair_truncated_json(s[i:]))
except Exception:
pass
# Last resort: repair a truncated JSON stream and retry
repaired = _repair_truncated_json(s)
if repaired != s:
for op, cl in (("{", "}"), ("[", "]")):
i, j = repaired.find(op), repaired.rfind(cl)
if i != -1 and j > i:
try:
return json.loads(repaired[i:j + 1])
except Exception:
pass
raise ValueError("LLM 未返回合法 JSON: " + s[:160])
class LlmClient:
def __init__(self, api_base: str, api_key: str, model: str,
timeout: int = 30, max_tokens: int = 4000) -> None:
self.api_base = api_base.rstrip("/")
self.api_key = api_key
self.model = model
self.timeout = timeout
self.max_tokens = max_tokens
@classmethod
def from_settings(cls) -> "LlmClient":
return cls(
api_base=settings.llm_api_base,
api_key=settings.llm_api_key,
model=settings.llm_model,
timeout=settings.llm_timeout_seconds,
)
def available(self) -> bool:
return bool(self.api_base and self.api_key)
def chat_json(self, system: str, user: str) -> dict[str, Any]:
if not self.available():
raise RuntimeError("LLM not configured")
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
sys_json = (system or "") + (
"\n\n严格要求:只输出一个合法 JSON对象或数组"
"不要 markdown 代码块、不要任何解释或前后缀文字。"
)
body = {
"model": self.model,
"messages": [
{"role": "system", "content": sys_json},
{"role": "user", "content": user},
],
"temperature": 0.1,
"max_tokens": self.max_tokens,
"response_format": {"type": "json_object"},
}
try:
resp = httpx.post(
f"{self.api_base}/chat/completions",
headers=headers,
json=body,
timeout=self.timeout,
)
resp.raise_for_status()
except httpx.HTTPStatusError as exc:
# A few OpenAI-compatible gateways still do not accept response_format.
# Retry once without it; other errors are surfaced normally.
if exc.response.status_code not in {400, 422}:
raise
body.pop("response_format", None)
resp = httpx.post(
f"{self.api_base}/chat/completions",
headers=headers,
json=body,
timeout=self.timeout,
)
resp.raise_for_status()
choice = resp.json()["choices"][0]
finish_reason = choice.get("finish_reason", "")
content = choice["message"]["content"]
try:
return _extract_json(content)
except ValueError as exc:
if finish_reason == "length":
raise ValueError(
f"[finish_reason=length] 输出被 max_tokens={self.max_tokens} 截断,"
f"模型={self.model},已尝试 JSON 修复但仍失败:{exc}"
) from exc
raise
def chat_text(self, system: str, user: str) -> str:
if not self.available():
raise RuntimeError("LLM not configured")
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
body = {
"model": self.model,
"messages": [
{"role": "system", "content": system},
{"role": "user", "content": user},
],
"temperature": 0.3,
"max_tokens": self.max_tokens,
}
resp = httpx.post(
f"{self.api_base}/chat/completions",
headers=headers,
json=body,
timeout=self.timeout,
)
resp.raise_for_status()
return resp.json()["choices"][0]["message"]["content"]