Files
bxh/app/agents/multi_extract.py

258 lines
10 KiB
Python

"""通用多模型知识抽取 —— 3 LLM 抽取 + 1 LLM 决策。
设计原则:
- **主 agent (opus / global) 不下场**, 只调度; 抽取走这里的独立模型池, 避免一处欠费全瘫
- **独立 API 池**(与 distill 分开): cfg["extract"]["models"] 自己一份
- **启用即用, 不启用即跳过**: 用户在系统设置 → 知识抽取 卡里勾选
- 抽取器 < 1 或没决策器 → 这个实体跳过, 不偷偷回退到全局 LLM
落地形态:
- cfg["extract"]["models"][key] = {label, enabled, base_url, api_key, model} ← API 真值
- cfg["extract"]["aggregator"] = "<model_key>" ← 决策器单选
- 抽取器=池中所有 enabled 且非 aggregator 的 model; 决策器=aggregator 指定那个
"""
from __future__ import annotations
import asyncio
import time
from app.llm_client import LlmClient
# ---------------------------------------------------------------------------
# Per-model output-token caps (2026 confirmed limits).
# Matched as substrings of the lowercased model name; first match wins.
# ---------------------------------------------------------------------------
_MODEL_MAX_OUTPUT: list[tuple[str, int]] = [
# Doubao / Seed 2.0 -------------------------------------------------------
("seed-2.0-pro", 131072), # 128 K
("seed-2-0-pro", 131072),
("seed-2.0-lite", 32768), # 32 K
("seed-2-0-lite", 32768),
("seed-2.0", 65536), # 64 K safe default for other Seed 2.0 variants
("seed-2-0", 65536),
("doubao-1.5", 32768),
("doubao", 16384), # conservative for unlabelled doubao endpoints
# DeepSeek ----------------------------------------------------------------
("deepseek-v4", 384000), # DeepSeek V4 series: 384 K max output
("v4-pro", 384000),
("v4-flash", 384000),
("deepseek-v3", 16384),
("deepseek-r1", 32768),
("deepseek", 16384),
# Qwen / 通义 ---------------------------------------------------------------
("qwen3.7", 32768),
("qwen3-max", 32768),
("qwen3-plus", 32768),
("qwen3", 32768),
("qwen2.5", 8192),
("qwen-max", 8192),
("qwen-plus", 8192),
("qwen-turbo", 8192),
("qwen", 8192),
# GLM / 智谱 ----------------------------------------------------------------
("glm-5", 131072), # GLM-5.1: 128 K
("glm-4.7", 131072), # GLM-4.7: 128 K
("glm-4.6", 32768),
("glm-4-long", 4096), # old long-ctx variant: 4 K output cap
("glm-4", 4096), # classic GLM-4: 4 K hard cap
("glm", 4096),
]
_DEFAULT_MAX_OUTPUT = 16384 # fallback for unknown/unlabelled models
def _model_max_output_tokens(model: str) -> int:
"""Return the known safe max output tokens for *model*.
Matches on lowercased substrings so partial names like Volcengine
endpoint IDs still resolve to a reasonable default rather than 4 K.
"""
m = (model or "").lower()
for pattern, limit in _MODEL_MAX_OUTPUT:
if pattern in m:
return limit
return _DEFAULT_MAX_OUTPUT
def _configured_max_output_tokens(api_cfg: dict, fallback: int) -> int:
"""Resolve user-configured max_tokens, falling back to model defaults."""
for key in ("max_tokens", "output_tokens", "tokens"):
raw = api_cfg.get(key)
if raw in (None, ""):
continue
try:
value = int(raw)
except (TypeError, ValueError):
continue
if value > 0:
return value
return fallback
def _make_client(api_cfg: dict, timeout: int,
max_tokens: int | None = None) -> LlmClient | None:
if not api_cfg or not api_cfg.get("api_key") or not api_cfg.get("base_url"):
return None
model = api_cfg.get("model") or ""
fallback_max = max_tokens if max_tokens is not None else _model_max_output_tokens(model)
resolved_max = _configured_max_output_tokens(api_cfg, fallback_max)
return LlmClient(
api_base=api_cfg["base_url"],
api_key=api_cfg["api_key"],
model=model,
timeout=timeout,
max_tokens=resolved_max,
)
def _model_fingerprint(api_cfg: dict | None) -> tuple[str, str]:
"""Use provider endpoint + model to avoid duplicate votes without exposing keys."""
if not api_cfg:
return ("", "")
return (
str(api_cfg.get("base_url") or "").rstrip("/").lower(),
str(api_cfg.get("model") or "").strip().lower(),
)
def build_extract_pool(
cfg: dict,
*,
skip_duplicate_aggregator: bool = True,
) -> tuple[list[tuple[str, LlmClient]], tuple[str, LlmClient] | None, str]:
"""从 agent_settings 解出 (extractors, aggregator, status_msg)。
extractors: 启用且能建客户端的抽取器列表 [(key, client), ...]
aggregator: (key, client) 或 None
status_msg: 一句话状态(用于 summary 回显)
"""
ex = cfg.get("extract") or {}
if not ex.get("enabled", True):
return [], None, "知识抽取(extract)已停用"
pool: dict[str, dict] = ex.get("models") or {}
timeout = int(ex.get("timeout") or 90)
agg_key = (ex.get("aggregator") or "").strip()
agg_fp = _model_fingerprint(pool.get(agg_key))
skipped_duplicates: list[str] = []
extractors: list[tuple[str, LlmClient]] = []
for k, m in pool.items():
if not m.get("enabled"):
continue
if k == agg_key:
# 决策器不同时扮抽取器, 避免投票偏置
continue
if skip_duplicate_aggregator and agg_fp != ("", "") and _model_fingerprint(m) == agg_fp:
# 同 provider + 同 model 的抽取器和决策器本质上是一票, 跳过可减少等待和偏置。
skipped_duplicates.append(k)
continue
# max_tokens=None → _make_client 自动按模型名识别上限
client = _make_client(m, timeout)
if client:
extractors.append((k, client))
agg: tuple[str, LlmClient] | None = None
if agg_key and agg_key in pool:
# 决策器合并多家结果,输出通常比单次抽取更长,保持模型自身上限
client = _make_client(pool[agg_key], timeout)
if client:
agg = (agg_key, client)
msg = (f"抽取器={len(extractors)}({','.join(k for k, _ in extractors)})"
f" · 决策器={agg[0] if agg else ''}")
if skipped_duplicates:
msg += f" · 已跳过重复模型({','.join(skipped_duplicates)})"
return extractors, agg, msg
async def fan_out(system: str, user: str,
extractors: list[tuple[str, LlmClient]],
min_valid: int | None = None,
max_wait_seconds: int | None = None) -> list[dict]:
"""并行让每个抽取器输出 JSON; 单个失败不阻断。返回 [{model, data, error?}]。"""
fan_started = time.perf_counter()
async def _one(k: str, c: LlmClient) -> dict:
started = time.perf_counter()
try:
r = await asyncio.to_thread(c.chat_json, system, user)
return {
"model": k,
"data": r if isinstance(r, dict) else None,
"seconds": round(time.perf_counter() - started, 2),
}
except Exception as e: # noqa: BLE001
return {
"model": k,
"data": None,
"error": str(e)[:120],
"seconds": round(time.perf_counter() - started, 2),
}
if not extractors:
return []
if min_valid is None and max_wait_seconds is None:
return list(await asyncio.gather(*[_one(k, c) for k, c in extractors]))
target_valid = max(1, min(int(min_valid or len(extractors)), len(extractors)))
deadline = time.monotonic() + max_wait_seconds if max_wait_seconds else None
pending = {
asyncio.create_task(_one(k, c)): k
for k, c in extractors
}
results: list[dict] = []
while pending:
timeout = None
if deadline is not None:
timeout = max(0.0, deadline - time.monotonic())
if timeout <= 0:
break
done, _ = await asyncio.wait(
pending.keys(),
timeout=timeout,
return_when=asyncio.FIRST_COMPLETED,
)
if not done:
break
for task in done:
pending.pop(task, None)
try:
results.append(task.result())
except asyncio.CancelledError:
continue
except Exception as e: # noqa: BLE001
results.append({"model": "unknown", "data": None, "error": str(e)[:120]})
valid_count = sum(1 for r in results if isinstance(r.get("data"), dict))
if valid_count >= target_valid:
break
enough = sum(1 for r in results if isinstance(r.get("data"), dict)) >= target_valid
for task, model_key in list(pending.items()):
task.cancel()
results.append({
"model": model_key,
"data": None,
"error": "skipped_after_quorum" if enough else "timeout_waiting_for_model",
"seconds": round(time.perf_counter() - fan_started, 2),
"skipped": bool(enough),
})
return results
async def decide(system: str, user: str,
agg: tuple[str, LlmClient],
max_wait_seconds: int | None = None) -> tuple[dict | None, str]:
"""决策器单次裁决, 返回 (result, err_msg)。失败时 result=None, err_msg 含原因。"""
_, c = agg
try:
call = asyncio.to_thread(c.chat_json, system, user)
r = await asyncio.wait_for(call, timeout=max_wait_seconds) if max_wait_seconds else await call
if isinstance(r, dict):
return r, ""
return None, f"返回非 JSON dict (type={type(r).__name__})"
except asyncio.TimeoutError:
return None, f"决策器超过 {max_wait_seconds} 秒未返回"
except Exception as e: # noqa: BLE001
return None, str(e)[:200]