"""通用多模型知识抽取 —— 3 LLM 抽取 + 1 LLM 决策。 设计原则: - **主 agent (opus / global) 不下场**, 只调度; 抽取走这里的独立模型池, 避免一处欠费全瘫 - **独立 API 池**(与 distill 分开): cfg["extract"]["models"] 自己一份 - **启用即用, 不启用即跳过**: 用户在系统设置 → 知识抽取 卡里勾选 - 抽取器 < 1 或没决策器 → 这个实体跳过, 不偷偷回退到全局 LLM 落地形态: - cfg["extract"]["models"][key] = {label, enabled, base_url, api_key, model} ← API 真值 - cfg["extract"]["aggregator"] = "" ← 决策器单选 - 抽取器=池中所有 enabled 且非 aggregator 的 model; 决策器=aggregator 指定那个 """ from __future__ import annotations import asyncio import time from app.llm_client import LlmClient # --------------------------------------------------------------------------- # Per-model output-token caps (2026 confirmed limits). # Matched as substrings of the lowercased model name; first match wins. # --------------------------------------------------------------------------- _MODEL_MAX_OUTPUT: list[tuple[str, int]] = [ # Doubao / Seed 2.0 ------------------------------------------------------- ("seed-2.0-pro", 131072), # 128 K ("seed-2-0-pro", 131072), ("seed-2.0-lite", 32768), # 32 K ("seed-2-0-lite", 32768), ("seed-2.0", 65536), # 64 K safe default for other Seed 2.0 variants ("seed-2-0", 65536), ("doubao-1.5", 32768), ("doubao", 16384), # conservative for unlabelled doubao endpoints # DeepSeek ---------------------------------------------------------------- ("deepseek-v4", 384000), # DeepSeek V4 series: 384 K max output ("v4-pro", 384000), ("v4-flash", 384000), ("deepseek-v3", 16384), ("deepseek-r1", 32768), ("deepseek", 16384), # Qwen / 通义 --------------------------------------------------------------- ("qwen3.7", 32768), ("qwen3-max", 32768), ("qwen3-plus", 32768), ("qwen3", 32768), ("qwen2.5", 8192), ("qwen-max", 8192), ("qwen-plus", 8192), ("qwen-turbo", 8192), ("qwen", 8192), # GLM / 智谱 ---------------------------------------------------------------- ("glm-5", 131072), # GLM-5.1: 128 K ("glm-4.7", 131072), # GLM-4.7: 128 K ("glm-4.6", 32768), ("glm-4-long", 4096), # old long-ctx variant: 4 K output cap ("glm-4", 4096), # classic GLM-4: 4 K hard cap ("glm", 4096), ] _DEFAULT_MAX_OUTPUT = 16384 # fallback for unknown/unlabelled models def _model_max_output_tokens(model: str) -> int: """Return the known safe max output tokens for *model*. Matches on lowercased substrings so partial names like Volcengine endpoint IDs still resolve to a reasonable default rather than 4 K. """ m = (model or "").lower() for pattern, limit in _MODEL_MAX_OUTPUT: if pattern in m: return limit return _DEFAULT_MAX_OUTPUT def _configured_max_output_tokens(api_cfg: dict, fallback: int) -> int: """Resolve user-configured max_tokens, falling back to model defaults.""" for key in ("max_tokens", "output_tokens", "tokens"): raw = api_cfg.get(key) if raw in (None, ""): continue try: value = int(raw) except (TypeError, ValueError): continue if value > 0: return value return fallback def _make_client(api_cfg: dict, timeout: int, max_tokens: int | None = None) -> LlmClient | None: if not api_cfg or not api_cfg.get("api_key") or not api_cfg.get("base_url"): return None model = api_cfg.get("model") or "" fallback_max = max_tokens if max_tokens is not None else _model_max_output_tokens(model) resolved_max = _configured_max_output_tokens(api_cfg, fallback_max) return LlmClient( api_base=api_cfg["base_url"], api_key=api_cfg["api_key"], model=model, timeout=timeout, max_tokens=resolved_max, ) def _model_fingerprint(api_cfg: dict | None) -> tuple[str, str]: """Use provider endpoint + model to avoid duplicate votes without exposing keys.""" if not api_cfg: return ("", "") return ( str(api_cfg.get("base_url") or "").rstrip("/").lower(), str(api_cfg.get("model") or "").strip().lower(), ) def build_extract_pool( cfg: dict, *, skip_duplicate_aggregator: bool = True, ) -> tuple[list[tuple[str, LlmClient]], tuple[str, LlmClient] | None, str]: """从 agent_settings 解出 (extractors, aggregator, status_msg)。 extractors: 启用且能建客户端的抽取器列表 [(key, client), ...] aggregator: (key, client) 或 None status_msg: 一句话状态(用于 summary 回显) """ ex = cfg.get("extract") or {} if not ex.get("enabled", True): return [], None, "知识抽取(extract)已停用" pool: dict[str, dict] = ex.get("models") or {} timeout = int(ex.get("timeout") or 90) agg_key = (ex.get("aggregator") or "").strip() agg_fp = _model_fingerprint(pool.get(agg_key)) skipped_duplicates: list[str] = [] extractors: list[tuple[str, LlmClient]] = [] for k, m in pool.items(): if not m.get("enabled"): continue if k == agg_key: # 决策器不同时扮抽取器, 避免投票偏置 continue if skip_duplicate_aggregator and agg_fp != ("", "") and _model_fingerprint(m) == agg_fp: # 同 provider + 同 model 的抽取器和决策器本质上是一票, 跳过可减少等待和偏置。 skipped_duplicates.append(k) continue # max_tokens=None → _make_client 自动按模型名识别上限 client = _make_client(m, timeout) if client: extractors.append((k, client)) agg: tuple[str, LlmClient] | None = None if agg_key and agg_key in pool: # 决策器合并多家结果,输出通常比单次抽取更长,保持模型自身上限 client = _make_client(pool[agg_key], timeout) if client: agg = (agg_key, client) msg = (f"抽取器={len(extractors)}({','.join(k for k, _ in extractors)})" f" · 决策器={agg[0] if agg else '无'}") if skipped_duplicates: msg += f" · 已跳过重复模型({','.join(skipped_duplicates)})" return extractors, agg, msg async def fan_out(system: str, user: str, extractors: list[tuple[str, LlmClient]], min_valid: int | None = None, max_wait_seconds: int | None = None) -> list[dict]: """并行让每个抽取器输出 JSON; 单个失败不阻断。返回 [{model, data, error?}]。""" fan_started = time.perf_counter() async def _one(k: str, c: LlmClient) -> dict: started = time.perf_counter() try: r = await asyncio.to_thread(c.chat_json, system, user) return { "model": k, "data": r if isinstance(r, dict) else None, "seconds": round(time.perf_counter() - started, 2), } except Exception as e: # noqa: BLE001 return { "model": k, "data": None, "error": str(e)[:120], "seconds": round(time.perf_counter() - started, 2), } if not extractors: return [] if min_valid is None and max_wait_seconds is None: return list(await asyncio.gather(*[_one(k, c) for k, c in extractors])) target_valid = max(1, min(int(min_valid or len(extractors)), len(extractors))) deadline = time.monotonic() + max_wait_seconds if max_wait_seconds else None pending = { asyncio.create_task(_one(k, c)): k for k, c in extractors } results: list[dict] = [] while pending: timeout = None if deadline is not None: timeout = max(0.0, deadline - time.monotonic()) if timeout <= 0: break done, _ = await asyncio.wait( pending.keys(), timeout=timeout, return_when=asyncio.FIRST_COMPLETED, ) if not done: break for task in done: pending.pop(task, None) try: results.append(task.result()) except asyncio.CancelledError: continue except Exception as e: # noqa: BLE001 results.append({"model": "unknown", "data": None, "error": str(e)[:120]}) valid_count = sum(1 for r in results if isinstance(r.get("data"), dict)) if valid_count >= target_valid: break enough = sum(1 for r in results if isinstance(r.get("data"), dict)) >= target_valid for task, model_key in list(pending.items()): task.cancel() results.append({ "model": model_key, "data": None, "error": "skipped_after_quorum" if enough else "timeout_waiting_for_model", "seconds": round(time.perf_counter() - fan_started, 2), "skipped": bool(enough), }) return results async def decide(system: str, user: str, agg: tuple[str, LlmClient], max_wait_seconds: int | None = None) -> tuple[dict | None, str]: """决策器单次裁决, 返回 (result, err_msg)。失败时 result=None, err_msg 含原因。""" _, c = agg try: call = asyncio.to_thread(c.chat_json, system, user) r = await asyncio.wait_for(call, timeout=max_wait_seconds) if max_wait_seconds else await call if isinstance(r, dict): return r, "" return None, f"返回非 JSON dict (type={type(r).__name__})" except asyncio.TimeoutError: return None, f"决策器超过 {max_wait_seconds} 秒未返回" except Exception as e: # noqa: BLE001 return None, str(e)[:200]