#!/usr/bin/env python3 # -*- coding: utf-8 -*- from fastapi import APIRouter, Request, Depends, Query, HTTPException, status,Path from common.security import valid_access_token from fastapi.responses import JSONResponse,Response from fastapi.responses import StreamingResponse from sqlalchemy.orm import Session from sqlalchemy import and_, or_,text,literal from sqlalchemy.sql import func from sqlalchemy.future import select from common.auth_user import * from pydantic import BaseModel from typing import Any, Dict # import contextily as ctx # import geopandas as gpd # from matplotlib import pyplot as plt import io from database import get_db from typing import List from models import * from utils import * from utils.ry_system_util import * from utils.video_util import * from collections import defaultdict import traceback from concurrent.futures import ThreadPoolExecutor, as_completed from multiprocessing import Pool, cpu_count import json import time import math router = APIRouter() @router.post("/get_info") @router.get("/get_info") async def get_infos( body = Depends(remove_xss_json), # zoom_level: float = Query(..., description="Zoom level for clustering"), # latitude_min: float = Query(..., description="Minimum latitude"), # latitude_max: float = Query(..., description="Maximum latitude"), # longitude_min: float = Query(..., description="Minimum longitude"), # longitude_max: float = Query(..., description="Maximum longitude"), # dict_value: str = Query(None), # option:str = Query(None), db: Session = Depends(get_db) ): try: # 根据缩放级别动态调整分组粒度 zoom_level = float(body['zoom_level']) zoom_levels = { 3: 10000, # 全国范围 4: 5000, 5: 2500, 6: 1250, 7: 825, 8: 412.5, 9: 256.25, 10: 178.125, 11: 69.0625, 12: 29.53125, 13: 13.765625, 14: 5.8828125, 15: 2.44140625, 16: 1.220703125, 17: 0.6103515625, 18: 0.30517578125 } distance_threshold=zoom_levels[int(zoom_level-1)] # distance_threshold = 100000 / (2.2 ** zoom_level) # 例如:每缩放一级,距离阈值减半 dict_value= body['dict_value'].split(',') latitude_min = float(body['latitude_min']) latitude_max = float(body['latitude_max']) longitude_min = float(body['longitude_min']) longitude_max = float(body['longitude_max']) option = body['option'].split(',') print("1",time.time()) videos = get_videos(db,dict_value,latitude_min,latitude_max,longitude_min,longitude_max) infos = get_points(db,option,latitude_min,latitude_max,longitude_min,longitude_max) # 动态分组逻辑 groups = group_points(videos+infos, distance_threshold) print("4",time.time()) return {"code": 200, "msg": "操作成功", "data": groups} except Exception as e: # 处理异常 traceback.print_exc() raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)) @router.post("/get_details") @router.get("/get_details") async def get_details( body = Depends(remove_xss_json), # center_latitude: float = Query(..., description="网格中心点的纬度"), # center_longitude: float = Query(..., description="网格中心点的经度"), # zoom_level: float = Query(..., description="缩放级别"), db: Session = Depends(get_db) ): try: # 计算网格大小 zoom_level = float(body['zoom_level']) zoom_levels = { 3: 10000, # 全国范围 4: 5000, 5: 2500, 6: 1250, 7: 825, 8: 412.5, 9: 256.25, 10: 178.125, 11: 69.0625, 12: 29.53125, 13: 13.765625, 14: 5.8828125, 15: 2.44140625, 16: 1.220703125, 17: 0.6103515625, 18: 0.30517578125 } distance_threshold=zoom_levels[int(zoom_level-1)] # distance_threshold = 1000 / (1.5 ** zoom_level) # 例如:每缩放一级,距离阈值减半 grid_size = calculate_grid_size(distance_threshold) # 地球半径为6371公里 center_latitude = float(body['latitude']) center_longitude = float(body['longitude']) dict_value = body['dict_value'].split(',') option = body['option'].split(',') # 计算网格的经纬度范围 latitude_min, latitude_max, longitude_min, longitude_max = get_grid_bounds_from_center(center_latitude, center_longitude, grid_size) videos = get_videos(db,dict_value,latitude_min,latitude_max,longitude_min,longitude_max) infos = get_points(db,option,latitude_min,latitude_max,longitude_min,longitude_max) return {"code": 200, "msg": "操作成功", "data": videos+infos }#{"videos":videos,"points":infos}} except Exception as e: # 处理异常 traceback.print_exc() raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)) def calculate_grid_size(distance_threshold): # 假设地球半径为6371公里,将距离阈值转换为经纬度的差值 # 这里假设纬度变化对距离的影响较小,仅根据经度计算网格大小 earth_radius = 6371 # 地球半径,单位为公里 grid_size = distance_threshold / earth_radius return grid_size def get_grid_key(latitude, longitude, grid_size): # 根据经纬度和网格大小计算网格键 return (math.floor(latitude / grid_size), math.floor(longitude / grid_size)) def get_grid_bounds_from_center(center_latitude, center_longitude, grid_size): half_grid_size = grid_size / 2 min_latitude = center_latitude - half_grid_size max_latitude = center_latitude + half_grid_size min_longitude = center_longitude - half_grid_size max_longitude = center_longitude + half_grid_size return min_latitude, max_latitude, min_longitude, max_longitude def calculate_distance(point1, point2): # 使用 Haversine 公式计算两点之间的距离 from math import radians, sin, cos, sqrt, atan2 R = 6371 # 地球半径(公里) lat1, lon1 = radians(point1.latitude), radians(point1.longitude) lat2, lon2 = radians(point2.latitude), radians(point2.longitude) dlat = lat2 - lat1 dlon = lon2 - lon1 a = sin(dlat / 2) ** 2 + cos(lat1) * cos(lat2) * sin(dlon / 2) ** 2 c = 2 * atan2(sqrt(a), sqrt(1 - a)) return R * c def group_points(points, distance_threshold): grid_size = calculate_grid_size(distance_threshold) grid = defaultdict(lambda:{"count":0}) #,"list":[] groups = [] tmp = defaultdict(list) for point in points: grid_key = get_grid_key(float(point.latitude), float(point.longitude), grid_size) lovalue = str(point.latitude)+str(point.longitude) if lovalue not in tmp['%s-%s'%grid_key]: tmp['%s-%s'%grid_key].append(lovalue) grid['%s-%s'%grid_key]['count']+=1 if grid['%s-%s'%grid_key]['count']>1 and len(tmp['%s-%s'%grid_key])<3: grid['%s-%s'%grid_key]['dataType'] = '' grid['%s-%s'%grid_key]['id'] = "" if len(tmp['%s-%s'%grid_key])<2: grid['%s-%s'%grid_key]['name'] = '多数据点位' grid['%s-%s' % grid_key]['type'] ='1' else: grid['%s-%s' % grid_key]['name'] = '聚合点位' grid['%s-%s' % grid_key]['type'] = '3' # grid['%s-%s'%grid_key]['latitude'] = float(point.latitude) #(grid_key[0] + 0.5) * grid_size # grid['%s-%s'%grid_key]['longitude'] = float(point.longitude) #(grid_key[1] + 0.5) * grid_size elif grid['%s-%s'%grid_key]['count']==1: if point.dataType=='video': grid['%s-%s' % grid_key]['id'] = point.gbIndexCode else: grid['%s-%s' % grid_key]['id'] = point.id grid['%s-%s'%grid_key]['dataType'] = point.dataType grid['%s-%s'%grid_key]['infoType'] = point.infoType grid['%s-%s'%grid_key]['name'] = point.name grid['%s-%s'%grid_key]['type'] ='2' grid['%s-%s'%grid_key]['latitude'] = float(point.latitude) grid['%s-%s'%grid_key]['longitude'] = float(point.longitude) groups = list(grid.values()) return groups def get_videos(db:Session,dict_value,latitude_min,latitude_max,longitude_min,longitude_max): que = True if len(dict_value)>0: videolist = [] for value in dict_value: tag_info = get_dict_data_info(db, 'video_type', value) if tag_info: if tag_info.dict_label == '全量视频': break else: videolist += [i.video_code for i in tag_get_video_tag_list(db, value)] else: que = TPVideoInfo.gbIndexCode.in_(videolist) # 查询分组 query = ( select( TPVideoInfo.gbIndexCode, TPVideoInfo.latitude, TPVideoInfo.longitude, TPVideoInfo.name, TPVideoInfo.status, literal('video').label("dataType"), literal('video').label("infoType") ) .select_from(TPVideoInfo).where( and_( TPVideoInfo.latitude >= latitude_min, TPVideoInfo.latitude <= latitude_max, TPVideoInfo.longitude >= longitude_min, TPVideoInfo.longitude <= longitude_max, TPVideoInfo.longitude > 0, TPVideoInfo.latitude > 0, que ) ) .order_by(TPVideoInfo.status.asc()) ) result = db.execute(query) videos = result.fetchall() return videos def get_points(db:Session,option,latitude_min,latitude_max,longitude_min,longitude_max): # 使用参数化查询避免 SQL 注入 if isinstance(option, list): option = tuple(option) query = text(""" SELECT A.`name`,A.`id`,A.dataType,A.longitude,A.latitude,A.infoType FROM ( SELECT *, ROW_NUMBER() OVER (PARTITION BY longitude, latitude, `name` ORDER BY longitude, latitude, `name`) AS rn FROM `point_data` WHERE longitude > 0 AND latitude BETWEEN :latitude_min AND :latitude_max AND longitude BETWEEN :longitude_min AND :longitude_max AND dataType IN :option ) AS A WHERE rn = 1 """) # 执行查询并传递参数 result = db.execute(query, { 'latitude_min': latitude_min, 'latitude_max': latitude_max, 'longitude_min': longitude_min, 'longitude_max': longitude_max, 'option': option }) infos = result.fetchall() return infos @router.post("/get_geojson") async def get_geojson( body = Depends(remove_xss_json), db: Session = Depends(get_db) ): try: # 根据缩放级别动态调整分组粒度 latitude_min = float(body['latitude_min']) latitude_max = float(body['latitude_max']) longitude_min = float(body['longitude_min']) longitude_max = float(body['longitude_max']) if latitude_min<-90 or latitude_max>90 : return JSONResponse(status_code=500, content={"code": 500, "msg": "Latitude must be within [-90.000000, 90.000000]"}) if longitude_min<-180 or longitude_max>180 : return JSONResponse(status_code=500, content={"code": 500, "msg": "Longitude must be within [-180.000000, 180.000000]"}) table_name = 'tp_geojson_data_zj' option = body['option'] if 'cj' == option: table_name = 'tp_geojson_data_cj_sq' sql = f"""SELECT id, name, pac, ST_AsGeoJSON(geometry) AS geojson, properties FROM {table_name} WHERE ST_Intersects( geometry, ST_GeomFromText( 'POLYGON(({latitude_min} {longitude_min}, {latitude_max} {longitude_min}, {latitude_max} {longitude_max}, {latitude_min} {longitude_max}, {latitude_min} {longitude_min}))', 4326 ) );""" result = db.execute(sql) features = result.fetchall() return {"code": 200, "msg": "操作成功","type":"FeatureCollection", "features": features} except Exception as e: # 处理异常 traceback.print_exc() raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)) @router.post("/get_geojson_new") async def get_geojson( body = Depends(remove_xss_json), db: Session = Depends(get_db) ): try: # 根据缩放级别动态调整分组粒度 pac = body['area_code'] table_name = 'tp_geojson_data_zj' option = body['option'] if 'cj' == option: table_name = 'tp_geojson_data_cj_sq' pac = pac[:9] else: pac = pac[:6] sql = f"""SELECT ST_AsGeoJSON(geometry) AS geometry, properties FROM {table_name} WHERE parent_pac = '{pac}';""" def gen(): # 1. 写头 yield '{"type":"FeatureCollection","features":[' first = True # 2. 逐行流式 for geom, prop_json in db.execute(sql): # 迭代器,不 fetchall if not first: yield "," feature = { "geometry": json.loads(geom), "properties": { "PAC": json.loads(prop_json)['PAC'], "NAME": json.loads(prop_json)['NAME'] } } yield json.dumps(feature, ensure_ascii=False) first = False # 3. 写尾 yield "]}" return StreamingResponse(gen(), media_type="application/json", headers={"Content-Disposition": "attachment; filename=data.geojson"}) except Exception as e: # 处理异常 traceback.print_exc() raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)) # @router.post("/get_map_img") # async def get_map_img( # body: Dict[str, Any] = Depends(remove_xss_json), # db: Session = Depends(get_db), # ): # """ # 输入: # { # "latitude_min": 27.0, # "latitude_max": 30.0, # "longitude_min": 118.0, # "longitude_max": 121.0, # "option": "zj" // "zj" 或 "cj" # } # 返回:PNG 图片 # """ # try: # # 1. 参数提取与合法性校验 # lat_min = float(body["latitude_min"]) # lat_max = float(body["latitude_max"]) # lon_min = float(body["longitude_min"]) # lon_max = float(body["longitude_max"]) # option = body.get("option", "zj") # # if not (-90 <= lat_min <= 90 and -90 <= lat_max <= 90): # raise ValueError("Latitude must be within [-90, 90]") # if not (-180 <= lon_min <= 180 and -180 <= lon_max <= 180): # raise ValueError("Longitude must be within [-180, 180]") # # table_name = "tp_geojson_data_zj" if option != "cj" else "tp_geojson_data_cj_sq" # # # 2. 构造 SQL # sql = text( # f""" # SELECT id, name, pac, ST_AsGeoJSON(geometry) AS geojson, properties # FROM {table_name} # WHERE ST_Intersects( # geometry, # ST_GeomFromText( # 'POLYGON(({lon_min} {lat_min}, # {lon_max} {lat_min}, # {lon_max} {lat_max}, # {lon_min} {lat_max}, # {lon_min} {lat_min}))', # 4326 # ) # ); # """ # ) # # rows = db.execute(sql).fetchall() # if not rows: # raise HTTPException( # status_code=status.HTTP_404_NOT_FOUND, # detail="No data within the given bbox." # ) # # # 3. 组装 GeoDataFrame # features = [ # {**json.loads(r.geojson), "properties": json.loads(r.properties)} # for r in rows # ] # gdf = gpd.GeoDataFrame.from_features(features, crs="EPSG:4326") # # # 4. 绘图 # fig, ax = plt.subplots(figsize=(6, 6), dpi=150) # gdf.to_crs(epsg=3857).plot(ax=ax, alpha=0.5, edgecolor="black") # ctx.add_basemap(ax, source=ctx.providers.Stamen.TonerLite, crs=gdf.to_crs(epsg=3857).crs) # ax.set_axis_off() # plt.tight_layout(pad=0) # # # 5. 保存成字节流 # buf = io.BytesIO() # fig.savefig(buf, format="png") # buf.seek(0) # plt.close(fig) # # # 6. 返回图片 # return StreamingResponse(buf, media_type="image/png") # # except Exception as e: # raise HTTPException( # status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, # detail=str(e) # ) # TILE_EXTENT = 4096 # 标准 MVT 精度 # # ---------- 纯 Python 坐标转换 ---------- # def lonlat2xy(lon: float, lat: float) -> tuple[float, float]: # """4326 -> 3857""" # x = lon * 20037508.34 / 180 # y = math.log(math.tan((90 + lat) * math.pi / 360)) / (math.pi / 180) # y = y * 20037508.34 / 180 # return x, y # # import mercantile # from mapbox_vector_tile import encode # from shapely.geometry import shape # @router.get("/tile/{option}/{z}/{x}/{y}.pbf") # async def get_tile( # option: str = Path(..., regex="^(zj|cj)$"), # z: int = Path(..., ge=0, le=22), # x: int = Path(..., ge=0), # y: int = Path(..., ge=0), # db: Session = Depends(get_db), # ): # """ # 根据 Slippy Map 标准 XYZ 返回 MVT 二进制。 # 前端 layer.url = '/tile/zj/{z}/{x}/{y}.pbf' # """ # table = "tp_geojson_data_zj" if option == "zj" else "tp_geojson_data_cj_sq" # # # 1. 计算瓦片 bbox (4326) # tile_bounds = mercantile.bounds(mercantile.Tile(x, y, z)) # xmin, ymin, xmax, ymax = tile_bounds.west, tile_bounds.south, tile_bounds.east, tile_bounds.north # print(xmin, ymin, xmax, ymax) # # 2. 查相交要素 # sql = text( # f""" # SELECT id, name, pac, properties, # ST_AsGeoJSON(geometry) AS geojson # FROM {table} # WHERE ST_Intersects( # geometry, # ST_GeomFromText( # 'POLYGON(( # {ymin} {xmin}, # {ymin} {xmax}, # {ymax} {xmax}, # {ymax} {xmin}, # {ymin} {xmin}))', # 4326 # ) # ); # """ # ) # rows = db.execute(sql).fetchall() # if not rows: # raise HTTPException(status_code=204) # # # 3. 构造 MVT features # features: List[Dict[str, Any]] = [] # bounds_3857 = mercantile.xy_bounds(mercantile.Tile(x, y, z)) # bx, by, bw, bh = bounds_3857.left, bounds_3857.bottom, \ # bounds_3857.right - bounds_3857.left, \ # bounds_3857.top - bounds_3857.bottom # # for r in rows: # # 直接用原始 GeoJSON,不做任何坐标变换 # geo = json.loads(r.geojson) # features.append({ # "geometry": geo, # 必须是 4326 坐标 # "properties": json.loads(r.properties), # }) # # # 4. 生成 MVT # mvt_bytes = encode([ # { # "name": "layer", # "features": features, # "extent": TILE_EXTENT, # } # ]) # # return Response(content=mvt_bytes, media_type="application/x-protobuf")