248 lines
9.9 KiB
Python
248 lines
9.9 KiB
Python
"""Super Agent — P1 deterministic tools (preview only, no write yet)."""
|
||
from fastapi import APIRouter, HTTPException
|
||
|
||
from app.auth import CurrentUser
|
||
from app.agents.gaode_connector import available_types, search_pois
|
||
from app.agents.super_ingest import ingest_gaode
|
||
from app.agents.super_orchestrator import (
|
||
schedule_super_agent, _coverage, _HARD_MAX_STEPS,
|
||
)
|
||
from app.db import (
|
||
sa_create_run, sa_get_run, sa_latest_run, sa_request_stop, sa_list_tasks,
|
||
grid_counts, get_conn,
|
||
)
|
||
import asyncio
|
||
|
||
router = APIRouter()
|
||
|
||
SPATIAL_GRAPH_NAME = "guiyang_spatial_v1"
|
||
SPATIAL_TYPE_ORDER = [
|
||
"景点", "美食", "酒店", "商场", "医疗保健", "交通设施", "生活服务",
|
||
"科教文化", "政府机构", "公共设施", "体育休闲", "商务住宅", "公司企业",
|
||
"金融保险", "汽车服务", "汽车维修", "汽车销售", "摩托车服务",
|
||
"地名地址", "道路附属",
|
||
]
|
||
SPATIAL_TYPE_FALLBACK = {
|
||
"景点": "sight",
|
||
"美食": "eat",
|
||
"酒店": "hotel",
|
||
"商场": "mall",
|
||
"医疗保健": "medical",
|
||
"交通设施": "transit",
|
||
"生活服务": "life",
|
||
"科教文化": "education",
|
||
"政府机构": "government",
|
||
"公共设施": "facility",
|
||
"商务住宅": "residential",
|
||
"公司企业": "enterprise",
|
||
}
|
||
|
||
|
||
async def _spatial_h3_coverage() -> dict | None:
|
||
"""Return coverage for the new H3 spatial collector, if its tables exist."""
|
||
from app.config import settings
|
||
|
||
s = settings.db_schema
|
||
try:
|
||
async with get_conn() as conn:
|
||
async with conn.cursor() as cur:
|
||
await cur.execute(
|
||
"SELECT to_regclass(%s) AS pois_table, to_regclass(%s) AS tasks_table",
|
||
(f"{s}.amap_spatial_pois", f"{s}.amap_spatial_collect_tasks"),
|
||
)
|
||
tables = await cur.fetchone()
|
||
if not tables or not tables["pois_table"] or not tables["tasks_table"]:
|
||
return None
|
||
|
||
await cur.execute(
|
||
f"""SELECT type_label, place_type, COUNT(*) AS cnt
|
||
FROM {s}.amap_spatial_pois
|
||
WHERE graph_name=%s
|
||
GROUP BY type_label, place_type""",
|
||
(SPATIAL_GRAPH_NAME,),
|
||
)
|
||
counts = await cur.fetchall()
|
||
if not counts:
|
||
return None
|
||
|
||
await cur.execute(
|
||
f"""SELECT type_label,
|
||
COUNT(*) FILTER (WHERE resolution=6) AS grid_total,
|
||
COUNT(*) FILTER (
|
||
WHERE resolution=6
|
||
AND status IN ('done','saturated','saturated_max_res')
|
||
) AS grid_done,
|
||
COUNT(*) FILTER (WHERE resolution=6 AND status='pending') AS grid_pending,
|
||
COUNT(*) FILTER (WHERE resolution=6 AND status='running') AS grid_running,
|
||
COUNT(*) FILTER (WHERE resolution=6 AND status='error') AS grid_error,
|
||
COUNT(*) FILTER (WHERE resolution=6 AND status='quota_limited') AS grid_quota_limited
|
||
FROM {s}.amap_spatial_collect_tasks
|
||
WHERE graph_name=%s
|
||
GROUP BY type_label""",
|
||
(SPATIAL_GRAPH_NAME,),
|
||
)
|
||
grid_rows = {r["type_label"]: r for r in await cur.fetchall()}
|
||
|
||
await cur.execute(
|
||
f"""SELECT status, COUNT(*) AS cnt
|
||
FROM {s}.amap_spatial_collect_tasks
|
||
WHERE graph_name=%s
|
||
GROUP BY status""",
|
||
(SPATIAL_GRAPH_NAME,),
|
||
)
|
||
status_counts = {r["status"]: int(r["cnt"]) for r in await cur.fetchall()}
|
||
|
||
count_by_label: dict[str, dict] = {}
|
||
for r in counts:
|
||
label = r["type_label"] or "未分类"
|
||
item = count_by_label.setdefault(
|
||
label,
|
||
{
|
||
"cat": label,
|
||
"place_type": r["place_type"] or SPATIAL_TYPE_FALLBACK.get(label, "poi"),
|
||
"current": 0,
|
||
},
|
||
)
|
||
item["current"] += int(r["cnt"] or 0)
|
||
|
||
order_index = {label: i for i, label in enumerate(SPATIAL_TYPE_ORDER)}
|
||
items: list[dict] = []
|
||
for label, item in count_by_label.items():
|
||
grid = grid_rows.get(label, {})
|
||
grid_total = int(grid.get("grid_total") or 0)
|
||
grid_done = int(grid.get("grid_done") or 0)
|
||
grid_pending = int(grid.get("grid_pending") or 0)
|
||
grid_running = int(grid.get("grid_running") or 0)
|
||
grid_error = int(grid.get("grid_error") or 0)
|
||
grid_quota_limited = int(grid.get("grid_quota_limited") or 0)
|
||
item.update(
|
||
{
|
||
"grid_total": grid_total,
|
||
"grid_done": grid_done,
|
||
"grid_pending": grid_pending,
|
||
"grid_running": grid_running,
|
||
"grid_error": grid_error,
|
||
"grid_quota_limited": grid_quota_limited,
|
||
"grid_pct": round(grid_done / grid_total * 100) if grid_total else 0,
|
||
}
|
||
)
|
||
items.append(item)
|
||
|
||
items.sort(key=lambda x: (order_index.get(x["cat"], 999), -x["current"], x["cat"]))
|
||
total = sum(i["current"] for i in items)
|
||
grid_total_sum = sum(i["grid_total"] for i in items)
|
||
grid_done_sum = sum(i["grid_done"] for i in items)
|
||
return {
|
||
"source": "amap_spatial_h3",
|
||
"graph_name": SPATIAL_GRAPH_NAME,
|
||
"items": items,
|
||
"total": total,
|
||
"grid_overall_pct": round(grid_done_sum / grid_total_sum * 100) if grid_total_sum else 0,
|
||
"status_counts": status_counts,
|
||
"error_tasks": int(status_counts.get("error", 0)),
|
||
"quota_limited": int(status_counts.get("quota_limited", 0)),
|
||
}
|
||
except Exception:
|
||
return None
|
||
|
||
|
||
@router.get("/super-agent/gaode/types")
|
||
async def _types(_user: CurrentUser = None):
|
||
return {"types": available_types()}
|
||
|
||
|
||
@router.post("/super-agent/gaode/preview")
|
||
async def _preview(body: dict, _user: CurrentUser):
|
||
"""Fetch + normalize 高德 POIs for preview. Does NOT write anywhere."""
|
||
poi_type = body.get("poi_type")
|
||
keyword = body.get("keyword")
|
||
max_pages = max(1, min(int(body.get("max_pages", 1)), 5))
|
||
limit = max(1, min(int(body.get("limit", 40)), 100))
|
||
if not poi_type and not keyword:
|
||
raise HTTPException(400, "需要 poi_type 或 keyword")
|
||
try:
|
||
rows = search_pois(poi_type=poi_type, keyword=keyword,
|
||
max_pages=max_pages, limit=limit)
|
||
except FileNotFoundError as e:
|
||
raise HTTPException(400, str(e))
|
||
except Exception as e:
|
||
raise HTTPException(400, f"高德采集失败:{str(e)[:200]}")
|
||
return {"count": len(rows), "rows": rows}
|
||
|
||
|
||
@router.post("/super-agent/gaode/ingest")
|
||
async def _ingest(body: dict, _user: CurrentUser):
|
||
"""Fetch 高德 POIs and write into BOTH stores (idempotent).
|
||
|
||
valid → approved (PG + FalkorDB); incomplete → pending_review (PG only).
|
||
"""
|
||
poi_type = body.get("poi_type")
|
||
keyword = body.get("keyword")
|
||
max_pages = max(1, min(int(body.get("max_pages", 2)), 10))
|
||
limit = max(1, min(int(body.get("limit", 60)), 200))
|
||
if not poi_type and not keyword:
|
||
raise HTTPException(400, "需要 poi_type 或 keyword")
|
||
try:
|
||
return await ingest_gaode(poi_type, keyword, max_pages, limit)
|
||
except FileNotFoundError as e:
|
||
raise HTTPException(400, str(e))
|
||
except Exception as e:
|
||
raise HTTPException(400, f"高德入库失败:{str(e)[:200]}")
|
||
|
||
|
||
# ── Autonomous loop (guarded) ────────────────────────────────────────────────
|
||
|
||
@router.post("/super-agent/run")
|
||
async def _run(body: dict | None = None, _user: CurrentUser = None):
|
||
"""启动馆长自治补全。无需人工设步数/预算——按馆藏蓝图自驱。"""
|
||
goal = (body or {}).get("goal") or "馆长常驻:按数据源饱和度自治补全并驻守"
|
||
run_id = await sa_create_run(goal, _HARD_MAX_STEPS, 0)
|
||
schedule_super_agent(run_id)
|
||
return {"run_id": run_id, "goal": goal}
|
||
|
||
|
||
@router.get("/super-agent/coverage")
|
||
async def _cov(_user: CurrentUser = None):
|
||
"""馆藏盘点:各类当前藏品数 + 全城网格地理覆盖率(真实进度)。"""
|
||
spatial_cov = await _spatial_h3_coverage()
|
||
if spatial_cov:
|
||
return spatial_cov
|
||
|
||
cov = await asyncio.to_thread(_coverage)
|
||
gc = await grid_counts()
|
||
for it in cov["items"]:
|
||
g = gc.get(it["cat"], {})
|
||
tot, done = g.get("total", 0), g.get("done", 0)
|
||
it["grid_total"] = tot
|
||
it["grid_done"] = done
|
||
it["grid_pct"] = round(done / tot * 100) if tot else 0
|
||
cov["grid_overall_pct"] = (
|
||
round(sum(i["grid_done"] for i in cov["items"]) /
|
||
max(1, sum(i["grid_total"] for i in cov["items"])) * 100))
|
||
return cov
|
||
|
||
|
||
@router.get("/super-agent/tasks")
|
||
async def _tasks(run_id: int | None = None, _user: CurrentUser = None):
|
||
"""工单台账:AI 每一步做了什么、结果如何、是否升级。"""
|
||
return {"tasks": await sa_list_tasks(run_id)}
|
||
|
||
|
||
@router.get("/super-agent/runs/latest")
|
||
async def _latest(_user: CurrentUser = None):
|
||
return await sa_latest_run() or {}
|
||
|
||
|
||
@router.get("/super-agent/runs/{run_id}")
|
||
async def _run_get(run_id: int, _user: CurrentUser = None):
|
||
r = await sa_get_run(run_id)
|
||
if not r:
|
||
raise HTTPException(404, "run not found")
|
||
return r
|
||
|
||
|
||
@router.post("/super-agent/runs/{run_id}/stop")
|
||
async def _run_stop(run_id: int, _user: CurrentUser):
|
||
await sa_request_stop(run_id)
|
||
return {"ok": True, "run_id": run_id}
|