123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536 |
- #!/usr/bin/env python3
- # -*- coding: utf-8 -*-
- from fastapi import APIRouter, Request, Depends, Query, HTTPException, status,Path
- from common.security import valid_access_token
- from fastapi.responses import JSONResponse,Response
- from fastapi.responses import StreamingResponse
- from sqlalchemy.orm import Session
- from sqlalchemy import and_, or_,text,literal
- from sqlalchemy.sql import func
- from sqlalchemy.future import select
- from common.auth_user import *
- from pydantic import BaseModel
- from typing import Any, Dict
- # import contextily as ctx
- # import geopandas as gpd
- # from matplotlib import pyplot as plt
- import io
- from database import get_db
- from typing import List
- from models import *
- from utils import *
- from utils.ry_system_util import *
- from utils.video_util import *
- from collections import defaultdict
- import traceback
- from concurrent.futures import ThreadPoolExecutor, as_completed
- from multiprocessing import Pool, cpu_count
- import json
- import time
- import math
- router = APIRouter()
- @router.post("/get_info")
- @router.get("/get_info")
- async def get_infos(
- body = Depends(remove_xss_json),
- # zoom_level: float = Query(..., description="Zoom level for clustering"),
- # latitude_min: float = Query(..., description="Minimum latitude"),
- # latitude_max: float = Query(..., description="Maximum latitude"),
- # longitude_min: float = Query(..., description="Minimum longitude"),
- # longitude_max: float = Query(..., description="Maximum longitude"),
- # dict_value: str = Query(None),
- # option:str = Query(None),
- db: Session = Depends(get_db)
- ):
- try:
- # 根据缩放级别动态调整分组粒度
- zoom_level = float(body['zoom_level'])
- zoom_levels = {
- 3: 10000, # 全国范围
- 4: 5000,
- 5: 2500,
- 6: 1250,
- 7: 825,
- 8: 412.5,
- 9: 256.25,
- 10: 178.125,
- 11: 69.0625,
- 12: 29.53125,
- 13: 13.765625,
- 14: 5.8828125,
- 15: 2.44140625,
- 16: 1.220703125,
- 17: 0.6103515625,
- 18: 0.30517578125
- }
- distance_threshold=zoom_levels[int(zoom_level-1)]
- # distance_threshold = 100000 / (2.2 ** zoom_level) # 例如:每缩放一级,距离阈值减半
- dict_value= body['dict_value'].split(',')
- latitude_min = float(body['latitude_min'])
- latitude_max = float(body['latitude_max'])
- longitude_min = float(body['longitude_min'])
- longitude_max = float(body['longitude_max'])
- option = body['option'].split(',')
- print("1",time.time())
- videos = get_videos(db,dict_value,latitude_min,latitude_max,longitude_min,longitude_max)
- infos = get_points(db,option,latitude_min,latitude_max,longitude_min,longitude_max)
- # 动态分组逻辑
- groups = group_points(videos+infos, distance_threshold)
- print("4",time.time())
- return {"code": 200,
- "msg": "操作成功",
- "data": groups}
- except Exception as e:
- # 处理异常
- traceback.print_exc()
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))
- @router.post("/get_details")
- @router.get("/get_details")
- async def get_details(
- body = Depends(remove_xss_json),
- # center_latitude: float = Query(..., description="网格中心点的纬度"),
- # center_longitude: float = Query(..., description="网格中心点的经度"),
- # zoom_level: float = Query(..., description="缩放级别"),
- db: Session = Depends(get_db)
- ):
- try:
- # 计算网格大小
- zoom_level = float(body['zoom_level'])
- zoom_levels = {
- 3: 10000, # 全国范围
- 4: 5000,
- 5: 2500,
- 6: 1250,
- 7: 825,
- 8: 412.5,
- 9: 256.25,
- 10: 178.125,
- 11: 69.0625,
- 12: 29.53125,
- 13: 13.765625,
- 14: 5.8828125,
- 15: 2.44140625,
- 16: 1.220703125,
- 17: 0.6103515625,
- 18: 0.30517578125
- }
- distance_threshold=zoom_levels[int(zoom_level-1)]
- # distance_threshold = 1000 / (1.5 ** zoom_level) # 例如:每缩放一级,距离阈值减半
- grid_size = calculate_grid_size(distance_threshold) # 地球半径为6371公里
- center_latitude = float(body['latitude'])
- center_longitude = float(body['longitude'])
- dict_value = body['dict_value'].split(',')
- option = body['option'].split(',')
- # 计算网格的经纬度范围
- latitude_min, latitude_max, longitude_min, longitude_max = get_grid_bounds_from_center(center_latitude, center_longitude, grid_size)
- videos = get_videos(db,dict_value,latitude_min,latitude_max,longitude_min,longitude_max)
- infos = get_points(db,option,latitude_min,latitude_max,longitude_min,longitude_max)
- return {"code": 200,
- "msg": "操作成功",
- "data": videos+infos }#{"videos":videos,"points":infos}}
- except Exception as e:
- # 处理异常
- traceback.print_exc()
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))
- def calculate_grid_size(distance_threshold):
- # 假设地球半径为6371公里,将距离阈值转换为经纬度的差值
- # 这里假设纬度变化对距离的影响较小,仅根据经度计算网格大小
- earth_radius = 6371 # 地球半径,单位为公里
- grid_size = distance_threshold / earth_radius
- return grid_size
- def get_grid_key(latitude, longitude, grid_size):
- # 根据经纬度和网格大小计算网格键
- return (math.floor(latitude / grid_size), math.floor(longitude / grid_size))
- def get_grid_bounds_from_center(center_latitude, center_longitude, grid_size):
- half_grid_size = grid_size / 2
- min_latitude = center_latitude - half_grid_size
- max_latitude = center_latitude + half_grid_size
- min_longitude = center_longitude - half_grid_size
- max_longitude = center_longitude + half_grid_size
- return min_latitude, max_latitude, min_longitude, max_longitude
- def calculate_distance(point1, point2):
- # 使用 Haversine 公式计算两点之间的距离
- from math import radians, sin, cos, sqrt, atan2
- R = 6371 # 地球半径(公里)
- lat1, lon1 = radians(point1.latitude), radians(point1.longitude)
- lat2, lon2 = radians(point2.latitude), radians(point2.longitude)
- dlat = lat2 - lat1
- dlon = lon2 - lon1
- a = sin(dlat / 2) ** 2 + cos(lat1) * cos(lat2) * sin(dlon / 2) ** 2
- c = 2 * atan2(sqrt(a), sqrt(1 - a))
- return R * c
- def group_points(points, distance_threshold):
- grid_size = calculate_grid_size(distance_threshold)
- grid = defaultdict(lambda:{"count":0}) #,"list":[]
- groups = []
- tmp = defaultdict(list)
- for point in points:
- grid_key = get_grid_key(float(point.latitude), float(point.longitude), grid_size)
- lovalue = str(point.latitude)+str(point.longitude)
- if lovalue not in tmp['%s-%s'%grid_key]:
- tmp['%s-%s'%grid_key].append(lovalue)
- grid['%s-%s'%grid_key]['count']+=1
- if grid['%s-%s'%grid_key]['count']>1 and len(tmp['%s-%s'%grid_key])<3:
- grid['%s-%s'%grid_key]['dataType'] = ''
- grid['%s-%s'%grid_key]['id'] = ""
- if len(tmp['%s-%s'%grid_key])<2:
- grid['%s-%s'%grid_key]['name'] = '多数据点位'
- grid['%s-%s' % grid_key]['type'] ='1'
- else:
- grid['%s-%s' % grid_key]['name'] = '聚合点位'
- grid['%s-%s' % grid_key]['type'] = '3'
- # grid['%s-%s'%grid_key]['latitude'] = float(point.latitude) #(grid_key[0] + 0.5) * grid_size
- # grid['%s-%s'%grid_key]['longitude'] = float(point.longitude) #(grid_key[1] + 0.5) * grid_size
- elif grid['%s-%s'%grid_key]['count']==1:
- if point.dataType=='video':
- grid['%s-%s' % grid_key]['id'] = point.gbIndexCode
- else:
- grid['%s-%s' % grid_key]['id'] = point.id
- grid['%s-%s'%grid_key]['dataType'] = point.dataType
- grid['%s-%s'%grid_key]['infoType'] = point.infoType
- grid['%s-%s'%grid_key]['name'] = point.name
- grid['%s-%s'%grid_key]['type'] ='2'
- grid['%s-%s'%grid_key]['latitude'] = float(point.latitude)
- grid['%s-%s'%grid_key]['longitude'] = float(point.longitude)
- groups = list(grid.values())
- return groups
- def get_videos(db:Session,dict_value,latitude_min,latitude_max,longitude_min,longitude_max):
- que = True
- if len(dict_value)>0:
- videolist = []
- for value in dict_value:
- tag_info = get_dict_data_info(db, 'video_type', value)
- if tag_info:
- if tag_info.dict_label == '全量视频':
- break
- else:
- videolist += [i.video_code for i in tag_get_video_tag_list(db, value)]
- else:
- que = TPVideoInfo.gbIndexCode.in_(videolist)
- # 查询分组
- query = (
- select(
- TPVideoInfo.gbIndexCode,
- TPVideoInfo.latitude,
- TPVideoInfo.longitude,
- TPVideoInfo.name,
- TPVideoInfo.status,
- literal('video').label("dataType"),
- literal('video').label("infoType")
- )
- .select_from(TPVideoInfo).where(
- and_(
- TPVideoInfo.latitude >= latitude_min,
- TPVideoInfo.latitude <= latitude_max,
- TPVideoInfo.longitude >= longitude_min,
- TPVideoInfo.longitude <= longitude_max,
- TPVideoInfo.longitude > 0,
- TPVideoInfo.latitude > 0, que
- )
- )
- .order_by(TPVideoInfo.status.asc())
- )
- result = db.execute(query)
- videos = result.fetchall()
- return videos
- def get_points(db:Session,option,latitude_min,latitude_max,longitude_min,longitude_max):
- # 使用参数化查询避免 SQL 注入
- if isinstance(option, list):
- option = tuple(option)
- query = text("""
- SELECT
- A.`name`,A.`id`,A.dataType,A.longitude,A.latitude,A.infoType
- FROM (
- SELECT
- *,
- ROW_NUMBER() OVER (PARTITION BY longitude, latitude, `name`
- ORDER BY longitude, latitude, `name`) AS rn
- FROM
- `point_data`
- WHERE
- longitude > 0
- AND latitude BETWEEN :latitude_min AND :latitude_max
- AND longitude BETWEEN :longitude_min AND :longitude_max
- AND dataType IN :option
- ) AS A
- WHERE rn = 1
- """)
- # 执行查询并传递参数
- result = db.execute(query, {
- 'latitude_min': latitude_min,
- 'latitude_max': latitude_max,
- 'longitude_min': longitude_min,
- 'longitude_max': longitude_max,
- 'option': option
- })
- infos = result.fetchall()
- return infos
- @router.post("/get_geojson")
- async def get_geojson(
- body = Depends(remove_xss_json),
- db: Session = Depends(get_db)
- ):
- try:
- # 根据缩放级别动态调整分组粒度
- latitude_min = float(body['latitude_min'])
- latitude_max = float(body['latitude_max'])
- longitude_min = float(body['longitude_min'])
- longitude_max = float(body['longitude_max'])
- if latitude_min<-90 or latitude_max>90 :
- return JSONResponse(status_code=500, content={"code": 500, "msg": "Latitude must be within [-90.000000, 90.000000]"})
- if longitude_min<-180 or longitude_max>180 :
- return JSONResponse(status_code=500, content={"code": 500, "msg": "Longitude must be within [-180.000000, 180.000000]"})
- table_name = 'tp_geojson_data_zj'
- option = body['option']
- if 'cj' == option:
- table_name = 'tp_geojson_data_cj_sq'
- sql = f"""SELECT id,
- name,
- pac,
- ST_AsGeoJSON(geometry) AS geojson,
- properties
- FROM {table_name}
- WHERE ST_Intersects(
- geometry,
- ST_GeomFromText(
- 'POLYGON(({latitude_min} {longitude_min},
- {latitude_max} {longitude_min},
- {latitude_max} {longitude_max},
- {latitude_min} {longitude_max},
- {latitude_min} {longitude_min}))',
- 4326
- )
- );"""
- result = db.execute(sql)
- features = result.fetchall()
- return {"code": 200,
- "msg": "操作成功","type":"FeatureCollection",
- "features": features}
- except Exception as e:
- # 处理异常
- traceback.print_exc()
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))
- @router.post("/get_geojson_new")
- async def get_geojson(
- body = Depends(remove_xss_json),
- db: Session = Depends(get_db)
- ):
- try:
- # 根据缩放级别动态调整分组粒度
- pac = body['area_code']
- if pac[-3:]=='000':
- pac=pac.replace('000','')
- table_name = 'tp_geojson_data_zj'
- option = body['option']
- if 'cj' == option:
- # print(1111)
- table_name = 'tp_geojson_data_cj_sq'
- sql = f"""SELECT name,
- pac,
- ST_AsGeoJSON(geometry) AS geometry,
- properties
- FROM {table_name}
- WHERE pac like '{pac}%';"""
- result = db.execute(sql)
- features = [
- {**dict(r), "geometry": json.loads(r.geometry)}
- for r in result.fetchall()
- ]
- # features = result.fetchall()
- # for info in features:
- # info['geometry']= {**dict(info), "geometry": json.loads(info.geometry)}
- # pass
- return {"code": 200,
- "msg": "操作成功","type":"FeatureCollection",
- "features": features}
- except Exception as e:
- # 处理异常
- traceback.print_exc()
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))
- # @router.post("/get_map_img")
- # async def get_map_img(
- # body: Dict[str, Any] = Depends(remove_xss_json),
- # db: Session = Depends(get_db),
- # ):
- # """
- # 输入:
- # {
- # "latitude_min": 27.0,
- # "latitude_max": 30.0,
- # "longitude_min": 118.0,
- # "longitude_max": 121.0,
- # "option": "zj" // "zj" 或 "cj"
- # }
- # 返回:PNG 图片
- # """
- # try:
- # # 1. 参数提取与合法性校验
- # lat_min = float(body["latitude_min"])
- # lat_max = float(body["latitude_max"])
- # lon_min = float(body["longitude_min"])
- # lon_max = float(body["longitude_max"])
- # option = body.get("option", "zj")
- #
- # if not (-90 <= lat_min <= 90 and -90 <= lat_max <= 90):
- # raise ValueError("Latitude must be within [-90, 90]")
- # if not (-180 <= lon_min <= 180 and -180 <= lon_max <= 180):
- # raise ValueError("Longitude must be within [-180, 180]")
- #
- # table_name = "tp_geojson_data_zj" if option != "cj" else "tp_geojson_data_cj_sq"
- #
- # # 2. 构造 SQL
- # sql = text(
- # f"""
- # SELECT id, name, pac, ST_AsGeoJSON(geometry) AS geojson, properties
- # FROM {table_name}
- # WHERE ST_Intersects(
- # geometry,
- # ST_GeomFromText(
- # 'POLYGON(({lon_min} {lat_min},
- # {lon_max} {lat_min},
- # {lon_max} {lat_max},
- # {lon_min} {lat_max},
- # {lon_min} {lat_min}))',
- # 4326
- # )
- # );
- # """
- # )
- #
- # rows = db.execute(sql).fetchall()
- # if not rows:
- # raise HTTPException(
- # status_code=status.HTTP_404_NOT_FOUND,
- # detail="No data within the given bbox."
- # )
- #
- # # 3. 组装 GeoDataFrame
- # features = [
- # {**json.loads(r.geojson), "properties": json.loads(r.properties)}
- # for r in rows
- # ]
- # gdf = gpd.GeoDataFrame.from_features(features, crs="EPSG:4326")
- #
- # # 4. 绘图
- # fig, ax = plt.subplots(figsize=(6, 6), dpi=150)
- # gdf.to_crs(epsg=3857).plot(ax=ax, alpha=0.5, edgecolor="black")
- # ctx.add_basemap(ax, source=ctx.providers.Stamen.TonerLite, crs=gdf.to_crs(epsg=3857).crs)
- # ax.set_axis_off()
- # plt.tight_layout(pad=0)
- #
- # # 5. 保存成字节流
- # buf = io.BytesIO()
- # fig.savefig(buf, format="png")
- # buf.seek(0)
- # plt.close(fig)
- #
- # # 6. 返回图片
- # return StreamingResponse(buf, media_type="image/png")
- #
- # except Exception as e:
- # raise HTTPException(
- # status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- # detail=str(e)
- # )
- # TILE_EXTENT = 4096 # 标准 MVT 精度
- # # ---------- 纯 Python 坐标转换 ----------
- # def lonlat2xy(lon: float, lat: float) -> tuple[float, float]:
- # """4326 -> 3857"""
- # x = lon * 20037508.34 / 180
- # y = math.log(math.tan((90 + lat) * math.pi / 360)) / (math.pi / 180)
- # y = y * 20037508.34 / 180
- # return x, y
- #
- # import mercantile
- # from mapbox_vector_tile import encode
- # from shapely.geometry import shape
- # @router.get("/tile/{option}/{z}/{x}/{y}.pbf")
- # async def get_tile(
- # option: str = Path(..., regex="^(zj|cj)$"),
- # z: int = Path(..., ge=0, le=22),
- # x: int = Path(..., ge=0),
- # y: int = Path(..., ge=0),
- # db: Session = Depends(get_db),
- # ):
- # """
- # 根据 Slippy Map 标准 XYZ 返回 MVT 二进制。
- # 前端 layer.url = '/tile/zj/{z}/{x}/{y}.pbf'
- # """
- # table = "tp_geojson_data_zj" if option == "zj" else "tp_geojson_data_cj_sq"
- #
- # # 1. 计算瓦片 bbox (4326)
- # tile_bounds = mercantile.bounds(mercantile.Tile(x, y, z))
- # xmin, ymin, xmax, ymax = tile_bounds.west, tile_bounds.south, tile_bounds.east, tile_bounds.north
- # print(xmin, ymin, xmax, ymax)
- # # 2. 查相交要素
- # sql = text(
- # f"""
- # SELECT id, name, pac, properties,
- # ST_AsGeoJSON(geometry) AS geojson
- # FROM {table}
- # WHERE ST_Intersects(
- # geometry,
- # ST_GeomFromText(
- # 'POLYGON((
- # {ymin} {xmin},
- # {ymin} {xmax},
- # {ymax} {xmax},
- # {ymax} {xmin},
- # {ymin} {xmin}))',
- # 4326
- # )
- # );
- # """
- # )
- # rows = db.execute(sql).fetchall()
- # if not rows:
- # raise HTTPException(status_code=204)
- #
- # # 3. 构造 MVT features
- # features: List[Dict[str, Any]] = []
- # bounds_3857 = mercantile.xy_bounds(mercantile.Tile(x, y, z))
- # bx, by, bw, bh = bounds_3857.left, bounds_3857.bottom, \
- # bounds_3857.right - bounds_3857.left, \
- # bounds_3857.top - bounds_3857.bottom
- #
- # for r in rows:
- # # 直接用原始 GeoJSON,不做任何坐标变换
- # geo = json.loads(r.geojson)
- # features.append({
- # "geometry": geo, # 必须是 4326 坐标
- # "properties": json.loads(r.properties),
- # })
- #
- # # 4. 生成 MVT
- # mvt_bytes = encode([
- # {
- # "name": "layer",
- # "features": features,
- # "extent": TILE_EXTENT,
- # }
- # ])
- #
- # return Response(content=mvt_bytes, media_type="application/x-protobuf")
|