Files
bxh/scripts/enrich_travel_graph_amap_driving_metrics.py

353 lines
13 KiB
Python

#!/usr/bin/env python3
"""Fetch AMap driving distance metrics for travel_graph route/resource pairs.
The graph stores only cooperative business resources. This script does not add
online hotels or restaurants; it enriches existing attraction-to-attraction and
same-region attraction-to-resource pairs with driving distance/duration.
"""
from __future__ import annotations
import csv
import importlib.util
import json
import math
import os
import re
import time
from datetime import datetime
from pathlib import Path
from typing import Any
import requests
import urllib3
from common_paths import GAODE_CRAWLER_PATH, PROJECT_ROOT, TRAVEL_KG_EXPORT_ROOT
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
BUILD_SCRIPT = PROJECT_ROOT / "scripts/build_travel_graph_existing_product_project.py"
OUT_DIR = TRAVEL_KG_EXPORT_ROOT / "travel_graph_旅行社线路制定"
NODES_PATH = OUT_DIR / "抽取结果_nodes.json"
CACHE_PATH = OUT_DIR / "amap_driving_distance_cache.json"
REPORT_CSV = OUT_DIR / "amap_driving_distance_report.csv"
AMAP_DISTANCE_URL = "https://restapi.amap.com/v3/distance"
MAX_ORIGINS_PER_REQUEST = 80
def load_build_module():
spec = importlib.util.spec_from_file_location("travel_build", BUILD_SCRIPT)
mod = importlib.util.module_from_spec(spec)
assert spec.loader
spec.loader.exec_module(mod)
return mod
b = load_build_module()
def load_key() -> str:
for key in (os.environ.get("AMAP_WEB_KEY"), os.environ.get("AMAP_KEY")):
if key:
return key
crawl_path = GAODE_CRAWLER_PATH
if crawl_path.exists():
spec = importlib.util.spec_from_file_location("crawl_guiyan", crawl_path)
mod = importlib.util.module_from_spec(spec)
assert spec.loader
spec.loader.exec_module(mod)
key = (getattr(mod, "CONFIG", {}) or {}).get("key")
if key:
return key
raise RuntimeError("未找到可用的高德 Web 服务 Key。请配置 AMAP_WEB_KEY 或保留 crawl_guiyan.py 中的 CONFIG['key']。")
def clean(value: Any) -> str:
return b.compact(value)
def number(value: Any) -> float | None:
return b.number(value)
def node_location(node: dict[str, Any]) -> tuple[float, float] | None:
lng = number(node.get("location_lng") or node.get("amap_lng"))
lat = number(node.get("location_lat") or node.get("amap_lat"))
if lng is None or lat is None:
return None
return lng, lat
def location_text(loc: tuple[float, float]) -> str:
return f"{loc[0]:.6f},{loc[1]:.6f}"
def chunked(items: list[Any], size: int) -> list[list[Any]]:
return [items[idx: idx + size] for idx in range(0, len(items), size)]
def pair_key(source_key: str, target_key: str) -> str:
return f"{source_key}=>{target_key}"
def resource_type(label: str) -> str:
return {
"ScenicArea": "scenic_area",
"ScenicAttraction": "attraction",
"HotelResource": "hotel",
"RestaurantResource": "restaurant",
}.get(label, "resource")
def same_admin_region(source: dict[str, Any], target: dict[str, Any]) -> str:
source_region = clean(source.get("admin_region_name"))
target_region = clean(target.get("admin_region_name"))
if source_region and target_region and source_region == target_region:
return "same_admin_region"
return ""
def straight_distance_km(source: dict[str, Any], target: dict[str, Any]) -> float | None:
src = node_location(source)
dst = node_location(target)
if not src or not dst:
return None
return round(b.haversine_km(src[0], src[1], dst[0], dst[1]), 2)
def nearby_match(source: dict[str, Any], target: dict[str, Any]) -> str:
distance = straight_distance_km(source, target)
if distance is None:
return ""
threshold = 35.0 if target.get("label") == "HotelResource" else 15.0
region_match = same_admin_region(source, target)
if region_match:
return region_match
source_city = clean(source.get("amap_cityname") or source.get("city") or source.get("city_or_area"))
target_city = clean(target.get("amap_cityname") or target.get("city") or target.get("city_or_area"))
same_city = source_city and target_city and (source_city in target_city or target_city in source_city)
if distance <= threshold:
return "nearby_distance"
if same_city and distance <= threshold * 1.6:
return "same_city_nearby_distance"
return ""
def load_existing_cache() -> dict[str, dict[str, Any]]:
if not CACHE_PATH.exists():
return {}
try:
payload = json.loads(CACHE_PATH.read_text(encoding="utf-8"))
except Exception:
return {}
if isinstance(payload, dict) and isinstance(payload.get("items"), dict):
return payload["items"]
return payload if isinstance(payload, dict) else {}
def build_pairs(nodes: list[dict[str, Any]]) -> list[dict[str, Any]]:
route_anchors = [
n for n in nodes
if node_location(n)
and (
n.get("label") == "ScenicArea"
or (
n.get("label") == "ScenicAttraction"
and n.get("is_independent_destination") is not False
and not clean(n.get("parent_scenic_area_name"))
)
)
]
hotels = [n for n in nodes if n.get("label") == "HotelResource" and node_location(n)]
restaurants = [n for n in nodes if n.get("label") == "RestaurantResource" and node_location(n)]
pairs: list[dict[str, Any]] = []
seen: set[str] = set()
def add_pair(source: dict[str, Any], target: dict[str, Any], scope: str, region_match: str) -> None:
key = pair_key(source["natural_key"], target["natural_key"])
if key in seen:
return
src_loc = node_location(source)
dst_loc = node_location(target)
if not src_loc or not dst_loc:
return
seen.add(key)
pairs.append({
"key": key,
"source_key": source["natural_key"],
"target_key": target["natural_key"],
"source_name": source.get("name"),
"target_name": target.get("name"),
"source_label": source.get("label"),
"target_label": target.get("label"),
"target_resource_type": resource_type(target.get("label", "")),
"source_region": source.get("admin_region_name"),
"target_region": target.get("admin_region_name"),
"metric_scope": scope,
"region_match_level": region_match,
"same_admin_region": bool(region_match),
"origin_location": location_text(src_loc),
"destination_location": location_text(dst_loc),
"straight_distance_km": straight_distance_km(source, target),
})
for source in route_anchors:
for target in route_anchors:
if source["natural_key"] == target["natural_key"]:
continue
add_pair(source, target, "route_anchor_to_route_anchor", same_admin_region(source, target) or "cross_region_route_metric")
for target in hotels:
region_match = nearby_match(source, target)
if region_match:
scope = "scenic_anchor_to_same_region_hotel" if region_match == "same_admin_region" else "scenic_anchor_to_nearby_hotel"
add_pair(source, target, scope, region_match)
for target in restaurants:
region_match = nearby_match(source, target)
if region_match:
scope = "scenic_anchor_to_same_region_restaurant" if region_match == "same_admin_region" else "scenic_anchor_to_nearby_restaurant"
add_pair(source, target, scope, region_match)
return pairs
def cache_matches_pair(cached: dict[str, Any], pair: dict[str, Any]) -> bool:
"""Reuse cached AMap metrics only when the current semantic scope is identical."""
if cached.get("status") != "matched":
return False
for field in [
"source_key",
"target_key",
"source_label",
"target_label",
"metric_scope",
"origin_location",
"destination_location",
]:
if cached.get(field) != pair.get(field):
return False
return True
def query_distance_batch(key: str, destination: str, origin_rows: list[dict[str, Any]]) -> dict[str, dict[str, Any]]:
params = {
"key": key,
"origins": "|".join(row["origin_location"] for row in origin_rows),
"destination": destination,
"type": "1",
"output": "json",
}
response = requests.get(AMAP_DISTANCE_URL, params=params, timeout=(3, 12), verify=False)
response.raise_for_status()
payload = response.json()
if payload.get("status") != "1":
raise RuntimeError(f"{payload.get('infocode')} {payload.get('info')}")
results = payload.get("results") or []
by_pair: dict[str, dict[str, Any]] = {}
for result in results:
raw_idx = clean(result.get("origin_id"))
try:
idx = int(raw_idx) - 1
except Exception:
continue
if idx < 0 or idx >= len(origin_rows):
continue
row = origin_rows[idx]
distance_m = number(result.get("distance"))
duration_s = number(result.get("duration"))
if distance_m is None or duration_s is None:
continue
by_pair[row["key"]] = {
**row,
"status": "matched",
"amap_distance_m": int(distance_m),
"amap_duration_s": int(duration_s),
"drive_distance_km": round(distance_m / 1000, 2),
"drive_duration_min": int(math.ceil(duration_s / 60)),
"provider": "amap",
"api": "v3/distance",
"route_type": "driving",
"updated_at": datetime.now().isoformat(timespec="seconds"),
}
return by_pair
def main() -> None:
if not NODES_PATH.exists():
raise RuntimeError(f"缺少节点文件:{NODES_PATH},请先构建一次 travel_graph。")
OUT_DIR.mkdir(parents=True, exist_ok=True)
amap_key = load_key()
nodes = json.loads(NODES_PATH.read_text(encoding="utf-8"))
pairs = build_pairs(nodes)
existing_cache = load_existing_cache()
cache: dict[str, dict[str, Any]] = {}
missing: list[dict[str, Any]] = []
for row in pairs:
cached = existing_cache.get(row["key"])
if isinstance(cached, dict) and cache_matches_pair(cached, row):
cache[row["key"]] = cached
else:
missing.append(row)
print(json.dumps({
"candidate_pairs": len(pairs),
"cached_matched": len(pairs) - len(missing),
"missing": len(missing),
}, ensure_ascii=False), flush=True)
by_destination: dict[str, list[dict[str, Any]]] = {}
for row in missing:
by_destination.setdefault(row["destination_location"], []).append(row)
processed = 0
for destination, rows in by_destination.items():
for batch in chunked(rows, MAX_ORIGINS_PER_REQUEST):
try:
cache.update(query_distance_batch(amap_key, destination, batch))
except Exception as exc:
now = datetime.now().isoformat(timespec="seconds")
for row in batch:
cache[row["key"]] = {
**row,
"status": "api_error",
"error": str(exc)[:180],
"provider": "amap",
"api": "v3/distance",
"route_type": "driving",
"updated_at": now,
}
processed += len(batch)
if processed % 200 == 0 or processed == len(missing):
print(f"processed {processed}/{len(missing)}", flush=True)
time.sleep(0.10)
rows = [cache[row["key"]] for row in pairs if row["key"] in cache]
CACHE_PATH.write_text(json.dumps({
"generated_at": datetime.now().isoformat(timespec="seconds"),
"source": "amap_v3_distance_driving",
"note": "仅对业务资料库已有资源补充驾车距离/耗时;不新增线上随机商家。",
"items": cache,
}, ensure_ascii=False, indent=2), encoding="utf-8")
fieldnames = [
"status", "metric_scope", "source_label", "source_name", "source_region",
"target_label", "target_name", "target_region", "region_match_level",
"straight_distance_km", "drive_distance_km", "drive_duration_min",
"amap_distance_m", "amap_duration_s", "origin_location", "destination_location",
"error",
]
with REPORT_CSV.open("w", newline="", encoding="utf-8-sig") as fh:
writer = csv.DictWriter(fh, fieldnames=fieldnames, extrasaction="ignore")
writer.writeheader()
writer.writerows(rows)
matched = sum(1 for row in rows if row.get("status") == "matched")
print(json.dumps({
"pairs": len(pairs),
"matched": matched,
"failed": len(rows) - matched,
"cache": str(CACHE_PATH),
"report": str(REPORT_CSV),
}, ensure_ascii=False, indent=2))
if __name__ == "__main__":
main()