Files
bxh/app/api/super_agent.py

248 lines
9.9 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Super Agent — P1 deterministic tools (preview only, no write yet)."""
from fastapi import APIRouter, HTTPException
from app.auth import CurrentUser
from app.agents.gaode_connector import available_types, search_pois
from app.agents.super_ingest import ingest_gaode
from app.agents.super_orchestrator import (
schedule_super_agent, _coverage, _HARD_MAX_STEPS,
)
from app.db import (
sa_create_run, sa_get_run, sa_latest_run, sa_request_stop, sa_list_tasks,
grid_counts, get_conn,
)
import asyncio
router = APIRouter()
SPATIAL_GRAPH_NAME = "guiyang_spatial_v1"
SPATIAL_TYPE_ORDER = [
"景点", "美食", "酒店", "商场", "医疗保健", "交通设施", "生活服务",
"科教文化", "政府机构", "公共设施", "体育休闲", "商务住宅", "公司企业",
"金融保险", "汽车服务", "汽车维修", "汽车销售", "摩托车服务",
"地名地址", "道路附属",
]
SPATIAL_TYPE_FALLBACK = {
"景点": "sight",
"美食": "eat",
"酒店": "hotel",
"商场": "mall",
"医疗保健": "medical",
"交通设施": "transit",
"生活服务": "life",
"科教文化": "education",
"政府机构": "government",
"公共设施": "facility",
"商务住宅": "residential",
"公司企业": "enterprise",
}
async def _spatial_h3_coverage() -> dict | None:
"""Return coverage for the new H3 spatial collector, if its tables exist."""
from app.config import settings
s = settings.db_schema
try:
async with get_conn() as conn:
async with conn.cursor() as cur:
await cur.execute(
"SELECT to_regclass(%s) AS pois_table, to_regclass(%s) AS tasks_table",
(f"{s}.amap_spatial_pois", f"{s}.amap_spatial_collect_tasks"),
)
tables = await cur.fetchone()
if not tables or not tables["pois_table"] or not tables["tasks_table"]:
return None
await cur.execute(
f"""SELECT type_label, place_type, COUNT(*) AS cnt
FROM {s}.amap_spatial_pois
WHERE graph_name=%s
GROUP BY type_label, place_type""",
(SPATIAL_GRAPH_NAME,),
)
counts = await cur.fetchall()
if not counts:
return None
await cur.execute(
f"""SELECT type_label,
COUNT(*) FILTER (WHERE resolution=6) AS grid_total,
COUNT(*) FILTER (
WHERE resolution=6
AND status IN ('done','saturated','saturated_max_res')
) AS grid_done,
COUNT(*) FILTER (WHERE resolution=6 AND status='pending') AS grid_pending,
COUNT(*) FILTER (WHERE resolution=6 AND status='running') AS grid_running,
COUNT(*) FILTER (WHERE resolution=6 AND status='error') AS grid_error,
COUNT(*) FILTER (WHERE resolution=6 AND status='quota_limited') AS grid_quota_limited
FROM {s}.amap_spatial_collect_tasks
WHERE graph_name=%s
GROUP BY type_label""",
(SPATIAL_GRAPH_NAME,),
)
grid_rows = {r["type_label"]: r for r in await cur.fetchall()}
await cur.execute(
f"""SELECT status, COUNT(*) AS cnt
FROM {s}.amap_spatial_collect_tasks
WHERE graph_name=%s
GROUP BY status""",
(SPATIAL_GRAPH_NAME,),
)
status_counts = {r["status"]: int(r["cnt"]) for r in await cur.fetchall()}
count_by_label: dict[str, dict] = {}
for r in counts:
label = r["type_label"] or "未分类"
item = count_by_label.setdefault(
label,
{
"cat": label,
"place_type": r["place_type"] or SPATIAL_TYPE_FALLBACK.get(label, "poi"),
"current": 0,
},
)
item["current"] += int(r["cnt"] or 0)
order_index = {label: i for i, label in enumerate(SPATIAL_TYPE_ORDER)}
items: list[dict] = []
for label, item in count_by_label.items():
grid = grid_rows.get(label, {})
grid_total = int(grid.get("grid_total") or 0)
grid_done = int(grid.get("grid_done") or 0)
grid_pending = int(grid.get("grid_pending") or 0)
grid_running = int(grid.get("grid_running") or 0)
grid_error = int(grid.get("grid_error") or 0)
grid_quota_limited = int(grid.get("grid_quota_limited") or 0)
item.update(
{
"grid_total": grid_total,
"grid_done": grid_done,
"grid_pending": grid_pending,
"grid_running": grid_running,
"grid_error": grid_error,
"grid_quota_limited": grid_quota_limited,
"grid_pct": round(grid_done / grid_total * 100) if grid_total else 0,
}
)
items.append(item)
items.sort(key=lambda x: (order_index.get(x["cat"], 999), -x["current"], x["cat"]))
total = sum(i["current"] for i in items)
grid_total_sum = sum(i["grid_total"] for i in items)
grid_done_sum = sum(i["grid_done"] for i in items)
return {
"source": "amap_spatial_h3",
"graph_name": SPATIAL_GRAPH_NAME,
"items": items,
"total": total,
"grid_overall_pct": round(grid_done_sum / grid_total_sum * 100) if grid_total_sum else 0,
"status_counts": status_counts,
"error_tasks": int(status_counts.get("error", 0)),
"quota_limited": int(status_counts.get("quota_limited", 0)),
}
except Exception:
return None
@router.get("/super-agent/gaode/types")
async def _types(_user: CurrentUser = None):
return {"types": available_types()}
@router.post("/super-agent/gaode/preview")
async def _preview(body: dict, _user: CurrentUser):
"""Fetch + normalize 高德 POIs for preview. Does NOT write anywhere."""
poi_type = body.get("poi_type")
keyword = body.get("keyword")
max_pages = max(1, min(int(body.get("max_pages", 1)), 5))
limit = max(1, min(int(body.get("limit", 40)), 100))
if not poi_type and not keyword:
raise HTTPException(400, "需要 poi_type 或 keyword")
try:
rows = search_pois(poi_type=poi_type, keyword=keyword,
max_pages=max_pages, limit=limit)
except FileNotFoundError as e:
raise HTTPException(400, str(e))
except Exception as e:
raise HTTPException(400, f"高德采集失败:{str(e)[:200]}")
return {"count": len(rows), "rows": rows}
@router.post("/super-agent/gaode/ingest")
async def _ingest(body: dict, _user: CurrentUser):
"""Fetch 高德 POIs and write into BOTH stores (idempotent).
valid → approved (PG + FalkorDB); incomplete → pending_review (PG only).
"""
poi_type = body.get("poi_type")
keyword = body.get("keyword")
max_pages = max(1, min(int(body.get("max_pages", 2)), 10))
limit = max(1, min(int(body.get("limit", 60)), 200))
if not poi_type and not keyword:
raise HTTPException(400, "需要 poi_type 或 keyword")
try:
return await ingest_gaode(poi_type, keyword, max_pages, limit)
except FileNotFoundError as e:
raise HTTPException(400, str(e))
except Exception as e:
raise HTTPException(400, f"高德入库失败:{str(e)[:200]}")
# ── Autonomous loop (guarded) ────────────────────────────────────────────────
@router.post("/super-agent/run")
async def _run(body: dict | None = None, _user: CurrentUser = None):
"""启动馆长自治补全。无需人工设步数/预算——按馆藏蓝图自驱。"""
goal = (body or {}).get("goal") or "馆长常驻:按数据源饱和度自治补全并驻守"
run_id = await sa_create_run(goal, _HARD_MAX_STEPS, 0)
schedule_super_agent(run_id)
return {"run_id": run_id, "goal": goal}
@router.get("/super-agent/coverage")
async def _cov(_user: CurrentUser = None):
"""馆藏盘点:各类当前藏品数 + 全城网格地理覆盖率(真实进度)。"""
spatial_cov = await _spatial_h3_coverage()
if spatial_cov:
return spatial_cov
cov = await asyncio.to_thread(_coverage)
gc = await grid_counts()
for it in cov["items"]:
g = gc.get(it["cat"], {})
tot, done = g.get("total", 0), g.get("done", 0)
it["grid_total"] = tot
it["grid_done"] = done
it["grid_pct"] = round(done / tot * 100) if tot else 0
cov["grid_overall_pct"] = (
round(sum(i["grid_done"] for i in cov["items"]) /
max(1, sum(i["grid_total"] for i in cov["items"])) * 100))
return cov
@router.get("/super-agent/tasks")
async def _tasks(run_id: int | None = None, _user: CurrentUser = None):
"""工单台账AI 每一步做了什么、结果如何、是否升级。"""
return {"tasks": await sa_list_tasks(run_id)}
@router.get("/super-agent/runs/latest")
async def _latest(_user: CurrentUser = None):
return await sa_latest_run() or {}
@router.get("/super-agent/runs/{run_id}")
async def _run_get(run_id: int, _user: CurrentUser = None):
r = await sa_get_run(run_id)
if not r:
raise HTTPException(404, "run not found")
return r
@router.post("/super-agent/runs/{run_id}/stop")
async def _run_stop(run_id: int, _user: CurrentUser):
await sa_request_stop(run_id)
return {"ok": True, "run_id": run_id}