point.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. from fastapi import APIRouter, Request, Depends, Query, HTTPException, status,Path
  4. from common.security import valid_access_token
  5. from fastapi.responses import JSONResponse,Response
  6. from fastapi.responses import StreamingResponse
  7. from sqlalchemy.orm import Session
  8. from sqlalchemy import and_, or_,text,literal
  9. from sqlalchemy.sql import func
  10. from sqlalchemy.future import select
  11. from common.auth_user import *
  12. from pydantic import BaseModel
  13. from typing import Any, Dict
  14. # import contextily as ctx
  15. # import geopandas as gpd
  16. # from matplotlib import pyplot as plt
  17. import io
  18. from database import get_db
  19. from typing import List
  20. from models import *
  21. from utils import *
  22. from utils.ry_system_util import *
  23. from utils.video_util import *
  24. from collections import defaultdict
  25. import traceback
  26. from concurrent.futures import ThreadPoolExecutor, as_completed
  27. from multiprocessing import Pool, cpu_count
  28. import json
  29. import time
  30. import math
  31. router = APIRouter()
  32. @router.post("/get_info")
  33. @router.get("/get_info")
  34. async def get_infos(
  35. body = Depends(remove_xss_json),
  36. # zoom_level: float = Query(..., description="Zoom level for clustering"),
  37. # latitude_min: float = Query(..., description="Minimum latitude"),
  38. # latitude_max: float = Query(..., description="Maximum latitude"),
  39. # longitude_min: float = Query(..., description="Minimum longitude"),
  40. # longitude_max: float = Query(..., description="Maximum longitude"),
  41. # dict_value: str = Query(None),
  42. # option:str = Query(None),
  43. db: Session = Depends(get_db)
  44. ):
  45. try:
  46. # 根据缩放级别动态调整分组粒度
  47. zoom_level = float(body['zoom_level'])
  48. zoom_levels = {
  49. 3: 10000, # 全国范围
  50. 4: 5000,
  51. 5: 2500,
  52. 6: 1250,
  53. 7: 825,
  54. 8: 412.5,
  55. 9: 256.25,
  56. 10: 178.125,
  57. 11: 69.0625,
  58. 12: 29.53125,
  59. 13: 13.765625,
  60. 14: 5.8828125,
  61. 15: 2.44140625,
  62. 16: 1.220703125,
  63. 17: 0.6103515625,
  64. 18: 0.30517578125
  65. }
  66. distance_threshold=zoom_levels[int(zoom_level-1)]
  67. # distance_threshold = 100000 / (2.2 ** zoom_level) # 例如:每缩放一级,距离阈值减半
  68. dict_value= body['dict_value'].split(',')
  69. latitude_min = float(body['latitude_min'])
  70. latitude_max = float(body['latitude_max'])
  71. longitude_min = float(body['longitude_min'])
  72. longitude_max = float(body['longitude_max'])
  73. option = body['option'].split(',')
  74. print("1",time.time())
  75. videos = get_videos(db,dict_value,latitude_min,latitude_max,longitude_min,longitude_max)
  76. infos = get_points(db,option,latitude_min,latitude_max,longitude_min,longitude_max)
  77. # 动态分组逻辑
  78. groups = group_points(videos+infos, distance_threshold)
  79. print("4",time.time())
  80. return {"code": 200,
  81. "msg": "操作成功",
  82. "data": groups}
  83. except Exception as e:
  84. # 处理异常
  85. traceback.print_exc()
  86. raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))
  87. @router.post("/get_details")
  88. @router.get("/get_details")
  89. async def get_details(
  90. body = Depends(remove_xss_json),
  91. # center_latitude: float = Query(..., description="网格中心点的纬度"),
  92. # center_longitude: float = Query(..., description="网格中心点的经度"),
  93. # zoom_level: float = Query(..., description="缩放级别"),
  94. db: Session = Depends(get_db)
  95. ):
  96. try:
  97. # 计算网格大小
  98. zoom_level = float(body['zoom_level'])
  99. zoom_levels = {
  100. 3: 10000, # 全国范围
  101. 4: 5000,
  102. 5: 2500,
  103. 6: 1250,
  104. 7: 825,
  105. 8: 412.5,
  106. 9: 256.25,
  107. 10: 178.125,
  108. 11: 69.0625,
  109. 12: 29.53125,
  110. 13: 13.765625,
  111. 14: 5.8828125,
  112. 15: 2.44140625,
  113. 16: 1.220703125,
  114. 17: 0.6103515625,
  115. 18: 0.30517578125
  116. }
  117. distance_threshold=zoom_levels[int(zoom_level-1)]
  118. # distance_threshold = 1000 / (1.5 ** zoom_level) # 例如:每缩放一级,距离阈值减半
  119. grid_size = calculate_grid_size(distance_threshold) # 地球半径为6371公里
  120. center_latitude = float(body['latitude'])
  121. center_longitude = float(body['longitude'])
  122. dict_value = body['dict_value'].split(',')
  123. option = body['option'].split(',')
  124. # 计算网格的经纬度范围
  125. latitude_min, latitude_max, longitude_min, longitude_max = get_grid_bounds_from_center(center_latitude, center_longitude, grid_size)
  126. videos = get_videos(db,dict_value,latitude_min,latitude_max,longitude_min,longitude_max)
  127. infos = get_points(db,option,latitude_min,latitude_max,longitude_min,longitude_max)
  128. return {"code": 200,
  129. "msg": "操作成功",
  130. "data": videos+infos }#{"videos":videos,"points":infos}}
  131. except Exception as e:
  132. # 处理异常
  133. traceback.print_exc()
  134. raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))
  135. def calculate_grid_size(distance_threshold):
  136. # 假设地球半径为6371公里,将距离阈值转换为经纬度的差值
  137. # 这里假设纬度变化对距离的影响较小,仅根据经度计算网格大小
  138. earth_radius = 6371 # 地球半径,单位为公里
  139. grid_size = distance_threshold / earth_radius
  140. return grid_size
  141. def get_grid_key(latitude, longitude, grid_size):
  142. # 根据经纬度和网格大小计算网格键
  143. return (math.floor(latitude / grid_size), math.floor(longitude / grid_size))
  144. def get_grid_bounds_from_center(center_latitude, center_longitude, grid_size):
  145. half_grid_size = grid_size / 2
  146. min_latitude = center_latitude - half_grid_size
  147. max_latitude = center_latitude + half_grid_size
  148. min_longitude = center_longitude - half_grid_size
  149. max_longitude = center_longitude + half_grid_size
  150. return min_latitude, max_latitude, min_longitude, max_longitude
  151. def calculate_distance(point1, point2):
  152. # 使用 Haversine 公式计算两点之间的距离
  153. from math import radians, sin, cos, sqrt, atan2
  154. R = 6371 # 地球半径(公里)
  155. lat1, lon1 = radians(point1.latitude), radians(point1.longitude)
  156. lat2, lon2 = radians(point2.latitude), radians(point2.longitude)
  157. dlat = lat2 - lat1
  158. dlon = lon2 - lon1
  159. a = sin(dlat / 2) ** 2 + cos(lat1) * cos(lat2) * sin(dlon / 2) ** 2
  160. c = 2 * atan2(sqrt(a), sqrt(1 - a))
  161. return R * c
  162. def group_points(points, distance_threshold):
  163. grid_size = calculate_grid_size(distance_threshold)
  164. grid = defaultdict(lambda:{"count":0}) #,"list":[]
  165. groups = []
  166. tmp = defaultdict(list)
  167. for point in points:
  168. grid_key = get_grid_key(float(point.latitude), float(point.longitude), grid_size)
  169. lovalue = str(point.latitude)+str(point.longitude)
  170. if lovalue not in tmp['%s-%s'%grid_key]:
  171. tmp['%s-%s'%grid_key].append(lovalue)
  172. grid['%s-%s'%grid_key]['count']+=1
  173. if grid['%s-%s'%grid_key]['count']>1 and len(tmp['%s-%s'%grid_key])<3:
  174. grid['%s-%s'%grid_key]['dataType'] = ''
  175. grid['%s-%s'%grid_key]['id'] = ""
  176. if len(tmp['%s-%s'%grid_key])<2:
  177. grid['%s-%s'%grid_key]['name'] = '多数据点位'
  178. grid['%s-%s' % grid_key]['type'] ='1'
  179. else:
  180. grid['%s-%s' % grid_key]['name'] = '聚合点位'
  181. grid['%s-%s' % grid_key]['type'] = '3'
  182. # grid['%s-%s'%grid_key]['latitude'] = float(point.latitude) #(grid_key[0] + 0.5) * grid_size
  183. # grid['%s-%s'%grid_key]['longitude'] = float(point.longitude) #(grid_key[1] + 0.5) * grid_size
  184. elif grid['%s-%s'%grid_key]['count']==1:
  185. if point.dataType=='video':
  186. grid['%s-%s' % grid_key]['id'] = point.gbIndexCode
  187. else:
  188. grid['%s-%s' % grid_key]['id'] = point.id
  189. grid['%s-%s'%grid_key]['dataType'] = point.dataType
  190. grid['%s-%s'%grid_key]['infoType'] = point.infoType
  191. grid['%s-%s'%grid_key]['name'] = point.name
  192. grid['%s-%s'%grid_key]['type'] ='2'
  193. grid['%s-%s'%grid_key]['latitude'] = float(point.latitude)
  194. grid['%s-%s'%grid_key]['longitude'] = float(point.longitude)
  195. groups = list(grid.values())
  196. return groups
  197. def get_videos(db:Session,dict_value,latitude_min,latitude_max,longitude_min,longitude_max):
  198. que = True
  199. if len(dict_value)>0:
  200. videolist = []
  201. for value in dict_value:
  202. tag_info = get_dict_data_info(db, 'video_type', value)
  203. if tag_info:
  204. if tag_info.dict_label == '全量视频':
  205. break
  206. else:
  207. videolist += [i.video_code for i in tag_get_video_tag_list(db, value)]
  208. else:
  209. que = TPVideoInfo.gbIndexCode.in_(videolist)
  210. # 查询分组
  211. query = (
  212. select(
  213. TPVideoInfo.gbIndexCode,
  214. TPVideoInfo.latitude,
  215. TPVideoInfo.longitude,
  216. TPVideoInfo.name,
  217. TPVideoInfo.status,
  218. literal('video').label("dataType"),
  219. literal('video').label("infoType")
  220. )
  221. .select_from(TPVideoInfo).where(
  222. and_(
  223. TPVideoInfo.latitude >= latitude_min,
  224. TPVideoInfo.latitude <= latitude_max,
  225. TPVideoInfo.longitude >= longitude_min,
  226. TPVideoInfo.longitude <= longitude_max,
  227. TPVideoInfo.longitude > 0,
  228. TPVideoInfo.latitude > 0, que
  229. )
  230. )
  231. .order_by(TPVideoInfo.status.asc())
  232. )
  233. result = db.execute(query)
  234. videos = result.fetchall()
  235. return videos
  236. def get_points(db:Session,option,latitude_min,latitude_max,longitude_min,longitude_max):
  237. # 使用参数化查询避免 SQL 注入
  238. if isinstance(option, list):
  239. option = tuple(option)
  240. query = text("""
  241. SELECT
  242. A.`name`,A.`id`,A.dataType,A.longitude,A.latitude,A.infoType
  243. FROM (
  244. SELECT
  245. *,
  246. ROW_NUMBER() OVER (PARTITION BY longitude, latitude, `name`
  247. ORDER BY longitude, latitude, `name`) AS rn
  248. FROM
  249. `point_data`
  250. WHERE
  251. longitude > 0
  252. AND latitude BETWEEN :latitude_min AND :latitude_max
  253. AND longitude BETWEEN :longitude_min AND :longitude_max
  254. AND dataType IN :option
  255. ) AS A
  256. WHERE rn = 1
  257. """)
  258. # 执行查询并传递参数
  259. result = db.execute(query, {
  260. 'latitude_min': latitude_min,
  261. 'latitude_max': latitude_max,
  262. 'longitude_min': longitude_min,
  263. 'longitude_max': longitude_max,
  264. 'option': option
  265. })
  266. infos = result.fetchall()
  267. return infos
  268. @router.post("/get_geojson")
  269. async def get_geojson(
  270. body = Depends(remove_xss_json),
  271. db: Session = Depends(get_db)
  272. ):
  273. try:
  274. # 根据缩放级别动态调整分组粒度
  275. latitude_min = float(body['latitude_min'])
  276. latitude_max = float(body['latitude_max'])
  277. longitude_min = float(body['longitude_min'])
  278. longitude_max = float(body['longitude_max'])
  279. if latitude_min<-90 or latitude_max>90 :
  280. return JSONResponse(status_code=500, content={"code": 500, "msg": "Latitude must be within [-90.000000, 90.000000]"})
  281. if longitude_min<-180 or longitude_max>180 :
  282. return JSONResponse(status_code=500, content={"code": 500, "msg": "Longitude must be within [-180.000000, 180.000000]"})
  283. table_name = 'tp_geojson_data_zj'
  284. option = body['option']
  285. if 'cj' == option:
  286. table_name = 'tp_geojson_data_cj_sq'
  287. sql = f"""SELECT id,
  288. name,
  289. pac,
  290. ST_AsGeoJSON(geometry) AS geojson,
  291. properties
  292. FROM {table_name}
  293. WHERE ST_Intersects(
  294. geometry,
  295. ST_GeomFromText(
  296. 'POLYGON(({latitude_min} {longitude_min},
  297. {latitude_max} {longitude_min},
  298. {latitude_max} {longitude_max},
  299. {latitude_min} {longitude_max},
  300. {latitude_min} {longitude_min}))',
  301. 4326
  302. )
  303. );"""
  304. result = db.execute(sql)
  305. features = result.fetchall()
  306. return {"code": 200,
  307. "msg": "操作成功","type":"FeatureCollection",
  308. "features": features}
  309. except Exception as e:
  310. # 处理异常
  311. traceback.print_exc()
  312. raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))
  313. @router.post("/get_geojson_new")
  314. async def get_geojson(
  315. body = Depends(remove_xss_json),
  316. db: Session = Depends(get_db)
  317. ):
  318. try:
  319. # 根据缩放级别动态调整分组粒度
  320. pac = body['area_code']
  321. if pac[-3:]=='000':
  322. pac=pac.replace('000','')
  323. table_name = 'tp_geojson_data_zj'
  324. option = body['option']
  325. if 'cj' == option:
  326. # print(1111)
  327. table_name = 'tp_geojson_data_cj_sq'
  328. sql = f"""SELECT name,
  329. pac,
  330. ST_AsGeoJSON(geometry) AS geometry,
  331. properties
  332. FROM {table_name}
  333. WHERE pac like '{pac}%';"""
  334. result = db.execute(sql)
  335. features = [
  336. {**dict(r), "geometry": json.loads(r.geometry)}
  337. for r in result.fetchall()
  338. ]
  339. # features = result.fetchall()
  340. # for info in features:
  341. # info['geometry']= {**dict(info), "geometry": json.loads(info.geometry)}
  342. # pass
  343. return {"code": 200,
  344. "msg": "操作成功","type":"FeatureCollection",
  345. "features": features}
  346. except Exception as e:
  347. # 处理异常
  348. traceback.print_exc()
  349. raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))
  350. # @router.post("/get_map_img")
  351. # async def get_map_img(
  352. # body: Dict[str, Any] = Depends(remove_xss_json),
  353. # db: Session = Depends(get_db),
  354. # ):
  355. # """
  356. # 输入:
  357. # {
  358. # "latitude_min": 27.0,
  359. # "latitude_max": 30.0,
  360. # "longitude_min": 118.0,
  361. # "longitude_max": 121.0,
  362. # "option": "zj" // "zj" 或 "cj"
  363. # }
  364. # 返回:PNG 图片
  365. # """
  366. # try:
  367. # # 1. 参数提取与合法性校验
  368. # lat_min = float(body["latitude_min"])
  369. # lat_max = float(body["latitude_max"])
  370. # lon_min = float(body["longitude_min"])
  371. # lon_max = float(body["longitude_max"])
  372. # option = body.get("option", "zj")
  373. #
  374. # if not (-90 <= lat_min <= 90 and -90 <= lat_max <= 90):
  375. # raise ValueError("Latitude must be within [-90, 90]")
  376. # if not (-180 <= lon_min <= 180 and -180 <= lon_max <= 180):
  377. # raise ValueError("Longitude must be within [-180, 180]")
  378. #
  379. # table_name = "tp_geojson_data_zj" if option != "cj" else "tp_geojson_data_cj_sq"
  380. #
  381. # # 2. 构造 SQL
  382. # sql = text(
  383. # f"""
  384. # SELECT id, name, pac, ST_AsGeoJSON(geometry) AS geojson, properties
  385. # FROM {table_name}
  386. # WHERE ST_Intersects(
  387. # geometry,
  388. # ST_GeomFromText(
  389. # 'POLYGON(({lon_min} {lat_min},
  390. # {lon_max} {lat_min},
  391. # {lon_max} {lat_max},
  392. # {lon_min} {lat_max},
  393. # {lon_min} {lat_min}))',
  394. # 4326
  395. # )
  396. # );
  397. # """
  398. # )
  399. #
  400. # rows = db.execute(sql).fetchall()
  401. # if not rows:
  402. # raise HTTPException(
  403. # status_code=status.HTTP_404_NOT_FOUND,
  404. # detail="No data within the given bbox."
  405. # )
  406. #
  407. # # 3. 组装 GeoDataFrame
  408. # features = [
  409. # {**json.loads(r.geojson), "properties": json.loads(r.properties)}
  410. # for r in rows
  411. # ]
  412. # gdf = gpd.GeoDataFrame.from_features(features, crs="EPSG:4326")
  413. #
  414. # # 4. 绘图
  415. # fig, ax = plt.subplots(figsize=(6, 6), dpi=150)
  416. # gdf.to_crs(epsg=3857).plot(ax=ax, alpha=0.5, edgecolor="black")
  417. # ctx.add_basemap(ax, source=ctx.providers.Stamen.TonerLite, crs=gdf.to_crs(epsg=3857).crs)
  418. # ax.set_axis_off()
  419. # plt.tight_layout(pad=0)
  420. #
  421. # # 5. 保存成字节流
  422. # buf = io.BytesIO()
  423. # fig.savefig(buf, format="png")
  424. # buf.seek(0)
  425. # plt.close(fig)
  426. #
  427. # # 6. 返回图片
  428. # return StreamingResponse(buf, media_type="image/png")
  429. #
  430. # except Exception as e:
  431. # raise HTTPException(
  432. # status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
  433. # detail=str(e)
  434. # )
  435. # TILE_EXTENT = 4096 # 标准 MVT 精度
  436. # # ---------- 纯 Python 坐标转换 ----------
  437. # def lonlat2xy(lon: float, lat: float) -> tuple[float, float]:
  438. # """4326 -> 3857"""
  439. # x = lon * 20037508.34 / 180
  440. # y = math.log(math.tan((90 + lat) * math.pi / 360)) / (math.pi / 180)
  441. # y = y * 20037508.34 / 180
  442. # return x, y
  443. #
  444. # import mercantile
  445. # from mapbox_vector_tile import encode
  446. # from shapely.geometry import shape
  447. # @router.get("/tile/{option}/{z}/{x}/{y}.pbf")
  448. # async def get_tile(
  449. # option: str = Path(..., regex="^(zj|cj)$"),
  450. # z: int = Path(..., ge=0, le=22),
  451. # x: int = Path(..., ge=0),
  452. # y: int = Path(..., ge=0),
  453. # db: Session = Depends(get_db),
  454. # ):
  455. # """
  456. # 根据 Slippy Map 标准 XYZ 返回 MVT 二进制。
  457. # 前端 layer.url = '/tile/zj/{z}/{x}/{y}.pbf'
  458. # """
  459. # table = "tp_geojson_data_zj" if option == "zj" else "tp_geojson_data_cj_sq"
  460. #
  461. # # 1. 计算瓦片 bbox (4326)
  462. # tile_bounds = mercantile.bounds(mercantile.Tile(x, y, z))
  463. # xmin, ymin, xmax, ymax = tile_bounds.west, tile_bounds.south, tile_bounds.east, tile_bounds.north
  464. # print(xmin, ymin, xmax, ymax)
  465. # # 2. 查相交要素
  466. # sql = text(
  467. # f"""
  468. # SELECT id, name, pac, properties,
  469. # ST_AsGeoJSON(geometry) AS geojson
  470. # FROM {table}
  471. # WHERE ST_Intersects(
  472. # geometry,
  473. # ST_GeomFromText(
  474. # 'POLYGON((
  475. # {ymin} {xmin},
  476. # {ymin} {xmax},
  477. # {ymax} {xmax},
  478. # {ymax} {xmin},
  479. # {ymin} {xmin}))',
  480. # 4326
  481. # )
  482. # );
  483. # """
  484. # )
  485. # rows = db.execute(sql).fetchall()
  486. # if not rows:
  487. # raise HTTPException(status_code=204)
  488. #
  489. # # 3. 构造 MVT features
  490. # features: List[Dict[str, Any]] = []
  491. # bounds_3857 = mercantile.xy_bounds(mercantile.Tile(x, y, z))
  492. # bx, by, bw, bh = bounds_3857.left, bounds_3857.bottom, \
  493. # bounds_3857.right - bounds_3857.left, \
  494. # bounds_3857.top - bounds_3857.bottom
  495. #
  496. # for r in rows:
  497. # # 直接用原始 GeoJSON,不做任何坐标变换
  498. # geo = json.loads(r.geojson)
  499. # features.append({
  500. # "geometry": geo, # 必须是 4326 坐标
  501. # "properties": json.loads(r.properties),
  502. # })
  503. #
  504. # # 4. 生成 MVT
  505. # mvt_bytes = encode([
  506. # {
  507. # "name": "layer",
  508. # "features": features,
  509. # "extent": TILE_EXTENT,
  510. # }
  511. # ])
  512. #
  513. # return Response(content=mvt_bytes, media_type="application/x-protobuf")