__init__.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. from fastapi import APIRouter, Request, Depends,Response,HTTPException,status,Query
  4. from fastapi.responses import StreamingResponse
  5. from fastapi.responses import JSONResponse
  6. from cachetools.keys import hashkey
  7. from sqlalchemy.orm import Session
  8. from typing import List, Optional
  9. from cachetools import LRUCache
  10. from pydantic import BaseModel
  11. from database import get_db
  12. from urllib import parse
  13. from models import *
  14. from utils import *
  15. import requests
  16. import hashlib
  17. import random
  18. import string
  19. import json
  20. import time
  21. import io
  22. router = APIRouter()
  23. def GetSign(signTime,nonce,passtoken):
  24. data = signTime+passtoken+nonce+signTime
  25. return hashlib.sha256(data.encode('utf-8')).hexdigest()
  26. def GetNonce(length):
  27. characters = string.ascii_letters + string.digits
  28. # 随机选择字符集的长度个字符
  29. random_string = ''.join(random.choice(characters) for _ in range(length))
  30. return random_string
  31. def GetTime():
  32. return int(time.time()*1000)
  33. # 设置缓存大小,例如 100 个缓存项
  34. cache_size = 10000
  35. cache = LRUCache(maxsize=cache_size)
  36. # 缓存过期时间,例如 10 分钟
  37. cache.ttl = 86400 # ttl(Time to Live) 以秒为单位
  38. @router.post('/proxyHandler/{city_code:path}/{service_code:path}')
  39. @router.get('/proxyHandler/{city_code:path}/{service_code:path}')
  40. async def mine(request: Request,body = Depends(remove_xss_json),db: Session = Depends(get_db)):
  41. target = str(request.url) # 获取接口地址
  42. cache_key = hashkey(target)
  43. cached_response = cache.get(cache_key)
  44. if cached_response:
  45. if 'image/png' in cached_response.headers.get('Content-Type', ''):
  46. return StreamingResponse(content=io.BytesIO(cached_response.content), media_type='image/png')
  47. return cached_response
  48. # 获取接口id 判断接口是否存在
  49. service_code = request.path_params.get('service_code')
  50. service_info = db.query(OneShareApiEntity).filter(OneShareApiEntity.servercode==service_code).first()
  51. if service_info is None:
  52. return JSONResponse(status_code=410, content={
  53. 'code': 410,
  54. 'msg': f'server_code{service_code}服务不存在'
  55. })
  56. city_code = request.path_params.get('city_code')
  57. if city_code == "mm":
  58. url = 'https://19.155.242.125/GatewayMsg/http/api/proxy/invoke'
  59. elif city_code == "gd":
  60. url = 'https://19.15.75.180:8581/GatewayMsg/http/api/proxy/invoke'
  61. else:
  62. return JSONResponse(status_code=410, content={
  63. 'code': 410,
  64. 'msg': f'city_code{city_code}不存在'
  65. })
  66. # 获取请求方式
  67. method = request.method
  68. # 获取请求体
  69. # body = await request.body()
  70. # body = body.decode(encoding='utf-8')
  71. # if len(body) > 0:
  72. # body = json.loads(body)
  73. # 获取默认params 1
  74. params_default = service_info.params_default
  75. if len(params_default)>0:
  76. print(params_default)
  77. params_default = json.loads(params_default)
  78. # 获取params
  79. query_list = parse.parse_qsl(str(request.query_params))
  80. params = {}
  81. for (key, val) in query_list:
  82. params[key]=val
  83. params_default.update(params)
  84. params = params_default
  85. # 生成请求头主体
  86. signTime = str(GetTime()//1000)
  87. nonce = GetNonce(5)
  88. sign = GetSign(signTime,nonce,service_info.passtoken)
  89. # 初始请求头
  90. headers = {
  91. # 'Content-Type': 'application/json',
  92. 'x-tif-signature': sign,
  93. 'x-tif-timestamp': signTime,
  94. 'x-tif-nonce': nonce,
  95. 'x-tif-paasid': service_info.passid,
  96. 'x-tif-serviceId': service_code
  97. }
  98. # 加入默认请求头
  99. headers_default = service_info.headers_default
  100. if len(headers_default)>0:
  101. headers_default = json.loads(headers_default)
  102. headers.update(headers_default)
  103. # 判断接口类型
  104. # 1 普通接口 请求头请求体用默认,前端传输的请求体嵌入到请求体query中
  105. # 2 地图接口
  106. # 3 自定义接口
  107. if service_info.servertype == 1:
  108. query_timestamp = str(GetTime())
  109. data = {
  110. "system_id": service_info.passid,
  111. "vender_id": 'xx',
  112. "department_id": 'xx',
  113. "query_timestamp": query_timestamp,
  114. "UID": GetNonce(5),
  115. "query": body,
  116. "audit_info": {
  117. "operator_id": 'xx',
  118. "operator_name": 'xx',
  119. "query_object_id": 'xx',
  120. "query_object_id_type": 'xx',
  121. "item_id": 'xx',
  122. "item_code": 'xx',
  123. "item_sequence": 'xx',
  124. "terminal_info": 'xx',
  125. "query_timestamp": query_timestamp
  126. }
  127. }
  128. body=data
  129. else:
  130. body_default = service_info.body_default
  131. if len(body_default)>0:
  132. body_default = json.loads(body_default)
  133. body_default.update(body)
  134. # 根据请求方式请求获取数据
  135. if method == "GET":
  136. response = requests.get(url=url,params=params,headers=headers,json=body,verify=False)
  137. elif method == "POST":
  138. response = requests.post(url=url,params=params,headers=headers,json=body,verify=False)
  139. else:
  140. return JSONResponse(status_code=410, content={
  141. 'code': 410,
  142. 'msg': f'请求方式{method}不支持'
  143. })
  144. # 根据响应头数据类型返回对应数据类型
  145. content_type = response.headers.get('Content-Type', '')
  146. if 'application/json' in content_type:
  147. return JSONResponse(content=response.json(), media_type='application/json',status_code=response.status_code)
  148. elif 'application/xml' in content_type or 'text/xml' in content_type:
  149. return Response(content=response.text, media_type='application/xml',status_code=response.status_code)
  150. elif 'text/html' in content_type:
  151. return Response(content=response.text, media_type='application/html',status_code=response.status_code)
  152. elif 'image/png' in content_type:
  153. if response.status_code==200:
  154. cache[cache_key] = response
  155. return StreamingResponse(content=io.BytesIO(response.content), media_type='image/png')
  156. # 可以继续添加更多的条件分支来处理其他类型
  157. else:
  158. return Response(content=response.text, media_type=content_type)
  159. class OneShareApiCreateForm(BaseModel):
  160. passid: str
  161. passtoken: str
  162. servercode: str
  163. servertype: int
  164. params_default: str = ""
  165. body_default: str = ""
  166. headers_default: str = ""
  167. servername: str
  168. @router.post("/create")
  169. async def create_one_share_api(
  170. api_data: OneShareApiCreateForm,
  171. db: Session = Depends(get_db)
  172. ):
  173. try:
  174. # 创建一个新的 OneShareApiEntity 实例
  175. new_api = OneShareApiEntity(
  176. passid=api_data.passid,
  177. passtoken=api_data.passtoken,
  178. servercode=api_data.servercode,
  179. servertype=api_data.servertype,
  180. params_default=api_data.params_default,
  181. body_default=api_data.body_default,
  182. headers_default=api_data.headers_default,
  183. servername=api_data.servername
  184. )
  185. # 添加到数据库会话并提交
  186. db.add(new_api)
  187. db.commit()
  188. db.refresh(new_api) # 刷新实例以包含新的 ID 等信息
  189. # 构建并返回响应
  190. return {
  191. "code": 200,
  192. "msg": "操作成功",
  193. "data": None
  194. }
  195. except Exception as e:
  196. # 处理异常
  197. raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))
  198. class OneShareApiUpdateForm(BaseModel):
  199. passid: str = None
  200. passtoken: str = None
  201. servercode: str = None
  202. servertype: int = None
  203. params_default: str = None
  204. body_default: str = None
  205. headers_default: str = None
  206. servername: str = None
  207. @router.put("/update/{id}") # 或者使用 @app.patch 如果你只想更新部分字段
  208. async def update_one_share_api(
  209. id: int,
  210. api_data: OneShareApiUpdateForm,
  211. db: Session = Depends(get_db)
  212. ):
  213. try:
  214. # 从数据库中获取现有的 OneShareApiEntity 实例
  215. api = db.query(OneShareApiEntity).filter(OneShareApiEntity.id == id).first()
  216. if not api:
  217. raise HTTPException(status_code=404, detail="接口不存在")
  218. # 更新字段
  219. if api_data.passid:
  220. api.passid = api_data.passid
  221. if api_data.passtoken:
  222. api.passtoken = api_data.passtoken
  223. if api_data.servercode:
  224. api.servercode = api_data.servercode
  225. if api_data.servertype:
  226. api.servertype = api_data.servertype
  227. if api_data.params_default:
  228. api.params_default = api_data.params_default
  229. if api_data.body_default:
  230. api.body_default = api_data.body_default
  231. if api_data.headers_default:
  232. api.headers_default = api_data.headers_default
  233. if api_data.servername:
  234. api.servername = api_data.servername
  235. # 更新时间
  236. api.update_time = datetime.now()
  237. # 提交更改
  238. db.commit()
  239. # 构建并返回响应
  240. return {
  241. "code": 200,
  242. "msg": "操作成功",
  243. "data": None
  244. }
  245. except Exception as e:
  246. # 处理异常
  247. raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))
  248. @router.delete("/delete/{id}") # 使用 ID 来标识要删除的接口
  249. async def delete_one_share_api(
  250. id: int,
  251. db: Session = Depends(get_db)
  252. ):
  253. # try:
  254. # 从数据库中获取要删除的 OneShareApiEntity 实例
  255. api = db.query(OneShareApiEntity).filter(OneShareApiEntity.id == id).first()
  256. if api is None:
  257. raise HTTPException(status_code=404, detail="接口不存在")
  258. # 删除实例
  259. db.delete(api)
  260. db.commit()
  261. # 构建并返回响应
  262. return {
  263. "code": 200,
  264. "msg": "操作成功",
  265. "data": None
  266. }
  267. # except Exception as e:
  268. # # 处理异常
  269. # print(e)
  270. # raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))
  271. class OneShareApiOut(BaseModel):
  272. id: int
  273. passid: str
  274. passtoken: str
  275. servercode: str
  276. servertype: int
  277. params_default: str
  278. body_default: str
  279. headers_default: str
  280. servername: str
  281. create_time: str
  282. update_time: str
  283. # 定义查询接口的路由
  284. @router.get("/list")
  285. async def get_one_share_apis(
  286. page: int = Query(1, gt=0),
  287. page_size: int = Query(10, gt=0),
  288. servername: Optional[str] = Query(None),
  289. servertype:int = Query(None),
  290. servercode: Optional[str] = Query(None),
  291. db: Session = Depends(get_db)
  292. ):
  293. try:
  294. # 构建查询
  295. query = db.query(OneShareApiEntity)
  296. # 应用查询参数
  297. if servername:
  298. query = query.filter(OneShareApiEntity.servername.contains(servername))
  299. if servercode:
  300. query = query.filter(OneShareApiEntity.servercode.contains(servercode))
  301. if servertype:
  302. query = query.filter(OneShareApiEntity.servertype==servertype)
  303. # 获取总记录数
  304. total_count = query.count()
  305. # 执行分页查询
  306. items = query.offset((page - 1) * page_size).limit(page_size).all()
  307. # 将查询结果转换为 Pydantic 模型列表
  308. apis_out = [
  309. OneShareApiOut(
  310. id=item.id,
  311. passid=item.passid,
  312. passtoken=item.passtoken,
  313. servercode=item.servercode,
  314. servertype=item.servertype,
  315. params_default=item.params_default,
  316. body_default=item.body_default,
  317. headers_default=item.headers_default,
  318. servername=item.servername,
  319. create_time=item.create_time.strftime('%Y-%m-%d %H:%M:%S'),
  320. update_time=item.update_time.strftime('%Y-%m-%d %H:%M:%S')
  321. ) for item in items
  322. ]
  323. # 构建并返回响应
  324. return {
  325. "code": 200,
  326. "msg": "查询成功",
  327. "data": {
  328. "total": total_count,
  329. "list": apis_out
  330. }
  331. }
  332. except Exception as e:
  333. # 处理异常
  334. raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))