139 lines
6.8 KiB
Python
139 lines
6.8 KiB
Python
"""多模型知识蒸馏富集 —— 独立数据来源,不是高德的质检闸门。
|
||
|
||
职责:高德网格法负责"快采骨架";本模块向多个大模型问其脑内知识,
|
||
跨模型蒸馏共识,再与图谱既有数据对齐,把高德给不了的"知识层"
|
||
(简介/历史/特色/适合人群/最佳季节/门票提示)安全写回。
|
||
|
||
隐私红线(用户硬约束):高德数据里**经纬度坐标、电话**绝不外发给任何大模型;
|
||
名称/地址/区县/类别可作为锚点发出(用于确认实体、消同名歧义)。
|
||
蒸馏只写"软知识字段",电话/营业时间/坐标等结构化字段永远以高德为准、不碰。
|
||
|
||
蒸馏 = 多模型一致即可信;并且与图谱既有值对齐:
|
||
既有为空+共识 → adopt 写 / 既有一致 → keep 不动
|
||
既有矛盾 → conflict 不覆盖、转人工 / 无共识 → uncertain 不写
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import json
|
||
|
||
from app.db import get_agent_settings
|
||
from app.llm_client import LlmClient
|
||
|
||
# 仅这些"软知识字段"由蒸馏负责;结构化字段(电话/营业时间/坐标)以高德为准
|
||
ATTR_FIELDS = ["summary", "history", "features",
|
||
"suitable_for", "best_season", "ticket_hint"]
|
||
|
||
_ASK_SYS = """你是贵阳本地文旅知识专家。下面给你一个地点的【已知信息】
|
||
(名称 / 高德详细地址 / 区县 / 类别)。严格两步:
|
||
1) 先判断你掌握的知识是否就是"该地址指向的这个地点":贵阳若有多个同名、
|
||
或你无法确认是不是这个地址的那一个、或你根本不了解 → entity_match 置 false;
|
||
2) 仅当 entity_match=true,才仅凭可靠知识补"知识层"软信息;不知道的字段留空,
|
||
绝不编造,绝不臆改任何已知信息。
|
||
只输出 JSON:
|
||
{"entity_match": true|false,
|
||
"summary":"", "history":"", "features":"",
|
||
"suitable_for":"", "best_season":"", "ticket_hint":""}"""
|
||
|
||
_AGG_SYS = """你是严谨的知识蒸馏与图谱对齐器。给你:
|
||
①【多模型资料】多个模型对同一地点(均已确认是该地址的地点)给出的软信息;
|
||
②【图谱现有软字段】该地点知识图谱里已存在的软字段值(可能来自上一次蒸馏)。
|
||
逐字段处理:
|
||
- 现有为空 且 ≥2 个模型相互印证 → adopt(给出综合后的准确简洁值)
|
||
- 现有已有 且 与多模型共识一致/兼容 → keep(保持不动)
|
||
- 现有已有 且 与共识实质矛盾 → conflict(给出 existing/distilled/简短理由,绝不擅改)
|
||
- 模型间无共识/孤证/都不知道 → uncertain
|
||
只输出 JSON:
|
||
{"adopt":{"字段":"值"},"keep":["字段"],
|
||
"conflict":[{"field":"","existing":"","distilled":"","note":""}],
|
||
"uncertain":["字段"],"confidence":0~1}"""
|
||
|
||
|
||
def _distill_clients(dcfg: dict) -> list[tuple[str, LlmClient]]:
|
||
out: list[tuple[str, LlmClient]] = []
|
||
if not dcfg.get("enabled", True):
|
||
return out
|
||
for k, m in (dcfg.get("models") or {}).items():
|
||
if m.get("enabled") and m.get("api_key") and m.get("base_url"):
|
||
out.append((k, LlmClient(
|
||
api_base=m["base_url"], api_key=m["api_key"],
|
||
model=m.get("model") or "",
|
||
timeout=int(dcfg.get("timeout") or 45))))
|
||
return out
|
||
|
||
|
||
async def _ask_one(key: str, client: LlmClient, q: str) -> dict:
|
||
try:
|
||
r = await asyncio.to_thread(client.chat_json, _ASK_SYS, q)
|
||
return {"model": key, "data": r if isinstance(r, dict) else None}
|
||
except Exception as e: # noqa: BLE001
|
||
return {"model": key, "data": None, "error": str(e)[:80]}
|
||
|
||
|
||
_EMPTY = {"ok": True, "adopt": {}, "keep": [], "conflict": [],
|
||
"uncertain": ATTR_FIELDS}
|
||
|
||
|
||
async def distill_entity(entity: dict) -> dict:
|
||
"""entity: {name, place_type, district, address, existing:{软字段现值}}。
|
||
|
||
锚点只含 名称/地址/区县/类别(**不发坐标、不发电话**)。
|
||
返回 {ok, adopt, keep, conflict, uncertain, confidence, n, summary}。
|
||
安全降级:启用蒸馏模型<2 / 无聚合模型 / 聚合失败 → ok=False(不写、待配置重试)。
|
||
"""
|
||
cfg = await get_agent_settings()
|
||
clients = _distill_clients(cfg.get("distill", {}))
|
||
if len(clients) < 2:
|
||
return {"ok": False, "summary": f"启用蒸馏模型不足2个({len(clients)})"}
|
||
|
||
# 发给模型的锚点:绝不含经纬度、电话
|
||
q = json.dumps({
|
||
"名称": entity.get("name"),
|
||
"城市": "贵阳",
|
||
"类别": entity.get("place_type") or "",
|
||
"高德详细地址": entity.get("address") or "",
|
||
"区县": entity.get("district") or "",
|
||
}, ensure_ascii=False)
|
||
|
||
answers = list(await asyncio.gather(
|
||
*[_ask_one(k, c, q) for k, c in clients]))
|
||
# 只保留"明确确认是这个地址的地点"的回答(治同名歧义)
|
||
confirmed = [a["data"] for a in answers
|
||
if a.get("data") and a["data"].get("entity_match") is True]
|
||
if len(confirmed) < 2:
|
||
return {**_EMPTY, "n": len(confirmed),
|
||
"summary": f"仅{len(confirmed)}个模型确认实体,跳过(留待其它途径)"}
|
||
|
||
g = cfg.get("global", {})
|
||
if not g.get("api_key") or not g.get("base_url"):
|
||
return {"ok": False, "summary": "未配全局聚合模型"}
|
||
agg = LlmClient(api_base=g["base_url"], api_key=g["api_key"],
|
||
model=g.get("model") or "", timeout=int(g.get("timeout") or 90))
|
||
body = json.dumps({
|
||
"地点": entity.get("name"),
|
||
"锚点": {"地址": entity.get("address") or "",
|
||
"区县": entity.get("district") or "",
|
||
"类别": entity.get("place_type") or ""},
|
||
"多模型资料": [{k: v for k, v in c.items() if k != "entity_match"}
|
||
for c in confirmed],
|
||
"图谱现有软字段": {k: v for k, v in
|
||
(entity.get("existing") or {}).items() if v},
|
||
}, ensure_ascii=False)
|
||
try:
|
||
merged = await asyncio.to_thread(agg.chat_json, _AGG_SYS, body)
|
||
except Exception as e: # noqa: BLE001
|
||
return {"ok": False, "summary": f"共识聚合失败:{str(e)[:60]}"}
|
||
|
||
adopt = {k: str(v).strip()
|
||
for k, v in (merged.get("adopt") or {}).items()
|
||
if k in ATTR_FIELDS and str(v).strip()}
|
||
keep = [k for k in (merged.get("keep") or []) if k in ATTR_FIELDS]
|
||
conflict = [c for c in (merged.get("conflict") or [])
|
||
if isinstance(c, dict) and c.get("field") in ATTR_FIELDS]
|
||
uncertain = [u for u in (merged.get("uncertain") or []) if u in ATTR_FIELDS]
|
||
return {"ok": True, "adopt": adopt, "keep": keep, "conflict": conflict,
|
||
"uncertain": uncertain, "confidence": merged.get("confidence"),
|
||
"n": len(confirmed),
|
||
"summary": f"{len(confirmed)}模型确认 → 采纳{len(adopt)}"
|
||
f"·一致{len(keep)}·矛盾{len(conflict)}·存疑{len(uncertain)}"}
|