xuguoyang 6 months ago
parent
commit
8aa7a6744f
1 changed files with 50 additions and 40 deletions
  1. 50 40
      routers/api/dataFilling/__init__.py

+ 50 - 40
routers/api/dataFilling/__init__.py

@@ -1,6 +1,6 @@
 from fastapi import FastAPI, HTTPException, Depends, APIRouter,Query,Body
 from sqlalchemy.engine.reflection import Inspector
-
+from common.security import valid_access_token
 from pydantic import BaseModel,Extra
 from datetime import datetime
 from typing import List, Optional,Any,Dict
@@ -26,7 +26,7 @@ class ReportCreate(BaseModel):
     issued_status: str
     period_type: str
     creator_name: str
-    creator_id: int
+    # creator_id: int
     creator_phone:str
     # num_reporters:int
     field_names: List[str]  # 用户只传递字段名称
@@ -63,7 +63,8 @@ class TableStructure(BaseModel):
 @router.get("/report_structure/{report_id}")
 async def get_report_structure(
     report_id: str,
-    db: Session = Depends(get_db)
+    db: Session = Depends(get_db),
+    creator_id = Depends(valid_access_token)
 ):
     # 查询 ReportManagement 表以获取 data_table_name
     report = db.query(ReportManagement).filter(ReportManagement.report_id == report_id).first()
@@ -129,7 +130,7 @@ async def get_report_structure(
             "period_type": report.period_type,
             "creator_name": report.creator_name,
             "num_reporters": distinct_users,
-            "creator_id": report.creator_id,
+            "creator_id": creator_id,
             "created_at": report.created_at,
             "updated_at": report.updated_at,
             "num_reported":num_reported,
@@ -175,7 +176,9 @@ def create_dynamic_table(table_name: str, field_names: List[str], db: Session):
 
 # 新建填报和创建新表的接口
 @router.post("/report/")
-def create_report_and_table(report: ReportCreate, db: Session = Depends(get_db)):
+def create_report_and_table(report: ReportCreate, db: Session = Depends(get_db),
+                            creator_id = Depends(valid_access_token)
+                            ):
     try:
         # 获取当前时间并格式化为 YYYYMMDDHHMMSS
         current_time_str = datetime.now().strftime("%Y%m%d%H%M%S")
@@ -200,7 +203,7 @@ def create_report_and_table(report: ReportCreate, db: Session = Depends(get_db))
             collection_status=0,#未收取
             period_type=report.period_type,
             creator_name=report.creator_name,
-            creator_id=report.creator_id,
+            creator_id=creator_id,
             creator_phone=report.creator_phone,
             num_reporters = len(report.user_ids)
         )
@@ -225,7 +228,7 @@ def create_report_and_table(report: ReportCreate, db: Session = Depends(get_db))
 
 
 class ReportQuery(BaseModel):
-    creator_id: str  # 创建者ID,必须提供
+    # creator_id: str  # 创建者ID,必须提供
     table_name: Optional[str] = None
     status: Optional[List[int]] = None
     start_time: Optional[datetime] = None
@@ -238,13 +241,12 @@ class ReportQuery(BaseModel):
 @router.get("/select")
 async def select_report(
         db: Session = Depends(get_db),
-        query: ReportQuery = Depends()
+        query: ReportQuery = Depends(),
+        creator_id = Depends(valid_access_token)
 ):
-    # 检查 creator_id 是否提供
-    if not query.creator_id:
-        raise HTTPException(status_code=400, detail="创建者ID是必填项")
 
-    data_query = db.query(ReportManagement).filter(ReportManagement.creator_id == query.creator_id)
+
+    data_query = db.query(ReportManagement).filter(ReportManagement.creator_id == creator_id)
 
     # 过滤条件
     if query.table_name:
@@ -281,7 +283,7 @@ async def select_report(
             "issued_status": item.issued_status,
             "period_type": item.period_type,
             "creator_name": item.creator_name,
-            "creator_id": item.creator_id,
+            "creator_id": creator_id,
             "created_at": item.created_at,
             "creator_phone": item.creator_phone,
             "updated_at": item.updated_at,
@@ -318,14 +320,15 @@ class ReportUpdate(BaseModel):
     period_type: str = None
     end_time: str = None
     comments: dict = None  # 字典,键为字段名,值为新的备注
-    creator_id: str = None
+    # creator_id: str = None
 
 #修改
 @router.put("/report/{report_id}/")
 async def update_report(
     report_id: str,
     update_data: ReportUpdate,
-    db: Session = Depends(get_db)
+    db: Session = Depends(get_db),
+    creator_id = Depends(valid_access_token)
 ):
     # 查询要修改的记录
     report = db.query(ReportManagement).filter(ReportManagement.report_id == report_id).first()
@@ -333,7 +336,7 @@ async def update_report(
         raise HTTPException(status_code=404, detail="Report not found")
 
         # 验证请求者ID
-    if report.creator_id != update_data.creator_id:
+    if report.creator_id != creator_id:
         raise HTTPException(status_code=403, detail="没有权限更新此报告")
 
     # 更新字段
@@ -371,7 +374,7 @@ async def update_report(
 
 
 class TaskQuery(BaseModel):
-    user_id: str
+    # user_id: str
     submission_status: Optional[List[int]] = None
     table_name: Optional[str] = None
 
@@ -379,17 +382,18 @@ class TaskQuery(BaseModel):
 @router.get("/my_filling")
 async def get_user_tasks(
     db: Session = Depends(get_db),
-    query: TaskQuery = Body(...)
+    query: TaskQuery = Body(...),
+    user_id = Depends(valid_access_token)
 ):
     # 检查用户ID是否提供
-    if not query.user_id:
+    if not user_id:
         raise HTTPException(status_code=400, detail="用户ID是必填项")
 
     # 查询用户的所有任务信息
     user_tasks = db.query(ReportManagement, FormSubmission).join(
         FormSubmission, ReportManagement.report_id == FormSubmission.report_id
     ).filter(
-        FormSubmission.user_id == query.user_id
+        FormSubmission.user_id == user_id
     )
 
     # 如果提供了填报结果列表,则过滤结果
@@ -406,7 +410,7 @@ async def get_user_tasks(
     result_items = []
     for report, submission in tasks:
         result_item = {
-            "user_id":query.user_id,
+            "user_id":user_id,
             "table_name": report.table_name,
             "report_id": report.report_id,
             "submission_status": submission.submission_status,
@@ -427,8 +431,9 @@ async def get_user_tasks(
 @router.post("/report_fields")
 async def get_report_fields(
     db: Session = Depends(get_db),
-    user_id: str = Query(None, description="用户ID"),
-    report_id: str = Query(None, description="填报ID")
+    # user_id: str = Query(None, description="用户ID"),
+    report_id: str = Query(None, description="填报ID"),
+    user_id = Depends(valid_access_token)
 ):
     # 检查用户ID和填报ID是否提供
     if not user_id or not report_id:
@@ -483,7 +488,7 @@ class DataItem(BaseModel):
     pass  # 用于动态接收键值对
 
 class SubmitData(BaseModel):
-    user_id: int
+    # user_id: int
     report_id: str
     data: List[Dict[str, str]]  # 数据列表,每个元素是一个字典,包含字段名和值
 
@@ -494,10 +499,11 @@ class SubmitData(BaseModel):
 @router.post("/submit_data")
 async def submit_data(
     db: Session = Depends(get_db),
-    submit_data: SubmitData = Body(...)
+    submit_data: SubmitData = Body(...),
+    user_id = Depends(valid_access_token)
 ):
     # 检查用户ID和填报ID是否提供
-    if not submit_data.user_id or not submit_data.report_id:
+    if not user_id or not submit_data.report_id:
         raise HTTPException(status_code=400, detail="用户ID和填报ID是必填项")
 
     # 获取对应填报ID的数据表名称
@@ -512,7 +518,7 @@ async def submit_data(
     # 检查用户是否有权限填报
     submission = db.query(FormSubmission).filter(
         FormSubmission.report_id == submit_data.report_id,
-        FormSubmission.user_id == str(submit_data.user_id)  # 确保user_id是字符串类型
+        FormSubmission.user_id == str(user_id)  # 确保user_id是字符串类型
     ).first()
     if not submission:
         raise HTTPException(status_code=403, detail="用户没有填报权限")
@@ -522,7 +528,7 @@ async def submit_data(
         # 构造插入SQL语句
         columns = ', '.join(list(item.keys()) + ['create_id', 'user_id', 'collect_status'])
         values = ', '.join(
-            [f":{k}" for k in item.keys()] + [f"'{report.creator_id}'", f"'{submit_data.user_id}'", '1'])
+            [f":{k}" for k in item.keys()] + [f"'{report.creator_id}'", f"'{user_id}'", '1'])
         sql = f"INSERT INTO {data_table_name} ({columns}) VALUES ({values})"
         print(sql)
         # 执行插入操作
@@ -543,17 +549,18 @@ async def submit_data(
 
 
 class SubmissionQuery(BaseModel):
-    user_id: int  # 用户ID,必须是整数
+    # user_id: int  # 用户ID,必须是整数
     report_id: str  # 填报ID,必须是字符串
 
 
 @router.post("/submission_status")
 async def get_submission_status(
     db: Session = Depends(get_db),
-        query: SubmissionQuery = Body(...)
+        query: SubmissionQuery = Body(...),
+    user_id = Depends(valid_access_token)
 ):
     # 检查用户ID和填报ID是否提供
-    if not query.user_id or not query.report_id:
+    if not user_id or not query.report_id:
         raise HTTPException(status_code=400, detail="用户ID和填报ID是必填项")
 
     # 获取对应填报ID的数据表名称
@@ -568,7 +575,7 @@ async def get_submission_status(
     # 获取填报情况
     submission = db.query(FormSubmission).filter(
         FormSubmission.report_id == query.report_id,
-        FormSubmission.user_id == str(query.user_id)  # 确保user_id是字符串类型
+        FormSubmission.user_id == str(user_id)  # 确保user_id是字符串类型
     ).first()
     if not submission:
         raise HTTPException(status_code=404, detail="未找到对应的填报情况")
@@ -596,7 +603,7 @@ async def get_submission_status(
     query_sql = text(f"""
                 SELECT * FROM {data_table_name} WHERE user_id = :user_id
             """)
-    result = db.execute(query_sql, {"user_id": query.user_id})
+    result = db.execute(query_sql, {"user_id": user_id})
     rows = result.fetchall()
 
     # 添加字段名和字段注释作为第一行
@@ -636,9 +643,10 @@ def has_matching_column_comments(
 
 @router.post("/reports_by_creator/")
 async def get_reports_by_creator(
-        creator_id: str,  # 精确匹配的必选参数
+        # creator_id: str,  # 精确匹配的必选参数
         field_comment: Optional[str] = Query(None, description="Optional comment of the field to match"),
-        db: Session = Depends(get_db)
+        db: Session = Depends(get_db),
+        creator_id = Depends(valid_access_token)
 ):
     # 获取数据库Inspector
     inspector: Inspector = inspect(db.bind)
@@ -676,10 +684,11 @@ async def get_reports_by_creator(
 
 @router.put("/update_collection_status/")
 async def update_collection_status(
-    creator_id: str,
+    # creator_id: str,
     report_id: str,
     new_status: int = Query(..., description="New collection status, must be 0, 1, or 2"),
-    db: Session = Depends(get_db)
+    db: Session = Depends(get_db),
+    creator_id = Depends(valid_access_token)
 ):
     # 检查 new_status 是否为允许的值之一
     if new_status not in (0, 1, 2):
@@ -711,17 +720,18 @@ async def update_collection_status(
 
 
 class ReportQuery(BaseModel):
-    creator_id: str  # 创建人ID,必须是字符串
+    # creator_id: str  # 创建人ID,必须是字符串
     report_id: str  # 填报ID,必须是字符串
 
 @router.get("/dataArchiveDetails/")
 async def get_records_by_creator_and_report(
     query: ReportQuery = Depends(),
-    db: Session = Depends(get_db)
+    db: Session = Depends(get_db),
+    creator_id = Depends(valid_access_token)
 ):
     # 查询 ReportManagement 表以获取对应记录
     report = db.query(ReportManagement).filter(
-        ReportManagement.creator_id == query.creator_id,
+        ReportManagement.creator_id == creator_id,
         ReportManagement.report_id == query.report_id
     ).first()