Files
bxh/app/agents/distill_gate.py

139 lines
6.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""多模型知识蒸馏富集 —— 独立数据来源,不是高德的质检闸门。
职责:高德网格法负责"快采骨架";本模块向多个大模型问其脑内知识,
跨模型蒸馏共识,再与图谱既有数据对齐,把高德给不了的"知识层"
(简介/历史/特色/适合人群/最佳季节/门票提示)安全写回。
隐私红线(用户硬约束):高德数据里**经纬度坐标、电话**绝不外发给任何大模型;
名称/地址/区县/类别可作为锚点发出(用于确认实体、消同名歧义)。
蒸馏只写"软知识字段",电话/营业时间/坐标等结构化字段永远以高德为准、不碰。
蒸馏 = 多模型一致即可信;并且与图谱既有值对齐:
既有为空+共识 → adopt 写 / 既有一致 → keep 不动
既有矛盾 → conflict 不覆盖、转人工 / 无共识 → uncertain 不写
"""
from __future__ import annotations
import asyncio
import json
from app.db import get_agent_settings
from app.llm_client import LlmClient
# 仅这些"软知识字段"由蒸馏负责;结构化字段(电话/营业时间/坐标)以高德为准
ATTR_FIELDS = ["summary", "history", "features",
"suitable_for", "best_season", "ticket_hint"]
_ASK_SYS = """你是贵阳本地文旅知识专家。下面给你一个地点的【已知信息】
(名称 / 高德详细地址 / 区县 / 类别)。严格两步:
1) 先判断你掌握的知识是否就是"该地址指向的这个地点":贵阳若有多个同名、
或你无法确认是不是这个地址的那一个、或你根本不了解 → entity_match 置 false
2) 仅当 entity_match=true才仅凭可靠知识补"知识层"软信息;不知道的字段留空,
绝不编造,绝不臆改任何已知信息。
只输出 JSON
{"entity_match": true|false,
"summary":"", "history":"", "features":"",
"suitable_for":"", "best_season":"", "ticket_hint":""}"""
_AGG_SYS = """你是严谨的知识蒸馏与图谱对齐器。给你:
①【多模型资料】多个模型对同一地点(均已确认是该地址的地点)给出的软信息;
②【图谱现有软字段】该地点知识图谱里已存在的软字段值(可能来自上一次蒸馏)。
逐字段处理:
- 现有为空 且 ≥2 个模型相互印证 → adopt给出综合后的准确简洁值
- 现有已有 且 与多模型共识一致/兼容 → keep保持不动
- 现有已有 且 与共识实质矛盾 → conflict给出 existing/distilled/简短理由,绝不擅改)
- 模型间无共识/孤证/都不知道 → uncertain
只输出 JSON
{"adopt":{"字段":""},"keep":["字段"],
"conflict":[{"field":"","existing":"","distilled":"","note":""}],
"uncertain":["字段"],"confidence":0~1}"""
def _distill_clients(dcfg: dict) -> list[tuple[str, LlmClient]]:
out: list[tuple[str, LlmClient]] = []
if not dcfg.get("enabled", True):
return out
for k, m in (dcfg.get("models") or {}).items():
if m.get("enabled") and m.get("api_key") and m.get("base_url"):
out.append((k, LlmClient(
api_base=m["base_url"], api_key=m["api_key"],
model=m.get("model") or "",
timeout=int(dcfg.get("timeout") or 45))))
return out
async def _ask_one(key: str, client: LlmClient, q: str) -> dict:
try:
r = await asyncio.to_thread(client.chat_json, _ASK_SYS, q)
return {"model": key, "data": r if isinstance(r, dict) else None}
except Exception as e: # noqa: BLE001
return {"model": key, "data": None, "error": str(e)[:80]}
_EMPTY = {"ok": True, "adopt": {}, "keep": [], "conflict": [],
"uncertain": ATTR_FIELDS}
async def distill_entity(entity: dict) -> dict:
"""entity: {name, place_type, district, address, existing:{软字段现值}}。
锚点只含 名称/地址/区县/类别(**不发坐标、不发电话**)。
返回 {ok, adopt, keep, conflict, uncertain, confidence, n, summary}。
安全降级:启用蒸馏模型<2 / 无聚合模型 / 聚合失败 → ok=False不写、待配置重试
"""
cfg = await get_agent_settings()
clients = _distill_clients(cfg.get("distill", {}))
if len(clients) < 2:
return {"ok": False, "summary": f"启用蒸馏模型不足2个{len(clients)}"}
# 发给模型的锚点:绝不含经纬度、电话
q = json.dumps({
"名称": entity.get("name"),
"城市": "贵阳",
"类别": entity.get("place_type") or "",
"高德详细地址": entity.get("address") or "",
"区县": entity.get("district") or "",
}, ensure_ascii=False)
answers = list(await asyncio.gather(
*[_ask_one(k, c, q) for k, c in clients]))
# 只保留"明确确认是这个地址的地点"的回答(治同名歧义)
confirmed = [a["data"] for a in answers
if a.get("data") and a["data"].get("entity_match") is True]
if len(confirmed) < 2:
return {**_EMPTY, "n": len(confirmed),
"summary": f"{len(confirmed)}个模型确认实体,跳过(留待其它途径)"}
g = cfg.get("global", {})
if not g.get("api_key") or not g.get("base_url"):
return {"ok": False, "summary": "未配全局聚合模型"}
agg = LlmClient(api_base=g["base_url"], api_key=g["api_key"],
model=g.get("model") or "", timeout=int(g.get("timeout") or 90))
body = json.dumps({
"地点": entity.get("name"),
"锚点": {"地址": entity.get("address") or "",
"区县": entity.get("district") or "",
"类别": entity.get("place_type") or ""},
"多模型资料": [{k: v for k, v in c.items() if k != "entity_match"}
for c in confirmed],
"图谱现有软字段": {k: v for k, v in
(entity.get("existing") or {}).items() if v},
}, ensure_ascii=False)
try:
merged = await asyncio.to_thread(agg.chat_json, _AGG_SYS, body)
except Exception as e: # noqa: BLE001
return {"ok": False, "summary": f"共识聚合失败:{str(e)[:60]}"}
adopt = {k: str(v).strip()
for k, v in (merged.get("adopt") or {}).items()
if k in ATTR_FIELDS and str(v).strip()}
keep = [k for k in (merged.get("keep") or []) if k in ATTR_FIELDS]
conflict = [c for c in (merged.get("conflict") or [])
if isinstance(c, dict) and c.get("field") in ATTR_FIELDS]
uncertain = [u for u in (merged.get("uncertain") or []) if u in ATTR_FIELDS]
return {"ok": True, "adopt": adopt, "keep": keep, "conflict": conflict,
"uncertain": uncertain, "confidence": merged.get("confidence"),
"n": len(confirmed),
"summary": f"{len(confirmed)}模型确认 → 采纳{len(adopt)}"
f"·一致{len(keep)}·矛盾{len(conflict)}·存疑{len(uncertain)}"}