Browse Source

危化企业接口变更

libushang 3 days ago
parent
commit
ce64b8b6ff

+ 1 - 1
routers/api/__init__.py

@@ -56,7 +56,7 @@ router.include_router(login_router)
 
 router.include_router(system.router, prefix="/system", dependencies=[Depends(valid_access_token_role)])
 
-router.include_router(gateway.router, prefix="/gateway", dependencies=[Depends(valid_access_token_role)])
+router.include_router(gateway.router, prefix="/gateway")
 router.include_router(dataAnalysis.router, prefix="/dataAnalysis", dependencies=[Depends(valid_access_token_role)])
 router.include_router(resourceMonitoring.router, prefix="/resource")
 router.include_router(jobs.router, prefix="/jobs", dependencies=[Depends(valid_access_token_role)])

+ 7 - 794
routers/api/gateway/__init__.py

@@ -1,802 +1,15 @@
 #!/usr/bin/env python3
 # -*- coding: utf-8 -*-
-from re import S
-from fastapi import APIRouter, Request, Depends, Form, Body, File, UploadFile, BackgroundTasks
-from fastapi.responses import Response
-from fastapi.responses import JSONResponse
-from database import get_db
-from sqlalchemy.orm import Session
-from utils import *
-from models import *
-from urllib import parse
-from common.DBgetdata import dbgetdata
-from pprint import pprint
-from datetime import datetime
-# import pandas as pd
-from pydantic import BaseModel
-import sqlalchemy
-import pymysql
-import json,time,io
-import time, os
-import math
-import uuid
-import re
+from fastapi import APIRouter, Depends
+from common.security import valid_access_token_role
+from . import v1
 from . import sign_api
 from . import skwhp_api
 
 router = APIRouter()
 
-router.include_router(sign_api.router)
-router.include_router(skwhp_api.router)
-
-def contains_special_characters(input_string, special_characters=";|&|$|#|'|@"):
-    """
-    判断字符串是否包含特殊符号。
-
-    :param input_string: 需要检查的字符串
-    :param special_characters: 特殊符号的字符串,多个符号用竖线 '|' 分隔
-    :return: 如果包含特殊符号返回 True,否则返回 False
-    """
-    # 创建正则表达式模式
-    pattern = re.compile('[' + re.escape(special_characters) + ']')
-
-    # 搜索字符串中的特殊符号
-    if pattern.search(input_string):
-        return True
-    return False
-
-
-
-
-
-@router.get('/v1/demo')
-@router.post('/v1/demo')
-async def test1(request: Request):
-    data = await request.body()
-    body = data.decode(encoding='utf-8')
-    # print(body)
-    if len(body) > 0:
-        body = json.loads(body)
-        print(body)
-        print([body,{'msg':'good'}])
-        return body
-    else:
-        return body
-
-
-@router.post('/create_api_service')
-async def create_api_service(
-    request: Request,
-    db: Session = Depends(get_db)
-):
-    # print(1)
-    data = await request.body()
-    body = data.decode(encoding='utf-8')
-    if len(body) > 0:
-        # print(body)
-        body = json.loads(body)
-    serviceid = str(uuid.uuid4()).replace("-","")
-    serviname = body['servicename']
-    datasource = str(uuid.uuid1())
-    sqlname = body['sqlname']
-    id = str(uuid.uuid4())
-    dbtype = body['dbtype']
-    if body['dbtype'] == 'pymysql':
-        host = body['host']
-        user = body['user']
-        password = body['password']
-        database = body['database']
-        port = body['port']
-        charset = body['charset']
-        create_datasource = DatasourceEntity(id= datasource,name=sqlname,scope=serviceid,dbtype=dbtype,host=host,user=user,password=password,database=database,port=port,charset=charset)
-        # db.add()
-        # print()
-        db.add(create_datasource)
-    sqltext = body['sqltext']
-    create_command = CommandEntity(id=id,sqltext=sqltext,datasource=datasource,scope=serviceid)
-    create_api = ApiServiceEntity(id=serviceid,name=serviname)
-    db.add(create_api)
-    db.add(create_command)
-    db.commit()
-    return serviceid
-
-
-@router.get('/v1/{service_code:path}')
-@router.post('/v1/{service_code:path}')
-async def v1(request: Request, background_tasks: BackgroundTasks, db: Session = Depends(get_db)):
-    service_code = request.path_params.get('service_code')
-    query_list = parse.parse_qsl(str(request.query_params))
-    print(query_list)
-    data = await request.body()
-    body = data.decode(encoding='utf-8')
-    if len(body) > 0:
-        print('[body]', body)
-        body = json.loads(body)
-    # print("1111",body)
-
-    service_info = db.query(ApiServiceEntity).filter(ApiServiceEntity.id == service_code).first()
-    if service_info is None:
-        return {
-            'errcode': 1,
-            'errmsg': '服务不存在'
-    }
-    # print(service_info)
-    print('service_name', service_info.name)
-
-    # print('database_name', database_info.name, database_info.dsn)
-
-    command_info = db.query(CommandEntity).filter(CommandEntity.scope == service_code).first()
-    if command_info is None:
-        return {
-            'errcode': 1,
-            'errmsg': '查询命令没配置'
-    }
-
-
-    database_info = db.query(DatasourceEntity).filter(DatasourceEntity.id == command_info.datasource).first()
-    if database_info is None:
-        return {
-            'errcode': 1,
-            'errmsg': '数据库没配置'
-    }
-
-
-
-    sql = command_info.sqltext
-    print('sql ==== << ', sql)
-    meta_data = {}
-    # 从query获取参数
-    for (key, val) in query_list:
-        if key != 'access_token' and key !='page' and key != 'limit' and key !='create_time' and key != 'record_name'and key != 'record_cid':
-            sql = sql.replace("{" + key + "}", val.replace('全部',''))
-        elif key =='create_time':
-            if val != '全部':
-                sql = sql.replace("{" + key + "}", ('DATE_FORMAT(create_time,\'%Y-%m-%d\') =\''+val+'\' and')) #DATE_FORMAT(create_time,'%Y-%m-%d') ='{create_time}'
-            else:
-                sql = sql.replace("{" + key + "}", '')
-        elif key =='record_name' or  key =='record_cid' :
-            if val != '全部':
-                sql = sql.replace("{" + key + "}", (key+'=\''+val+'\' and')) 
-            else:
-                sql = sql.replace("{" + key + "}", '')
-        # elif key =='record_cid':
-        #     if val != '全部':
-        #         sql = sql.replace("{" + key + "}", ('record_name =\''+val+'\' and')) 
-        #     else:
-        #         print(111)
-        #         sql = sql.replace("{" + key + "}", '')
-        elif key =='page' :
-            meta_data['page'] = int(val)
-        elif key == 'limit':
-            meta_data['limit'] = int(val)
-
-    # 从json body获取参数
-    if len(body) > 0:
-        print(body)
-        for (key, val) in body.items():
-            if key == 'page1' and 'limit1' in body:
-                if val==-1:
-                    pass
-                if val == '':
-                    val = '0'
-                page1 = int(val)
-                limit1=int(body['limit1'])
-                print(str((page1-1)*limit1))
-                sql = sql.replace('{page1}',str((page1-1)*limit1))
-            # elif key == 'limit1' and 'page1' in body:
-            #     page1 = int(body['page1'])
-            #     limit1 = int(val)
-            #     sql.replace('{page1}', str((page1-1)*limit1))
-            elif key !='page' and key != 'limit' and key !='create_time':
-                if isinstance(val, int):
-                    val = str(val)
-                if isinstance(val, float):
-                    val = str(val)
-                sql = sql.replace("{" + key + "}", val)
-            elif key =='create_time':
-                if val != '全部':
-                    sql = sql.replace("{" + key + "}", ('DATE_FORMAT(create_time,\'%Y-%m-%d\') =\''+val+'\' and')) #DATE_FORMAT(create_time,'%Y-%m-%d') ='{create_time}'
-                else:
-                    sql = sql.replace("{" + key + "}", '')
-            elif key =='page' :
-                if isinstance(val, str):
-                    meta_data['page'] = int(val)
-                elif isinstance(val, int):
-                    meta_data['page'] = val
-            elif key == 'limit':
-                if isinstance(val, str):
-                    meta_data['limit'] = int(val)
-                elif isinstance(val, int):
-                    meta_data['limit'] = val
-            elif key == 'taskNo' and val=='':
-                print("他们又输入空的taskNo啦!")
-                return [{"text":"兄弟,taskNo不能为空"}]
-    print('sql ==== >> ', sql)
-
-    data = []
-    if database_info.dbtype == 'psycopg2':
-        '''
-        print(1111)
-        conn = psycopg2.connect(database_info.dsn)
-        cur = conn.cursor()
-        cur.execute(sql)
-        rows = cur.fetchall()
-
-        # 字段名列表
-        colnames = [desc[0] for desc in cur.description]
-
-        item = {}
-        for row in rows:
-            for col in range(len(colnames)):
-                field_name = colnames[col]
-                item[field_name] = row[col]
-            data.append(item) 
-
-        conn.close()
-        '''
-            
-
-    elif database_info.dbtype == 'pymysql':
-        # print(database_info)
-        conn = pymysql.connect(host=database_info.host,
-                               user=database_info.user,
-                               password=database_info.password,
-                               database=database_info.database,
-                               port=database_info.port,
-                               charset=database_info.charset)
-        cur = conn.cursor()
-        cur.execute(sql)
-        rows = cur.fetchall()
-
-        # 字段名列表
-        colnames = [desc[0] for desc in cur.description]
-        # print(colnames)
-        # print(rows)
-        pages = 1 #总页数
-        current = 1 #第几页
-        total = len(rows)
-        size = len(rows)
-        # print(size)
-        if 'page' in meta_data and 'limit' in meta_data:
-            current = meta_data['page']
-            size = meta_data['limit']
-
-            if (current == 0 or size == 0) and total != 0:
-                 current = 1
-                 size = 5
-            pages = total//size
-            if total%size !=0:
-                pages+=1
-    
-        start_index = (current-1)*size
-        end_index = current*size
-        if pages <= current :
-            # current = pages
-            if total == size :
-                end_index = (current-1)*size+total
-            elif total%size == 0:
-                end_index = current*size
-            else:
-                end_index = (current-1)*size+total%size
-            start_index = (current-1)*size
-        if total ==0:
-            start_index = end_index =0
-        # print(start_index,end_index)
-        for row in range(start_index,end_index):
-            item = {}
-            for col in range(len(colnames)):
-                field_name = colnames[col]
-                item[field_name] = rows[row][col]
-            data.append(item) 
-            # print(item)
-
-        # for row in rows:
-        #     item = {}
-        #     for col in range(len(colnames)):
-        #         field_name = colnames[col]
-        #         item[field_name] = row[col]
-        #     data.append(item) 
-            # print(item)
-        # print(data)
-        conn.close()
-
-    return {
-        'code': 0,
-        'errcode': 0,
-        'errmsg': '查询成功',
-        'data': {"list":data,
-        'pages':pages, # 总页数
-        'currentPage':current, #当前页数
-        # 'current':current,
-        # 'total' : total,
-        'total' : total, # 总数据量
-        # 'size':size,
-        'pageSize':size #页码
-         }
-    }
-    # import time
-    # print(time)
-    # if 'ispages' in body:
-    #     try:
-    #         if body['ispages']=='1':
-    #             return {
-    #                 'page': [{"total": pages,
-    #                          "page": current}],
-    #                 "data":data
-    #             }
-    #             # data['pages'] =
-    #     except:
-    #         pass
-    #
-    # #background_tasks.add_task(post_service_method, service_code, body, db)
-    #
-    # return  data
-
-
-
-
-
-class CreateApiServiceFrom(BaseModel):
-        datasource:str
-        name:str
-        sqltext:str
-        serviceid:str=None
-
-@router.post('/v2/create_api_service')
-async def create_api_service_v2(
-    form_data: CreateApiServiceFrom,
-    request: Request,
-    db: Session = Depends(get_db)
-):
-    # print(1)
-    data = await request.body()
-    body = data.decode(encoding='utf-8')
-    if len(body) > 0:
-        # print(body)
-        body = json.loads(body)
-
-
-    for param in ['datasource','name','sqltext']:
-        if param not in body or body[param]=='':
-            return f'{param}为必填参数,不能不存在'
-
-    # 数据库信息校验
-    datasource = form_data.datasource
-    database_info = db.query(DatasourceEntity).filter(DatasourceEntity.id == datasource).first()
-    if database_info is None:
-        return JSONResponse(status_code=410, content={
-            'errcode': 410,
-            'errmsg': f'数据库-{datasource}-不存在'
-        })
-
-    # 接口信息生成及获取
-    if form_data.serviceid is None:
-        serviceid = str(uuid.uuid4()).replace("-","")
-    else:
-        serviceid = form_data.serviceid
-    serviname = form_data.name
-    id = str(uuid.uuid4())
-    sqltext = form_data.sqltext
-
-    # 接口执行校验
-    if database_info.dbtype == 'pymysql':
-        try:
-            conn = pymysql.connect(host=database_info.host,
-                                   user=database_info.user,
-                                   password=database_info.password,
-                                   database=database_info.database,
-                                   port=database_info.port,
-                                   charset=database_info.charset)
-            cur = conn.cursor()
-
-            # 测试返回
-            sql = sqltext+' limit 0;'
-            print(sql)
-            cur.execute(sql)
-            colnames = [desc[0] for desc in cur.description]
-            cur.close()
-            conn.close()
-
-        except Exception as e:
-            # 捕获其他所有类型的异常
-            return f"未知错误: {e}"
-    else:
-        return '暂不支持除pymysql以外的驱动'
-
-    create_command = CommandEntity(id=id,sqltext=sqltext,datasource=datasource,scope=serviceid)
-    create_api = ApiServiceEntity(id=serviceid,name=serviname,status=1)
-    db.add(create_api)
-    db.add(create_command)
-    db.commit()
-    return [serviceid,colnames]
-
-
-
-@router.post('/v2/update_api_service')
-async def update_api_service_v2(
-    request: Request,
-    db: Session = Depends(get_db)
-):
-    # print(1)
-    data = await request.body()
-    body = data.decode(encoding='utf-8')
-    if len(body) > 0:
-        # print(body)
-        body = json.loads(body)
-
-    for param in ['id','datasource','scope','name','sqltext','status']:
-        if param not in body or body[param]=='':
-            return f'{param}为必填参数,不能不存在'
-
-    # 数据库信息校验
-    datasource = body['datasource']
-    database_info = db.query(DatasourceEntity).filter(DatasourceEntity.id == datasource).first()
-    if database_info is None:
-        return JSONResponse(status_code=410, content={
-            'errcode': 410,
-            'errmsg': f'数据库-{datasource}-不存在'
-        })
-
-
-
-    serviceid = body['scope']
-    # 判断请求接口是否存在
-    service_info = db.query(ApiServiceEntity).filter(ApiServiceEntity.id == serviceid).first()
-    if service_info is None:
-       return JSONResponse(status_code=410, content={
-           'code': 410,
-           'msg': f'servicecode{serviceid}服务不存在'
-       })
-    command_info = db.query(CommandEntity).filter(CommandEntity.scope == serviceid).first()
-    if command_info is None:
-        return JSONResponse(status_code=410, content={
-            'code': 410,
-            'msg': f'servicecode{serviceid}服务不存在'
-        })
+router.include_router(v1.router, dependencies=[Depends(valid_access_token_role)])
+router.include_router(sign_api.router, dependencies=[Depends(valid_access_token_role)])
 
-    # 接口信息生成及获取
-    id = body['id']
-    serviname = body['name']
-    sqltext = body['sqltext']
-    status = body['status']
-
-    # 接口执行校验
-    if database_info.dbtype == 'pymysql':
-        try:
-            conn = pymysql.connect(host=database_info.host,
-                                   user=database_info.user,
-                                   password=database_info.password,
-                                   database=database_info.database,
-                                   port=database_info.port,
-                                   charset=database_info.charset)
-            cur = conn.cursor()
-
-            # 测试返回
-            sql = sqltext+' limit 0;'
-            print(sql)
-            cur.execute(sql)
-            colnames = [desc[0] for desc in cur.description]
-            cur.close()
-            conn.close()
-
-        except Exception as e:
-            # 捕获其他所有类型的异常
-            return f"未知错误: {e}"
-    else:
-        return '暂不支持除pymysql以外的驱动'
-
-    # create_command = CommandEntity(id=id,sqltext=sqltext,datasource=datasource,scope=serviceid)
-    # create_api = ApiServiceEntity(id=serviceid,name=serviname,status=1)
-    # db.add(create_api)
-    # db.add(create_command)
-    service_info.id=serviceid
-    service_info.name=serviname
-    service_info.status=status
-
-    command_info.id=id
-    command_info.sqltext=sqltext
-    command_info.datasource=datasource
-    command_info.scope=serviceid
-
-
-    db.commit()
-    return [serviceid,colnames]
-
-
-
-
-@router.post('/v2/create_datasource')
-async def create_create_datasource_v2(
-    request: Request,
-    db: Session = Depends(get_db)
-):
-    # print(1)
-    data = await request.body()
-    body = data.decode(encoding='utf-8')
-    if len(body) > 0:
-        # print(body)
-        body = json.loads(body)
-
-    # 检查入参
-    for param in ['datasourcename','dbtype','host','user','password','database','port','charset']:
-        if param not in body or body[param]=='':
-            return f'{param}为必填参数,不能不存在'
-    datasource = str(uuid.uuid1())
-    datasourcename = body['datasourcename']
-    dbtype = body['dbtype']
-    if body['dbtype'] == 'pymysql':
-        host = body['host']
-        user = body['user']
-        password = body['password']
-        database = body['database']
-        port = body['port']
-        charset = body['charset']
-
-        # 数据库连通性校验
-        try:
-            conn = pymysql.connect(host=host,
-                                   user=user,
-                                   password=password,
-                                   database=database,
-                                   port=port,
-                                   charset=charset)
-            cur = conn.cursor()
-
-            # 测试返回
-            sql = ' select now()'
-            print(sql)
-            cur.execute(sql)
-            testresult = cur.fetchall()
-            cur.close()
-            conn.close()
-
-        except Exception as e:
-            # 捕获其他所有类型的异常
-            return f"未知错误: {e}"
-        create_datasource = DatasourceEntity(id=datasource, name=datasourcename, scope='', dbtype=dbtype, host=host,
-                                             user=user, password=password, database=database, port=port,
-                                             charset=charset)
-        db.add(create_datasource)
-    else:
-        return f'暂不支持{dbtype}'
-    db.commit()
-    return [datasource,testresult]
-
-
-@router.post('/v2/update_datasource')
-async def create_update_datasource_v2(
-    request: Request,
-    db: Session = Depends(get_db)
-):
-    # print(1)
-    data = await request.body()
-    body = data.decode(encoding='utf-8')
-    if len(body) > 0:
-        # print(body)
-        body = json.loads(body)
-
-    # 检查入参
-    for param in ['id','datasourcename','dbtype','host','user','password','database','port','charset']:
-        if param not in body or body[param]=='':
-            return f'{param}为必填参数,不能不存在'
-
-    # 根据数据库id检测是否存在
-    datasource = body['id']
-    database_info = db.query(DatasourceEntity).filter(DatasourceEntity.id == datasource).first()
-    if database_info is None:
-        return JSONResponse(status_code=410, content={
-            'errcode': 410,
-            'errmsg': f'数据库{datasource}不存在'
-        })
-    datasourcename = body['datasourcename']
-    dbtype = body['dbtype']
-    if body['dbtype'] == 'pymysql':
-        host = body['host']
-        user = body['user']
-        password = body['password']
-        database = body['database']
-        port = body['port']
-        charset = body['charset']
-
-        # 数据库连通性校验
-        try:
-            conn = pymysql.connect(host=host,
-                                   user=user,
-                                   password=password,
-                                   database=database,
-                                   port=port,
-                                   charset=charset)
-            cur = conn.cursor()
-
-            # 测试返回
-            sql = ' select now()'
-            print(sql)
-            cur.execute(sql)
-            testresult = cur.fetchall()
-            cur.close()
-            conn.close()
-
-        except Exception as e:
-            # 捕获其他所有类型的异常
-            return f"未知错误: {e}"
-        # create_datasource = DatasourceEntity(id=datasource, name=datasourcename, scope='', dbtype=dbtype, host=host,
-        #                                      user=user, password=password, database=database, port=port,
-        #                                      charset=charset)
-        # db.add(create_datasource)
-        database_info.name = datasourcename
-        database_info.dbtype = dbtype
-        database_info.host = host
-        database_info.user = user
-        database_info.password = password
-        database_info.database = database
-        database_info.port = port
-        database_info.charset = charset
-
-    else:
-        return f'暂不支持{dbtype}'
-    db.commit()
-    return [datasource,testresult]
-
-
-
-
-@router.get('/v2/{serviceid}')
-@router.post('/v2/{serviceid}')
-# @router.post('/v2/{servicecode}')
-async def v2(serviceid:str,request: Request, db: Session = Depends(get_db)):
-    # 获取请求头 servicecode
-    # if service_code is None:
-    # service_code = request.headers.get('servicecode')
-    # print(serviceid)
-    if serviceid is None:
-       return JSONResponse(status_code=410, content= {
-           'code': 410,
-           'msg': "请求头servicecode未传参"
-       })
-    # 判断请求接口是否存在
-    service_info = db.query(ApiServiceEntity).filter(ApiServiceEntity.id == serviceid).first()
-    if service_info is None:
-       return JSONResponse(status_code=410, content={
-           'code': 410,
-           'msg': 'servicecode服务不存在'
-       })
-    # print(service_code)
-
-
-    # 判断请求接口对应数据库是否存在
-    command_info = db.query(CommandEntity).filter(CommandEntity.scope == serviceid).first()
-    if command_info is None:
-        return JSONResponse(status_code=410, content={
-            'code': 410,
-            'msg': '查询命令没配置'
-        })
-
-    # 判断请求接口对应数据库是否存在
-    database_info = db.query(DatasourceEntity).filter(DatasourceEntity.id == command_info.datasource).first()
-    if database_info is None:
-        return JSONResponse(status_code=410, content={
-            'errcode': 410,
-            'errmsg': '数据库没配置'
-        })
-
-    # 获取接口对应的sql
-    sql = command_info.sqltext
-    print('sql ==== << ', sql)
-
-    # 从params和body获取参数
-    # query_list = parse.parse_qsl(str(request.query_params))
-    # print(query_list)
-    data = await request.body()
-    body = data.decode(encoding='utf-8')
-    size = 10
-    current = 1
-    if len(body) > 0:
-        body = json.loads(body)
-
-        # 分页器 页数和页码的设置
-
-        if "size" in body:
-            if isinstance(body['size'], str):
-                size = int(body['size'])
-            elif isinstance(body['size'], int):
-                size = body['size']
-            if size >100:
-                size = 100
-
-        if "current" in body:
-            if isinstance(body['current'], str):
-                current = int(body['current'])
-            elif isinstance(body['current'], int):
-                current = body['current']
-            if current<=0:
-                current=1
-        # 接口sql的参数替换
-        if 'query' in body:
-            for (key, val) in body['query'].items():
-                if isinstance(val, int):
-                    val = str(val)
-                if isinstance(val, float):
-                    val = str(val)
-                if contains_special_characters(val):
-                    return JSONResponse(status_code=411, content={
-                                'code': 411,
-                                'msg': f'参数{key}含特殊符号:;、&、$、#、\'、\\t、@、空格等'
-                            })
-                sql = sql.replace("{" + key + "}", val)
-
-
-    print('sql ==== >> ', sql)
-
-
-    data = []
-
-        # 数据库类型为mysql情况下
-    if database_info.dbtype == 'pymysql':
-        # 数据库连接
-
-        conn = pymysql.connect(host=database_info.host,
-                               user=database_info.user,
-                               password=database_info.password,
-                               database=database_info.database,
-                               port=database_info.port,
-                               charset=database_info.charset)
-        cur = conn.cursor()
-
-        # 查总数据量,分页数据处理
-        totalsql = f'select count(*) from ({sql})t'
-        print(totalsql)
-        cur.execute(totalsql)
-        total = cur.fetchone()[0]
-
-        pages,pagesmod = divmod(total, size)
-        print(total,pages,pagesmod)
-        if pagesmod!=0:
-            pages+=1
-        print(pages,pagesmod)
-        if total <size :
-            size = total
-
-
-        # 正式查询
-        sql = sql+f' limit {size*(current-1)}, {size};'
-        print(sql,size)
-        cur.execute(sql)
-        rows = cur.fetchall()
-        colnames = [desc[0] for desc in cur.description]
-        for row in range(len(rows)):
-            item = {}
-            for col in range(len(colnames)):
-                field_name = colnames[col]
-                item[field_name] = rows[row][col]
-            data.append(item)
-
-        # 数据库关闭
-        cur.close()
-        conn.close()
-
-
-
-    else:
-        return JSONResponse(status_code=410, content={
-        'code': 410,
-        'msg': '接口对应数据库暂不支持'
-    })
-
-    return {
-        'code': 0,
-        'msg': 'success',
-        'rows': data,
-         'pages': pages,  # 总页数
-         'currentPage': current,  # 当前页数
-         # 'current':current,
-         # 'total' : total,
-         'total': total,  # 总数据量
-         # 'size':size,
-         'pageSize': size  # 页码
-    }
-
-    # else:
-    #     return JSONResponse(status_code=410, content={
-    #         'code': 410,
-    #         'msg': 'body不能为空'
-    #     })
+# 数科代理,不用校验
+router.include_router(skwhp_api.router)

+ 802 - 0
routers/api/gateway/v1.py

@@ -0,0 +1,802 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+from re import S
+from fastapi import APIRouter, Request, Depends, Form, Body, File, UploadFile, BackgroundTasks
+from fastapi.responses import Response
+from fastapi.responses import JSONResponse
+from database import get_db
+from sqlalchemy.orm import Session
+from utils import *
+from models import *
+from urllib import parse
+from common.DBgetdata import dbgetdata
+from pprint import pprint
+from datetime import datetime
+# import pandas as pd
+from pydantic import BaseModel
+import sqlalchemy
+import pymysql
+import json,time,io
+import time, os
+import math
+import uuid
+import re
+from . import sign_api
+from . import skwhp_api
+
+router = APIRouter()
+
+router.include_router(sign_api.router)
+router.include_router(skwhp_api.router)
+
+def contains_special_characters(input_string, special_characters=";|&|$|#|'|@"):
+    """
+    判断字符串是否包含特殊符号。
+
+    :param input_string: 需要检查的字符串
+    :param special_characters: 特殊符号的字符串,多个符号用竖线 '|' 分隔
+    :return: 如果包含特殊符号返回 True,否则返回 False
+    """
+    # 创建正则表达式模式
+    pattern = re.compile('[' + re.escape(special_characters) + ']')
+
+    # 搜索字符串中的特殊符号
+    if pattern.search(input_string):
+        return True
+    return False
+
+
+
+
+
+@router.get('/v1/demo')
+@router.post('/v1/demo')
+async def test1(request: Request):
+    data = await request.body()
+    body = data.decode(encoding='utf-8')
+    # print(body)
+    if len(body) > 0:
+        body = json.loads(body)
+        print(body)
+        print([body,{'msg':'good'}])
+        return body
+    else:
+        return body
+
+
+@router.post('/create_api_service')
+async def create_api_service(
+    request: Request,
+    db: Session = Depends(get_db)
+):
+    # print(1)
+    data = await request.body()
+    body = data.decode(encoding='utf-8')
+    if len(body) > 0:
+        # print(body)
+        body = json.loads(body)
+    serviceid = str(uuid.uuid4()).replace("-","")
+    serviname = body['servicename']
+    datasource = str(uuid.uuid1())
+    sqlname = body['sqlname']
+    id = str(uuid.uuid4())
+    dbtype = body['dbtype']
+    if body['dbtype'] == 'pymysql':
+        host = body['host']
+        user = body['user']
+        password = body['password']
+        database = body['database']
+        port = body['port']
+        charset = body['charset']
+        create_datasource = DatasourceEntity(id= datasource,name=sqlname,scope=serviceid,dbtype=dbtype,host=host,user=user,password=password,database=database,port=port,charset=charset)
+        # db.add()
+        # print()
+        db.add(create_datasource)
+    sqltext = body['sqltext']
+    create_command = CommandEntity(id=id,sqltext=sqltext,datasource=datasource,scope=serviceid)
+    create_api = ApiServiceEntity(id=serviceid,name=serviname)
+    db.add(create_api)
+    db.add(create_command)
+    db.commit()
+    return serviceid
+
+
+@router.get('/v1/{service_code:path}')
+@router.post('/v1/{service_code:path}')
+async def v1(request: Request, background_tasks: BackgroundTasks, db: Session = Depends(get_db)):
+    service_code = request.path_params.get('service_code')
+    query_list = parse.parse_qsl(str(request.query_params))
+    print(query_list)
+    data = await request.body()
+    body = data.decode(encoding='utf-8')
+    if len(body) > 0:
+        print('[body]', body)
+        body = json.loads(body)
+    # print("1111",body)
+
+    service_info = db.query(ApiServiceEntity).filter(ApiServiceEntity.id == service_code).first()
+    if service_info is None:
+        return {
+            'errcode': 1,
+            'errmsg': '服务不存在'
+    }
+    # print(service_info)
+    print('service_name', service_info.name)
+
+    # print('database_name', database_info.name, database_info.dsn)
+
+    command_info = db.query(CommandEntity).filter(CommandEntity.scope == service_code).first()
+    if command_info is None:
+        return {
+            'errcode': 1,
+            'errmsg': '查询命令没配置'
+    }
+
+
+    database_info = db.query(DatasourceEntity).filter(DatasourceEntity.id == command_info.datasource).first()
+    if database_info is None:
+        return {
+            'errcode': 1,
+            'errmsg': '数据库没配置'
+    }
+
+
+
+    sql = command_info.sqltext
+    print('sql ==== << ', sql)
+    meta_data = {}
+    # 从query获取参数
+    for (key, val) in query_list:
+        if key != 'access_token' and key !='page' and key != 'limit' and key !='create_time' and key != 'record_name'and key != 'record_cid':
+            sql = sql.replace("{" + key + "}", val.replace('全部',''))
+        elif key =='create_time':
+            if val != '全部':
+                sql = sql.replace("{" + key + "}", ('DATE_FORMAT(create_time,\'%Y-%m-%d\') =\''+val+'\' and')) #DATE_FORMAT(create_time,'%Y-%m-%d') ='{create_time}'
+            else:
+                sql = sql.replace("{" + key + "}", '')
+        elif key =='record_name' or  key =='record_cid' :
+            if val != '全部':
+                sql = sql.replace("{" + key + "}", (key+'=\''+val+'\' and')) 
+            else:
+                sql = sql.replace("{" + key + "}", '')
+        # elif key =='record_cid':
+        #     if val != '全部':
+        #         sql = sql.replace("{" + key + "}", ('record_name =\''+val+'\' and')) 
+        #     else:
+        #         print(111)
+        #         sql = sql.replace("{" + key + "}", '')
+        elif key =='page' :
+            meta_data['page'] = int(val)
+        elif key == 'limit':
+            meta_data['limit'] = int(val)
+
+    # 从json body获取参数
+    if len(body) > 0:
+        print(body)
+        for (key, val) in body.items():
+            if key == 'page1' and 'limit1' in body:
+                if val==-1:
+                    pass
+                if val == '':
+                    val = '0'
+                page1 = int(val)
+                limit1=int(body['limit1'])
+                print(str((page1-1)*limit1))
+                sql = sql.replace('{page1}',str((page1-1)*limit1))
+            # elif key == 'limit1' and 'page1' in body:
+            #     page1 = int(body['page1'])
+            #     limit1 = int(val)
+            #     sql.replace('{page1}', str((page1-1)*limit1))
+            elif key !='page' and key != 'limit' and key !='create_time':
+                if isinstance(val, int):
+                    val = str(val)
+                if isinstance(val, float):
+                    val = str(val)
+                sql = sql.replace("{" + key + "}", val)
+            elif key =='create_time':
+                if val != '全部':
+                    sql = sql.replace("{" + key + "}", ('DATE_FORMAT(create_time,\'%Y-%m-%d\') =\''+val+'\' and')) #DATE_FORMAT(create_time,'%Y-%m-%d') ='{create_time}'
+                else:
+                    sql = sql.replace("{" + key + "}", '')
+            elif key =='page' :
+                if isinstance(val, str):
+                    meta_data['page'] = int(val)
+                elif isinstance(val, int):
+                    meta_data['page'] = val
+            elif key == 'limit':
+                if isinstance(val, str):
+                    meta_data['limit'] = int(val)
+                elif isinstance(val, int):
+                    meta_data['limit'] = val
+            elif key == 'taskNo' and val=='':
+                print("他们又输入空的taskNo啦!")
+                return [{"text":"兄弟,taskNo不能为空"}]
+    print('sql ==== >> ', sql)
+
+    data = []
+    if database_info.dbtype == 'psycopg2':
+        '''
+        print(1111)
+        conn = psycopg2.connect(database_info.dsn)
+        cur = conn.cursor()
+        cur.execute(sql)
+        rows = cur.fetchall()
+
+        # 字段名列表
+        colnames = [desc[0] for desc in cur.description]
+
+        item = {}
+        for row in rows:
+            for col in range(len(colnames)):
+                field_name = colnames[col]
+                item[field_name] = row[col]
+            data.append(item) 
+
+        conn.close()
+        '''
+            
+
+    elif database_info.dbtype == 'pymysql':
+        # print(database_info)
+        conn = pymysql.connect(host=database_info.host,
+                               user=database_info.user,
+                               password=database_info.password,
+                               database=database_info.database,
+                               port=database_info.port,
+                               charset=database_info.charset)
+        cur = conn.cursor()
+        cur.execute(sql)
+        rows = cur.fetchall()
+
+        # 字段名列表
+        colnames = [desc[0] for desc in cur.description]
+        # print(colnames)
+        # print(rows)
+        pages = 1 #总页数
+        current = 1 #第几页
+        total = len(rows)
+        size = len(rows)
+        # print(size)
+        if 'page' in meta_data and 'limit' in meta_data:
+            current = meta_data['page']
+            size = meta_data['limit']
+
+            if (current == 0 or size == 0) and total != 0:
+                 current = 1
+                 size = 5
+            pages = total//size
+            if total%size !=0:
+                pages+=1
+    
+        start_index = (current-1)*size
+        end_index = current*size
+        if pages <= current :
+            # current = pages
+            if total == size :
+                end_index = (current-1)*size+total
+            elif total%size == 0:
+                end_index = current*size
+            else:
+                end_index = (current-1)*size+total%size
+            start_index = (current-1)*size
+        if total ==0:
+            start_index = end_index =0
+        # print(start_index,end_index)
+        for row in range(start_index,end_index):
+            item = {}
+            for col in range(len(colnames)):
+                field_name = colnames[col]
+                item[field_name] = rows[row][col]
+            data.append(item) 
+            # print(item)
+
+        # for row in rows:
+        #     item = {}
+        #     for col in range(len(colnames)):
+        #         field_name = colnames[col]
+        #         item[field_name] = row[col]
+        #     data.append(item) 
+            # print(item)
+        # print(data)
+        conn.close()
+
+    return {
+        'code': 0,
+        'errcode': 0,
+        'errmsg': '查询成功',
+        'data': {"list":data,
+        'pages':pages, # 总页数
+        'currentPage':current, #当前页数
+        # 'current':current,
+        # 'total' : total,
+        'total' : total, # 总数据量
+        # 'size':size,
+        'pageSize':size #页码
+         }
+    }
+    # import time
+    # print(time)
+    # if 'ispages' in body:
+    #     try:
+    #         if body['ispages']=='1':
+    #             return {
+    #                 'page': [{"total": pages,
+    #                          "page": current}],
+    #                 "data":data
+    #             }
+    #             # data['pages'] =
+    #     except:
+    #         pass
+    #
+    # #background_tasks.add_task(post_service_method, service_code, body, db)
+    #
+    # return  data
+
+
+
+
+
+class CreateApiServiceFrom(BaseModel):
+        datasource:str
+        name:str
+        sqltext:str
+        serviceid:str=None
+
+@router.post('/v2/create_api_service')
+async def create_api_service_v2(
+    form_data: CreateApiServiceFrom,
+    request: Request,
+    db: Session = Depends(get_db)
+):
+    # print(1)
+    data = await request.body()
+    body = data.decode(encoding='utf-8')
+    if len(body) > 0:
+        # print(body)
+        body = json.loads(body)
+
+
+    for param in ['datasource','name','sqltext']:
+        if param not in body or body[param]=='':
+            return f'{param}为必填参数,不能不存在'
+
+    # 数据库信息校验
+    datasource = form_data.datasource
+    database_info = db.query(DatasourceEntity).filter(DatasourceEntity.id == datasource).first()
+    if database_info is None:
+        return JSONResponse(status_code=410, content={
+            'errcode': 410,
+            'errmsg': f'数据库-{datasource}-不存在'
+        })
+
+    # 接口信息生成及获取
+    if form_data.serviceid is None:
+        serviceid = str(uuid.uuid4()).replace("-","")
+    else:
+        serviceid = form_data.serviceid
+    serviname = form_data.name
+    id = str(uuid.uuid4())
+    sqltext = form_data.sqltext
+
+    # 接口执行校验
+    if database_info.dbtype == 'pymysql':
+        try:
+            conn = pymysql.connect(host=database_info.host,
+                                   user=database_info.user,
+                                   password=database_info.password,
+                                   database=database_info.database,
+                                   port=database_info.port,
+                                   charset=database_info.charset)
+            cur = conn.cursor()
+
+            # 测试返回
+            sql = sqltext+' limit 0;'
+            print(sql)
+            cur.execute(sql)
+            colnames = [desc[0] for desc in cur.description]
+            cur.close()
+            conn.close()
+
+        except Exception as e:
+            # 捕获其他所有类型的异常
+            return f"未知错误: {e}"
+    else:
+        return '暂不支持除pymysql以外的驱动'
+
+    create_command = CommandEntity(id=id,sqltext=sqltext,datasource=datasource,scope=serviceid)
+    create_api = ApiServiceEntity(id=serviceid,name=serviname,status=1)
+    db.add(create_api)
+    db.add(create_command)
+    db.commit()
+    return [serviceid,colnames]
+
+
+
+@router.post('/v2/update_api_service')
+async def update_api_service_v2(
+    request: Request,
+    db: Session = Depends(get_db)
+):
+    # print(1)
+    data = await request.body()
+    body = data.decode(encoding='utf-8')
+    if len(body) > 0:
+        # print(body)
+        body = json.loads(body)
+
+    for param in ['id','datasource','scope','name','sqltext','status']:
+        if param not in body or body[param]=='':
+            return f'{param}为必填参数,不能不存在'
+
+    # 数据库信息校验
+    datasource = body['datasource']
+    database_info = db.query(DatasourceEntity).filter(DatasourceEntity.id == datasource).first()
+    if database_info is None:
+        return JSONResponse(status_code=410, content={
+            'errcode': 410,
+            'errmsg': f'数据库-{datasource}-不存在'
+        })
+
+
+
+    serviceid = body['scope']
+    # 判断请求接口是否存在
+    service_info = db.query(ApiServiceEntity).filter(ApiServiceEntity.id == serviceid).first()
+    if service_info is None:
+       return JSONResponse(status_code=410, content={
+           'code': 410,
+           'msg': f'servicecode{serviceid}服务不存在'
+       })
+    command_info = db.query(CommandEntity).filter(CommandEntity.scope == serviceid).first()
+    if command_info is None:
+        return JSONResponse(status_code=410, content={
+            'code': 410,
+            'msg': f'servicecode{serviceid}服务不存在'
+        })
+
+    # 接口信息生成及获取
+    id = body['id']
+    serviname = body['name']
+    sqltext = body['sqltext']
+    status = body['status']
+
+    # 接口执行校验
+    if database_info.dbtype == 'pymysql':
+        try:
+            conn = pymysql.connect(host=database_info.host,
+                                   user=database_info.user,
+                                   password=database_info.password,
+                                   database=database_info.database,
+                                   port=database_info.port,
+                                   charset=database_info.charset)
+            cur = conn.cursor()
+
+            # 测试返回
+            sql = sqltext+' limit 0;'
+            print(sql)
+            cur.execute(sql)
+            colnames = [desc[0] for desc in cur.description]
+            cur.close()
+            conn.close()
+
+        except Exception as e:
+            # 捕获其他所有类型的异常
+            return f"未知错误: {e}"
+    else:
+        return '暂不支持除pymysql以外的驱动'
+
+    # create_command = CommandEntity(id=id,sqltext=sqltext,datasource=datasource,scope=serviceid)
+    # create_api = ApiServiceEntity(id=serviceid,name=serviname,status=1)
+    # db.add(create_api)
+    # db.add(create_command)
+    service_info.id=serviceid
+    service_info.name=serviname
+    service_info.status=status
+
+    command_info.id=id
+    command_info.sqltext=sqltext
+    command_info.datasource=datasource
+    command_info.scope=serviceid
+
+
+    db.commit()
+    return [serviceid,colnames]
+
+
+
+
+@router.post('/v2/create_datasource')
+async def create_create_datasource_v2(
+    request: Request,
+    db: Session = Depends(get_db)
+):
+    # print(1)
+    data = await request.body()
+    body = data.decode(encoding='utf-8')
+    if len(body) > 0:
+        # print(body)
+        body = json.loads(body)
+
+    # 检查入参
+    for param in ['datasourcename','dbtype','host','user','password','database','port','charset']:
+        if param not in body or body[param]=='':
+            return f'{param}为必填参数,不能不存在'
+    datasource = str(uuid.uuid1())
+    datasourcename = body['datasourcename']
+    dbtype = body['dbtype']
+    if body['dbtype'] == 'pymysql':
+        host = body['host']
+        user = body['user']
+        password = body['password']
+        database = body['database']
+        port = body['port']
+        charset = body['charset']
+
+        # 数据库连通性校验
+        try:
+            conn = pymysql.connect(host=host,
+                                   user=user,
+                                   password=password,
+                                   database=database,
+                                   port=port,
+                                   charset=charset)
+            cur = conn.cursor()
+
+            # 测试返回
+            sql = ' select now()'
+            print(sql)
+            cur.execute(sql)
+            testresult = cur.fetchall()
+            cur.close()
+            conn.close()
+
+        except Exception as e:
+            # 捕获其他所有类型的异常
+            return f"未知错误: {e}"
+        create_datasource = DatasourceEntity(id=datasource, name=datasourcename, scope='', dbtype=dbtype, host=host,
+                                             user=user, password=password, database=database, port=port,
+                                             charset=charset)
+        db.add(create_datasource)
+    else:
+        return f'暂不支持{dbtype}'
+    db.commit()
+    return [datasource,testresult]
+
+
+@router.post('/v2/update_datasource')
+async def create_update_datasource_v2(
+    request: Request,
+    db: Session = Depends(get_db)
+):
+    # print(1)
+    data = await request.body()
+    body = data.decode(encoding='utf-8')
+    if len(body) > 0:
+        # print(body)
+        body = json.loads(body)
+
+    # 检查入参
+    for param in ['id','datasourcename','dbtype','host','user','password','database','port','charset']:
+        if param not in body or body[param]=='':
+            return f'{param}为必填参数,不能不存在'
+
+    # 根据数据库id检测是否存在
+    datasource = body['id']
+    database_info = db.query(DatasourceEntity).filter(DatasourceEntity.id == datasource).first()
+    if database_info is None:
+        return JSONResponse(status_code=410, content={
+            'errcode': 410,
+            'errmsg': f'数据库{datasource}不存在'
+        })
+    datasourcename = body['datasourcename']
+    dbtype = body['dbtype']
+    if body['dbtype'] == 'pymysql':
+        host = body['host']
+        user = body['user']
+        password = body['password']
+        database = body['database']
+        port = body['port']
+        charset = body['charset']
+
+        # 数据库连通性校验
+        try:
+            conn = pymysql.connect(host=host,
+                                   user=user,
+                                   password=password,
+                                   database=database,
+                                   port=port,
+                                   charset=charset)
+            cur = conn.cursor()
+
+            # 测试返回
+            sql = ' select now()'
+            print(sql)
+            cur.execute(sql)
+            testresult = cur.fetchall()
+            cur.close()
+            conn.close()
+
+        except Exception as e:
+            # 捕获其他所有类型的异常
+            return f"未知错误: {e}"
+        # create_datasource = DatasourceEntity(id=datasource, name=datasourcename, scope='', dbtype=dbtype, host=host,
+        #                                      user=user, password=password, database=database, port=port,
+        #                                      charset=charset)
+        # db.add(create_datasource)
+        database_info.name = datasourcename
+        database_info.dbtype = dbtype
+        database_info.host = host
+        database_info.user = user
+        database_info.password = password
+        database_info.database = database
+        database_info.port = port
+        database_info.charset = charset
+
+    else:
+        return f'暂不支持{dbtype}'
+    db.commit()
+    return [datasource,testresult]
+
+
+
+
+@router.get('/v2/{serviceid}')
+@router.post('/v2/{serviceid}')
+# @router.post('/v2/{servicecode}')
+async def v2(serviceid:str,request: Request, db: Session = Depends(get_db)):
+    # 获取请求头 servicecode
+    # if service_code is None:
+    # service_code = request.headers.get('servicecode')
+    # print(serviceid)
+    if serviceid is None:
+       return JSONResponse(status_code=410, content= {
+           'code': 410,
+           'msg': "请求头servicecode未传参"
+       })
+    # 判断请求接口是否存在
+    service_info = db.query(ApiServiceEntity).filter(ApiServiceEntity.id == serviceid).first()
+    if service_info is None:
+       return JSONResponse(status_code=410, content={
+           'code': 410,
+           'msg': 'servicecode服务不存在'
+       })
+    # print(service_code)
+
+
+    # 判断请求接口对应数据库是否存在
+    command_info = db.query(CommandEntity).filter(CommandEntity.scope == serviceid).first()
+    if command_info is None:
+        return JSONResponse(status_code=410, content={
+            'code': 410,
+            'msg': '查询命令没配置'
+        })
+
+    # 判断请求接口对应数据库是否存在
+    database_info = db.query(DatasourceEntity).filter(DatasourceEntity.id == command_info.datasource).first()
+    if database_info is None:
+        return JSONResponse(status_code=410, content={
+            'errcode': 410,
+            'errmsg': '数据库没配置'
+        })
+
+    # 获取接口对应的sql
+    sql = command_info.sqltext
+    print('sql ==== << ', sql)
+
+    # 从params和body获取参数
+    # query_list = parse.parse_qsl(str(request.query_params))
+    # print(query_list)
+    data = await request.body()
+    body = data.decode(encoding='utf-8')
+    size = 10
+    current = 1
+    if len(body) > 0:
+        body = json.loads(body)
+
+        # 分页器 页数和页码的设置
+
+        if "size" in body:
+            if isinstance(body['size'], str):
+                size = int(body['size'])
+            elif isinstance(body['size'], int):
+                size = body['size']
+            if size >100:
+                size = 100
+
+        if "current" in body:
+            if isinstance(body['current'], str):
+                current = int(body['current'])
+            elif isinstance(body['current'], int):
+                current = body['current']
+            if current<=0:
+                current=1
+        # 接口sql的参数替换
+        if 'query' in body:
+            for (key, val) in body['query'].items():
+                if isinstance(val, int):
+                    val = str(val)
+                if isinstance(val, float):
+                    val = str(val)
+                if contains_special_characters(val):
+                    return JSONResponse(status_code=411, content={
+                                'code': 411,
+                                'msg': f'参数{key}含特殊符号:;、&、$、#、\'、\\t、@、空格等'
+                            })
+                sql = sql.replace("{" + key + "}", val)
+
+
+    print('sql ==== >> ', sql)
+
+
+    data = []
+
+        # 数据库类型为mysql情况下
+    if database_info.dbtype == 'pymysql':
+        # 数据库连接
+
+        conn = pymysql.connect(host=database_info.host,
+                               user=database_info.user,
+                               password=database_info.password,
+                               database=database_info.database,
+                               port=database_info.port,
+                               charset=database_info.charset)
+        cur = conn.cursor()
+
+        # 查总数据量,分页数据处理
+        totalsql = f'select count(*) from ({sql})t'
+        print(totalsql)
+        cur.execute(totalsql)
+        total = cur.fetchone()[0]
+
+        pages,pagesmod = divmod(total, size)
+        print(total,pages,pagesmod)
+        if pagesmod!=0:
+            pages+=1
+        print(pages,pagesmod)
+        if total <size :
+            size = total
+
+
+        # 正式查询
+        sql = sql+f' limit {size*(current-1)}, {size};'
+        print(sql,size)
+        cur.execute(sql)
+        rows = cur.fetchall()
+        colnames = [desc[0] for desc in cur.description]
+        for row in range(len(rows)):
+            item = {}
+            for col in range(len(colnames)):
+                field_name = colnames[col]
+                item[field_name] = rows[row][col]
+            data.append(item)
+
+        # 数据库关闭
+        cur.close()
+        conn.close()
+
+
+
+    else:
+        return JSONResponse(status_code=410, content={
+        'code': 410,
+        'msg': '接口对应数据库暂不支持'
+    })
+
+    return {
+        'code': 0,
+        'msg': 'success',
+        'rows': data,
+         'pages': pages,  # 总页数
+         'currentPage': current,  # 当前页数
+         # 'current':current,
+         # 'total' : total,
+         'total': total,  # 总数据量
+         # 'size':size,
+         'pageSize': size  # 页码
+    }
+
+    # else:
+    #     return JSONResponse(status_code=410, content={
+    #         'code': 410,
+    #         'msg': 'body不能为空'
+    #     })

+ 2 - 2
routers/apit/__init__.py

@@ -1,7 +1,7 @@
 #!/usr/bin/env python3
 # -*- coding: utf-8 -*-
 from fastapi import APIRouter, Request, Depends, Form
-from database import get_db
+from database import get_db_share
 from utils.StripTagsHTMLParser import *
 from sqlalchemy.orm import Session
 from datetime import datetime, timedelta
@@ -26,7 +26,7 @@ ALGORITHM = "HS256"
 pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
 
 
-def valid_access_token(Authorization: str = Header(..., alias="Authorization"), db: Session = Depends(get_db)) -> str:
+def valid_access_token(Authorization: str = Header(..., alias="Authorization"), db: Session = Depends(get_db_share)) -> str:
     try:
         access_token = Authorization.removeprefix("Bearer ")
         payload = jwt.decode(access_token, SECRET_KEY, algorithms=[ALGORITHM])

+ 4 - 4
routers/apit/topinfo.py

@@ -1,7 +1,7 @@
 #!/usr/bin/env python3
 # -*- coding: utf-8 -*-
 from fastapi import APIRouter, Request, Depends, Form, Body, File, UploadFile
-from database import get_db
+from database import get_db_share
 from utils.StripTagsHTMLParser import *
 from sqlalchemy.orm import Session
 from datetime import datetime, timedelta
@@ -21,7 +21,7 @@ router = APIRouter()
 @router.post('/insular/getVisitorInfo', description='提交访客信息')
 async def index(
     request: Request, 
-    db: Session = Depends(get_db),
+    db: Session = Depends(get_db_share),
     data: dict = Depends(remove_xss_json)
 ):
     try:
@@ -54,7 +54,7 @@ async def index(
 @router.post('/insular/getOtherCarInfo', description='其他车辆信息')
 async def index(
     request: Request, 
-    db: Session = Depends(get_db),
+    db: Session = Depends(get_db_share),
     data: dict = Depends(remove_xss_json)
 ):
     try:
@@ -87,7 +87,7 @@ async def index(
 @router.post('/emergencerecord/addOrUpdate', description='应急演练实施过程记录信息')
 async def index(
     request: Request, 
-    db: Session = Depends(get_db),
+    db: Session = Depends(get_db_share),
     data: dict = Depends(remove_xss_json)
 ):
     try:

+ 5 - 5
routers/apiz/base/__init__.py

@@ -1,7 +1,7 @@
 #!/usr/bin/env python3
 # -*- coding: utf-8 -*-
 from fastapi import APIRouter, Request, Depends, Form, Body, File, UploadFile
-from database import get_db
+from database import get_db_share
 from utils.StripTagsHTMLParser import *
 from sqlalchemy.orm import Session
 from datetime import datetime, timedelta
@@ -29,7 +29,7 @@ router.include_router(third.router, prefix="/third")
 @router.post('/staff/info', description='提交人员基础信息')
 async def post_company_industry(
     request: Request, 
-    db: Session = Depends(get_db),
+    db: Session = Depends(get_db_share),
     data: dict = Depends(remove_xss_json)
 ):
     try:
@@ -63,7 +63,7 @@ async def post_company_industry(
 @router.post('/duty/info', description='提交值班值守信息')
 async def post_duty_info(
     request: Request, 
-    db: Session = Depends(get_db),
+    db: Session = Depends(get_db_share),
     data: dict = Depends(remove_xss_json)
 ):
     try:
@@ -93,7 +93,7 @@ async def post_duty_info(
 @router.post('/video/info', description='提交摄像头信息')
 async def post_video_info(
     request: Request, 
-    db: Session = Depends(get_db),
+    db: Session = Depends(get_db_share),
     data: dict = Depends(remove_xss_json)
 ):
     try:
@@ -124,7 +124,7 @@ async def post_video_info(
 @router.post('/medium/menu', description='获取危险化学品目录信息表')
 async def get_medium_menu(
     request: Request, 
-    db: Session = Depends(get_db),
+    db: Session = Depends(get_db_share),
     data: dict = Depends(remove_xss_json)
 ):
     try:

+ 2 - 2
routers/apiz/base/certificate.py

@@ -1,7 +1,7 @@
 #!/usr/bin/env python3
 # -*- coding: utf-8 -*-
 from fastapi import APIRouter, Request, Depends, Form, Body, File, UploadFile
-from database import get_db
+from database import get_db_share
 from utils.StripTagsHTMLParser import *
 from sqlalchemy.orm import Session
 from datetime import datetime, timedelta
@@ -22,7 +22,7 @@ router = APIRouter()
 @router.post('/info', description='提交证照报告信息')
 async def post_certificate_info(
     request: Request, 
-    db: Session = Depends(get_db),
+    db: Session = Depends(get_db_share),
     data: dict = Depends(remove_xss_json)
 ):
     try:

+ 8 - 8
routers/apiz/base/company.py

@@ -1,7 +1,7 @@
 #!/usr/bin/env python3
 # -*- coding: utf-8 -*-
 from fastapi import APIRouter, Request, Depends, Form, Body, File, UploadFile
-from database import get_db
+from database import get_db_share
 from utils.StripTagsHTMLParser import *
 from sqlalchemy.orm import Session
 from datetime import datetime, timedelta
@@ -22,7 +22,7 @@ router = APIRouter()
 @router.post('/info', description='提交企业基本信息')
 async def post_company_info(
     request: Request, 
-    db: Session = Depends(get_db),
+    db: Session = Depends(get_db_share),
     data: dict = Depends(remove_xss_json)
 ):
     try:
@@ -52,7 +52,7 @@ async def post_company_info(
 @router.post('/industry', description='提交企业隶属化工行业关系表')
 async def post_company_industry(
     request: Request, 
-    db: Session = Depends(get_db),
+    db: Session = Depends(get_db_share),
     data: dict = Depends(remove_xss_json)
 ):
     try:
@@ -83,7 +83,7 @@ async def post_company_industry(
 @router.post('/keyIndustry', description='提交企业隶属重点行业关系表')
 async def post_company_industry(
     request: Request, 
-    db: Session = Depends(get_db),
+    db: Session = Depends(get_db_share),
     data: dict = Depends(remove_xss_json)
 ):
     try:
@@ -114,7 +114,7 @@ async def post_company_industry(
 @router.post('/chemical', description='提交危险化学品信息')
 async def post_company_chemical(
     request: Request, 
-    db: Session = Depends(get_db),
+    db: Session = Depends(get_db_share),
     data: dict = Depends(remove_xss_json)
 ):
     try:
@@ -145,7 +145,7 @@ async def post_company_chemical(
 @router.post('/crafts', description='提交企业重点监管工艺信息')
 async def post_company_crafts(
     request: Request, 
-    db: Session = Depends(get_db),
+    db: Session = Depends(get_db_share),
     data: dict = Depends(remove_xss_json)
 ):
     try:
@@ -177,7 +177,7 @@ async def post_company_crafts(
 @router.post('/accident', description='提交企业事故事件信息')
 async def post_company_accident(
     request: Request, 
-    db: Session = Depends(get_db),
+    db: Session = Depends(get_db_share),
     data: dict = Depends(remove_xss_json)
 ):
     try:
@@ -210,7 +210,7 @@ async def post_company_accident(
 @router.post('/supervise', description='提交企业“三同时”监管')
 async def post_company_supervise(
     request: Request, 
-    db: Session = Depends(get_db),
+    db: Session = Depends(get_db_share),
     data: dict = Depends(remove_xss_json)
 ):
     try:

+ 17 - 17
routers/apiz/base/device.py

@@ -1,7 +1,7 @@
 #!/usr/bin/env python3
 # -*- coding: utf-8 -*-
 from fastapi import APIRouter, Request, Depends, Form, Body, File, UploadFile
-from database import get_db
+from database import get_db_share
 from utils.StripTagsHTMLParser import *
 from sqlalchemy.orm import Session
 from datetime import datetime, timedelta
@@ -22,7 +22,7 @@ router = APIRouter()
 @router.post('/info', description='提交设备基本信息')
 async def post_device_info(
     request: Request, 
-    db: Session = Depends(get_db),
+    db: Session = Depends(get_db_share),
     data: dict = Depends(remove_xss_json)
 ):
     try:
@@ -53,7 +53,7 @@ async def post_device_info(
 @router.post('/special', description='提交特种设备登记信息')
 async def post_company_industry(
     request: Request, 
-    db: Session = Depends(get_db),
+    db: Session = Depends(get_db_share),
     data: dict = Depends(remove_xss_json)
 ):
     try:
@@ -84,7 +84,7 @@ async def post_company_industry(
 @router.post('/safetyValve', description='提交安全阀信息')
 async def post_company_industry(
     request: Request, 
-    db: Session = Depends(get_db),
+    db: Session = Depends(get_db_share),
     data: dict = Depends(remove_xss_json)
 ):
     try:
@@ -115,7 +115,7 @@ async def post_company_industry(
 @router.post('/burst', description='提交爆破片信息')
 async def post_company_industry(
     request: Request, 
-    db: Session = Depends(get_db),
+    db: Session = Depends(get_db_share),
     data: dict = Depends(remove_xss_json)
 ):
     try:
@@ -147,7 +147,7 @@ async def post_company_industry(
 @router.post('/instrument', description='提交安全仪表连锁信息')
 async def post_company_industry(
     request: Request, 
-    db: Session = Depends(get_db),
+    db: Session = Depends(get_db_share),
     data: dict = Depends(remove_xss_json)
 ):
     try:
@@ -179,7 +179,7 @@ async def post_company_industry(
 @router.post('/detect', description='提交监测设备信息')
 async def post_device_detect(
     request: Request, 
-    db: Session = Depends(get_db),
+    db: Session = Depends(get_db_share),
     data: dict = Depends(remove_xss_json)
 ):
     try:
@@ -212,7 +212,7 @@ async def post_device_detect(
 @router.post('/tank', description='提交储罐基础信息')
 async def post_device_tank(
     request: Request, 
-    db: Session = Depends(get_db),
+    db: Session = Depends(get_db_share),
     data: dict = Depends(remove_xss_json)
 ):
     try:
@@ -243,7 +243,7 @@ async def post_device_tank(
 @router.post('/system', description='提交装置基础信息')
 async def post_device_system(
     request: Request, 
-    db: Session = Depends(get_db),
+    db: Session = Depends(get_db_share),
     data: dict = Depends(remove_xss_json)
 ):
     try:
@@ -275,7 +275,7 @@ async def post_device_system(
 @router.post('/gas', description='提交气体泄漏点基础信息')
 async def post_device_gas(
     request: Request, 
-    db: Session = Depends(get_db),
+    db: Session = Depends(get_db_share),
     data: dict = Depends(remove_xss_json)
 ):
     try:
@@ -307,7 +307,7 @@ async def post_device_gas(
 @router.post('/store', description='提交仓库基础信息')
 async def post_device_store(
     request: Request, 
-    db: Session = Depends(get_db),
+    db: Session = Depends(get_db_share),
     data: dict = Depends(remove_xss_json)
 ):
     try:
@@ -338,7 +338,7 @@ async def post_device_store(
 @router.post('/dock', description='提交仓库基础信息')
 async def post_device_dock(
     request: Request, 
-    db: Session = Depends(get_db),
+    db: Session = Depends(get_db_share),
     data: dict = Depends(remove_xss_json)
 ):
     try:
@@ -370,7 +370,7 @@ async def post_device_dock(
 @router.post('/stop', description='提交设备停用记录表')
 async def post_device_stop(
     request: Request, 
-    db: Session = Depends(get_db),
+    db: Session = Depends(get_db_share),
     data: dict = Depends(remove_xss_json)
 ):
     try:
@@ -403,7 +403,7 @@ async def post_device_stop(
 @router.post('/medium', description='提交设备介质基础信息表')
 async def post_device_medium(
     request: Request, 
-    db: Session = Depends(get_db),
+    db: Session = Depends(get_db_share),
     data: dict = Depends(remove_xss_json)
 ):
     try:
@@ -434,7 +434,7 @@ async def post_device_medium(
 @router.post('/mediumGroup', description='提交设备介质组分信息表')
 async def post_device_mediumGroup(
     request: Request, 
-    db: Session = Depends(get_db),
+    db: Session = Depends(get_db_share),
     data: dict = Depends(remove_xss_json)
 ):
     try:
@@ -467,7 +467,7 @@ async def post_device_mediumGroup(
 @router.post('/systemMaintenance', description='提交装置大检修备案记录信息')
 async def post_systemMaintenance(
     request: Request, 
-    db: Session = Depends(get_db),
+    db: Session = Depends(get_db_share),
     data: dict = Depends(remove_xss_json)
 ):
     try:
@@ -498,7 +498,7 @@ async def post_systemMaintenance(
 @router.post('/systemStartStop', description='提交装置开停车管理信息')
 async def index(
     request: Request, 
-    db: Session = Depends(get_db),
+    db: Session = Depends(get_db_share),
     data: dict = Depends(remove_xss_json)
 ):
     try:

+ 8 - 8
routers/apiz/base/third.py

@@ -1,7 +1,7 @@
 #!/usr/bin/env python3
 # -*- coding: utf-8 -*-
 from fastapi import APIRouter, Request, Depends, Form, Body, File, UploadFile
-from database import get_db
+from database import get_db_share
 from utils.StripTagsHTMLParser import *
 from sqlalchemy.orm import Session
 from datetime import datetime, timedelta
@@ -23,7 +23,7 @@ router = APIRouter()
 @router.post('/company', description='提交第三方单位信息')
 async def index(
     request: Request, 
-    db: Session = Depends(get_db),
+    db: Session = Depends(get_db_share),
     data: dict = Depends(remove_xss_json)
 ):
     try:
@@ -54,7 +54,7 @@ async def index(
 @router.post('/staff', description='提交第三方人员基础信息')
 async def index(
     request: Request, 
-    db: Session = Depends(get_db),
+    db: Session = Depends(get_db_share),
     data: dict = Depends(remove_xss_json)
 ):
     try:
@@ -84,7 +84,7 @@ async def index(
 @router.post('/staffCertificate', description='提交第三方人员证书信息')
 async def index(
     request: Request, 
-    db: Session = Depends(get_db),
+    db: Session = Depends(get_db_share),
     data: dict = Depends(remove_xss_json)
 ):
     try:
@@ -114,7 +114,7 @@ async def index(
 @router.post('/companyCertificate', description='提交第三方单位资质信息')
 async def index(
     request: Request, 
-    db: Session = Depends(get_db),
+    db: Session = Depends(get_db_share),
     data: dict = Depends(remove_xss_json)
 ):
     try:
@@ -144,7 +144,7 @@ async def index(
 @router.post('/staffTrain', description='提交第三方单位安全教育培训信息')
 async def index(
     request: Request, 
-    db: Session = Depends(get_db),
+    db: Session = Depends(get_db_share),
     data: dict = Depends(remove_xss_json)
 ):
     try:
@@ -174,7 +174,7 @@ async def index(
 @router.post('/violation', description='提交第三方单位违规记录信息')
 async def index(
     request: Request, 
-    db: Session = Depends(get_db),
+    db: Session = Depends(get_db_share),
     data: dict = Depends(remove_xss_json)
 ):
     try:
@@ -205,7 +205,7 @@ async def index(
 @router.post('/serve', description='提交第三方单位服务记录信息')
 async def index(
     request: Request, 
-    db: Session = Depends(get_db),
+    db: Session = Depends(get_db_share),
     data: dict = Depends(remove_xss_json)
 ):
     try:

+ 2 - 2
routers/apiz/common.py

@@ -1,7 +1,7 @@
 #!/usr/bin/env python3
 # -*- coding: utf-8 -*-
 from fastapi import APIRouter, Request, Depends, Form, Body, File, UploadFile
-from database import get_db
+from database import get_db_share
 from utils.StripTagsHTMLParser import *
 from sqlalchemy.orm import Session
 from datetime import datetime, timedelta
@@ -28,7 +28,7 @@ async def uploadfile(
     from_type: str = Body(...),
     file: UploadFile = File(...),
     client_id = Depends(valid_access_token),
-    db: Session = Depends(get_db)
+    db: Session = Depends(get_db_share)
 ):
     try:
         file_name = file.filename

+ 5 - 5
routers/apiz/danger/__init__.py

@@ -1,7 +1,7 @@
 #!/usr/bin/env python3
 # -*- coding: utf-8 -*-
 from fastapi import APIRouter, Request, Depends, Form, Body, File, UploadFile
-from database import get_db
+from database import get_db_share
 from utils.StripTagsHTMLParser import *
 from sqlalchemy.orm import Session
 from datetime import datetime, timedelta
@@ -27,7 +27,7 @@ router.include_router(monitor.router, prefix="/monitor")
 @router.post('/process/info', description='提交生产过程基础信息')
 async def index(
     request: Request, 
-    db: Session = Depends(get_db),
+    db: Session = Depends(get_db_share),
     data: dict = Depends(remove_xss_json)
 ):
     try:
@@ -58,7 +58,7 @@ async def index(
 @router.post('/manager/promise', description='提交安全承诺详情数据')
 async def index(
     request: Request, 
-    db: Session = Depends(get_db),
+    db: Session = Depends(get_db_share),
     data: dict = Depends(remove_xss_json)
 ):
     try:
@@ -88,7 +88,7 @@ async def index(
 @router.post('/warn/event', description='提交预警推送数据表')
 async def index(
     request: Request, 
-    db: Session = Depends(get_db),
+    db: Session = Depends(get_db_share),
     data: dict = Depends(remove_xss_json)
 ):
     try:
@@ -117,7 +117,7 @@ async def index(
 @router.post('/warn/getEvent', description='获取预警推送数据表')
 async def index(
     request: Request, 
-    db: Session = Depends(get_db),
+    db: Session = Depends(get_db_share),
     data: dict = Depends(remove_xss_json)
 ):
     try:

+ 5 - 5
routers/apiz/danger/monitor.py

@@ -1,7 +1,7 @@
 #!/usr/bin/env python3
 # -*- coding: utf-8 -*-
 from fastapi import APIRouter, Request, Depends, Form, Body, File, UploadFile
-from database import get_db
+from database import get_db_share
 from utils.StripTagsHTMLParser import *
 from sqlalchemy.orm import Session
 from datetime import datetime, timedelta
@@ -22,7 +22,7 @@ router = APIRouter()
 @router.post('/info', description='提交监测指标信息')
 async def index(
     request: Request, 
-    db: Session = Depends(get_db),
+    db: Session = Depends(get_db_share),
     data: dict = Depends(remove_xss_json)
 ):
     try:
@@ -53,7 +53,7 @@ async def index(
 @router.post('/data', description='提交监测指标实时数据')
 async def index(
     request: Request, 
-    db: Session = Depends(get_db),
+    db: Session = Depends(get_db_share),
     data: dict = Depends(remove_xss_json)
 ):
     try:
@@ -85,7 +85,7 @@ async def index(
 @router.post('/alarm', description='提交报警数据报文')
 async def index(
     request: Request, 
-    db: Session = Depends(get_db),
+    db: Session = Depends(get_db_share),
     data: dict = Depends(remove_xss_json)
 ):
     try:
@@ -118,7 +118,7 @@ async def index(
 @router.post('/video', description='提交视频智能监控数据')
 async def index(
     request: Request, 
-    db: Session = Depends(get_db),
+    db: Session = Depends(get_db_share),
     data: dict = Depends(remove_xss_json)
 ):
     try:

+ 6 - 6
routers/apiz/danger/source.py

@@ -1,7 +1,7 @@
 #!/usr/bin/env python3
 # -*- coding: utf-8 -*-
 from fastapi import APIRouter, Request, Depends, Form, Body, File, UploadFile
-from database import get_db
+from database import get_db_share
 from utils.StripTagsHTMLParser import *
 from sqlalchemy.orm import Session
 from datetime import datetime, timedelta
@@ -24,7 +24,7 @@ router = APIRouter()
 @router.post('/evaluation', description='提交评价/评估报告')
 async def index(
     request: Request, 
-    db: Session = Depends(get_db),
+    db: Session = Depends(get_db_share),
     data: dict = Depends(remove_xss_json)
 ):
     try:
@@ -55,7 +55,7 @@ async def index(
 @router.post('/info', description='提交企业重大危险源信息')
 async def index(
     request: Request, 
-    db: Session = Depends(get_db),
+    db: Session = Depends(get_db_share),
     data: dict = Depends(remove_xss_json)
 ):
     try:
@@ -88,7 +88,7 @@ async def index(
 @router.post('/problem', description='提交隐患信息')
 async def index(
     request: Request, 
-    db: Session = Depends(get_db),
+    db: Session = Depends(get_db_share),
     data: dict = Depends(remove_xss_json)
 ):
     try:
@@ -120,7 +120,7 @@ async def index(
 @router.post('/takes', description='提交危险源包责任人信息')
 async def index(
     request: Request, 
-    db: Session = Depends(get_db),
+    db: Session = Depends(get_db_share),
     data: dict = Depends(remove_xss_json)
 ):
     try:
@@ -151,7 +151,7 @@ async def index(
 @router.post('/warning', description='提交安全风险评估与管控信息')
 async def index(
     request: Request, 
-    db: Session = Depends(get_db),
+    db: Session = Depends(get_db_share),
     data: dict = Depends(remove_xss_json)
 ):
     try:

+ 8 - 8
routers/apiz/inspection.py

@@ -1,7 +1,7 @@
 #!/usr/bin/env python3
 # -*- coding: utf-8 -*-
 from fastapi import APIRouter, Request, Depends, Form, Body, File, UploadFile
-from database import get_db
+from database import get_db_share
 from utils.StripTagsHTMLParser import *
 from sqlalchemy.orm import Session
 from datetime import datetime, timedelta
@@ -22,7 +22,7 @@ router = APIRouter()
 @router.post('/plan/info', description='提交巡检计划信息')
 async def index(
     request: Request, 
-    db: Session = Depends(get_db),
+    db: Session = Depends(get_db_share),
     data: dict = Depends(remove_xss_json)
 ):
     try:
@@ -54,7 +54,7 @@ async def index(
 @router.post('/plan/point', description='提交巡检节点信息')
 async def index(
     request: Request, 
-    db: Session = Depends(get_db),
+    db: Session = Depends(get_db_share),
     data: dict = Depends(remove_xss_json)
 ):
     try:
@@ -86,7 +86,7 @@ async def index(
 @router.post('/plan/rule', description='提交检查标准信息')
 async def index(
     request: Request, 
-    db: Session = Depends(get_db),
+    db: Session = Depends(get_db_share),
     data: dict = Depends(remove_xss_json)
 ):
     try:
@@ -118,7 +118,7 @@ async def index(
 @router.post('/plan/obj', description='提交巡检对象信息')
 async def index(
     request: Request, 
-    db: Session = Depends(get_db),
+    db: Session = Depends(get_db_share),
     data: dict = Depends(remove_xss_json)
 ):
     try:
@@ -152,7 +152,7 @@ async def index(
 @router.post('/task/info', description='提交巡检任务信息')
 async def index(
     request: Request, 
-    db: Session = Depends(get_db),
+    db: Session = Depends(get_db_share),
     data: dict = Depends(remove_xss_json)
 ):
     try:
@@ -185,7 +185,7 @@ async def index(
 @router.post('/task/record', description='提交巡检记录信息')
 async def index(
     request: Request, 
-    db: Session = Depends(get_db),
+    db: Session = Depends(get_db_share),
     data: dict = Depends(remove_xss_json)
 ):
     try:
@@ -218,7 +218,7 @@ async def index(
 @router.post('/task/statistics', description='提交巡检统计结果信息')
 async def index(
     request: Request, 
-    db: Session = Depends(get_db),
+    db: Session = Depends(get_db_share),
     data: dict = Depends(remove_xss_json)
 ):
     try:

+ 6 - 6
routers/apiz/location.py

@@ -1,7 +1,7 @@
 #!/usr/bin/env python3
 # -*- coding: utf-8 -*-
 from fastapi import APIRouter, Request, Depends, Form, Body, File, UploadFile
-from database import get_db
+from database import get_db_share
 from utils.StripTagsHTMLParser import *
 from sqlalchemy.orm import Session
 from datetime import datetime, timedelta
@@ -21,7 +21,7 @@ router = APIRouter()
 @router.post('/zone/info', description='提交区域边界信息')
 async def index(
     request: Request, 
-    db: Session = Depends(get_db),
+    db: Session = Depends(get_db_share),
     data: dict = Depends(remove_xss_json)
 ):
     try:
@@ -53,7 +53,7 @@ async def index(
 @router.post('/person/info', description='提交人员定位实时数据')
 async def index(
     request: Request, 
-    db: Session = Depends(get_db),
+    db: Session = Depends(get_db_share),
     data: dict = Depends(remove_xss_json)
 ):
     try:
@@ -83,7 +83,7 @@ async def index(
 @router.post('/person/area', description='提交人员聚集数据')
 async def index(
     request: Request, 
-    db: Session = Depends(get_db),
+    db: Session = Depends(get_db_share),
     data: dict = Depends(remove_xss_json)
 ):
     try:
@@ -114,7 +114,7 @@ async def index(
 @router.post('/person/alarm', description='提交人员报警数据')
 async def index(
     request: Request, 
-    db: Session = Depends(get_db),
+    db: Session = Depends(get_db_share),
     data: dict = Depends(remove_xss_json)
 ):
     try:
@@ -145,7 +145,7 @@ async def index(
 @router.post('/zone/alarm', description='提交区域报警数据')
 async def index(
     request: Request, 
-    db: Session = Depends(get_db),
+    db: Session = Depends(get_db_share),
     data: dict = Depends(remove_xss_json)
 ):
     try:

+ 4 - 4
routers/apiz/oauth.py

@@ -1,7 +1,7 @@
 #!/usr/bin/env python3
 # -*- coding: utf-8 -*-
 from fastapi import APIRouter, Request, Depends, Form
-from database import get_db
+from database import get_db_share
 from utils.StripTagsHTMLParser import *
 from sqlalchemy.orm import Session
 from datetime import datetime, timedelta
@@ -28,7 +28,7 @@ async def login(
     request: Request, 
     client_id: str = Form(..., description=''),
     client_secret: str = Form(..., description=''),
-    db: Session = Depends(get_db)
+    db: Session = Depends(get_db_share)
 ):
     hash_passwd = pwd_context.hash(client_secret)
     return {
@@ -44,7 +44,7 @@ async def login(
     client_secret: str = Form(..., description=''),
     grant_type: str = Form(..., description=''),
     scope: str = Form(..., description=''),
-    db: Session = Depends(get_db)
+    db: Session = Depends(get_db_share)
 ):
     app = authenticate_app(db, client_id, client_secret)
     if not app:
@@ -92,7 +92,7 @@ def create_access_token(*, data: dict, expires_delta: timedelta = None):
     return encoded_jwt
 
 
-def valid_access_token(Authorization: str = Header(..., alias="Authorization"), db: Session = Depends(get_db)) -> str:
+def valid_access_token(Authorization: str = Header(..., alias="Authorization"), db: Session = Depends(get_db_share)) -> str:
     try:
         access_token = Authorization.removeprefix("Bearer ")
         payload = jwt.decode(access_token, SECRET_KEY, algorithms=[ALGORITHM])

+ 5 - 5
routers/apiz/worker.py

@@ -1,7 +1,7 @@
 #!/usr/bin/env python3
 # -*- coding: utf-8 -*-
 from fastapi import APIRouter, Request, Depends, Form, Body, File, UploadFile
-from database import get_db
+from database import get_db_share
 from utils.StripTagsHTMLParser import *
 from sqlalchemy.orm import Session
 from datetime import datetime, timedelta
@@ -22,7 +22,7 @@ router = APIRouter()
 @router.post('/activity/info', description='提交特殊作业活动信息')
 async def index(
     request: Request, 
-    db: Session = Depends(get_db),
+    db: Session = Depends(get_db_share),
     data: dict = Depends(remove_xss_json)
 ):
     try:
@@ -53,7 +53,7 @@ async def index(
 @router.post('/ticket/info', description='提交特殊作业票证信息表')
 async def index(
     request: Request, 
-    db: Session = Depends(get_db),
+    db: Session = Depends(get_db_share),
     data: dict = Depends(remove_xss_json)
 ):
     try:
@@ -83,7 +83,7 @@ async def index(
 @router.post('/jsa/analysis', description='提交气体分析信息')
 async def index(
     request: Request, 
-    db: Session = Depends(get_db),
+    db: Session = Depends(get_db_share),
     data: dict = Depends(remove_xss_json)
 ):
     try:
@@ -114,7 +114,7 @@ async def index(
 @router.post('/ticket/video', description='提交作业票和视频关联表')
 async def index(
     request: Request, 
-    db: Session = Depends(get_db),
+    db: Session = Depends(get_db_share),
     data: dict = Depends(remove_xss_json)
 ):
     try: