from fastapi import APIRouter, Request, Depends, HTTPException, Query,Response from sqlalchemy.exc import IntegrityError from fastapi.responses import JSONResponse from database import get_db from sqlalchemy.orm import Session from models import * import json import random from sqlalchemy import create_engine, select, or_ from typing import Optional import uuid router = APIRouter() @router.post('/create') async def create_knowledge(request: Request, db: Session = Depends(get_db)): body = await request.json() required_fields = ['reportName', 'subject', 'eventType', 'publishingUnit', 'publishDate', 'summary', 'fileNames'] if not all(field in body for field in required_fields): missing_fields = ", ".join([field for field in required_fields if field not in body]) return Response(content=f"Missing required fields: {missing_fields}", status_code=400) reportId = f'ZJBG{random.randint(1000000000, 9999999999)}' base_code = 'base' + f'{random.randint(1000000000, 9999999999)}' uuid1 = uuid.uuid1() knowledge = KnowledgeBase( reportId=reportId, **{ field: body[field] for field in required_fields if field != 'fileNames' }, base_code=base_code, del_flag = 0, updateTime=body['publishDate'], notificationType="总结报告" # 硬编码的值,如果有其他情况需要处理,请修改 ) db.add(knowledge) knowledge_files = [ KnowledgeFile( file_identifier=f'file{random.randint(1000000000, 9999999999)}', file_path=f'/data/upload/mergefile/{fileName["url"]}', storage_file_name=fileName["url"], file_name=fileName["name"], # 使用 fileName["name"] 作为文件名 is_deleted=0, updateTime=body['publishDate'], createTime=body['publishDate'], knowledge_base_code=base_code ) for fileName in body['fileNames'] # body['fileNames'] 现在是一个包含对象的数组,每个对象都有 'name' 和 'url' 属性 ] db.add_all(knowledge_files) try: db.commit() knowledge_entry = db.query(KnowledgeBase).filter(KnowledgeBase.reportId == reportId).first() files_entry = db.query(KnowledgeFile).filter(KnowledgeFile.knowledge_base_code == base_code).all() return { "code": 200, "msg": "总结报告创建成功", "status": "success", "data": [{ "reportId":knowledge_entry.reportId, "reportName":knowledge_entry.reportName, "eventType":knowledge_entry.eventType, "publishingUnit":knowledge_entry.publishingUnit, "subject":knowledge_entry.subject, "publishDate":knowledge_entry.publishDate, "summary":knowledge_entry.summary, "files": [ { "fileIdentifier": file.file_identifier, "fileName": file.file_name, "url": file.storage_file_name, "status":"success" } for file in files_entry ] }] } except IntegrityError as e: db.rollback() return Response(content=f"Database error: {str(e)}", status_code=409) except Exception as e: db.rollback() return Response(content=f"Internal server error: {str(e)}", status_code=500) @router.get('/select') async def select_knowledge( db: Session = Depends(get_db), sortBy: str = Query(..., description="排序字段"), sortOrder: str = Query(..., description="排序顺序"), pageNum: int = Query(1, gt=0, description="页码"), pageSize: int = Query(10, gt=0, le=100, description="每页大小"), eventType: str = Query(None, description="事件类型"), publishDateRange: str = Query(None, description="发布日期范围"), query: str = Query(None, description="查询关键字") ): data_query = db.query(KnowledgeBase) data_query = data_query.filter(KnowledgeBase.del_flag != '2') if eventType: data_query = data_query.filter(KnowledgeBase.eventType == eventType) if publishDateRange: start_date, end_date = publishDateRange.split('-') data_query = data_query.filter(KnowledgeBase.publishDate.between(start_date, end_date)) if query: search_fields = [getattr(KnowledgeBase, field) for field in ('reportName', 'publishingUnit', 'reportId') if hasattr(KnowledgeBase, field)] search_conditions = [field.like(f'%{query}%') for field in search_fields] data_query = data_query.filter(or_(*search_conditions)) if hasattr(KnowledgeBase, sortBy): sort_attr = getattr(KnowledgeBase, sortBy) data_query = data_query.order_by(sort_attr.asc() if sortOrder == 'asc' else sort_attr.desc()) total_count = data_query.count() offset = (pageNum - 1) * pageSize data_query = data_query.offset(offset).limit(pageSize) fields = ['reportId', 'reportName', 'eventType', 'publishDate', 'publishingUnit', 'summary', 'subject', 'notificationType', 'base_code'] entities = [getattr(KnowledgeBase, field) for field in fields if hasattr(KnowledgeBase, field)] data = data_query.with_entities(*entities).offset(offset).limit(pageSize).all() result_items = [] for item in data: item_dict = {field: getattr(item, field) for field in fields} print(item_dict) base_code = item_dict['base_code'] kf_entries = db.query(KnowledgeFile).filter(KnowledgeFile.knowledge_base_code == base_code).filter(KnowledgeFile.is_deleted != '2').all() item_dict['files'] = [ { "fileIdentifier": file.file_identifier, "fileName": file.file_name, "url": file.storage_file_name, "status":"success" } for file in kf_entries ] result_items.append(item_dict) result = { "code": 200, 'msg': '查询成功', 'pages': (total_count + pageSize - 1) // pageSize, 'total': total_count, "currentPage": pageNum, "pageSize": pageSize, 'data': result_items } return result @router.get('/detail') async def get_knowledge_detail(db: Session = Depends(get_db), reportID: Optional[str] = Query(None, description="报告ID")): if not reportID: raise HTTPException(status_code=400, detail="Missing required parameter 'reportID'") # kb_entry = db.query(KnowledgeBase) # kb_entry = kb_entry.filter(KnowledgeBase.del_flag != '2') # 先过滤出 del_flag 不等于 '2' 的记录,然后再根据 report_id 筛选 kb_entry = (db.query(KnowledgeBase) .filter(KnowledgeBase.del_flag != '2') # 确保这里过滤掉 del_flag 等于 '2' 的记录 .filter(KnowledgeBase.reportId == reportID)) # 确保 reportID 是在过滤后的查询中使用的 kb_entry = kb_entry.first() if not kb_entry: raise HTTPException(status_code=404, detail="报告不存在") kf_entries = db.query(KnowledgeFile).filter(KnowledgeFile.knowledge_base_code == kb_entry.base_code)\ .filter(KnowledgeFile.is_deleted != '2').all() files = [ {"content": kf.file_name, "url": kf.storage_file_name,"file_identifier":kf.file_identifier} for kf in kf_entries ] ''' files = [ { "fileIdentifier": file.file_identifier, "fileName": file.file_name, "url": file.storage_file_name, "status":"success" } for file in kf_entries ] ''' result = { "code": 200, "msg": "查询成功", "data": [{ "report_id": kb_entry.reportId, "reportName": kb_entry.reportName, "subject": kb_entry.subject, "eventType": kb_entry.eventType, "publishDate": kb_entry.publishDate, "publishingUnit": kb_entry.publishingUnit, "summary": kb_entry.summary, "notificationType": kb_entry.notificationType, "file": files }] } return result @router.post ('/delete') async def delete_knowledge(request: Request, db: Session = Depends(get_db)): # 从请求的 JSON 数据中获取 reportID body = await request.json() report_id_to_use = body.get('reportID') if not report_id_to_use: return Response(content="Missing required parameter 'reportID'", status_code=400) kb_entry = (db.query(KnowledgeBase) .filter(KnowledgeBase.del_flag != '2') # 确保这里过滤掉 del_flag 等于 '2' 的记录 .filter(KnowledgeBase.reportId == report_id_to_use)) # 确保 reportID 是在过滤后的查询中使用的 kb_entry = kb_entry.first() if not kb_entry: raise HTTPException(status_code=404, detail="报告不存在") # 将找到的记录的 is_deleted 改为 2 kb_entry.is_deleted = 2 kf_entries = db.query(KnowledgeFile).filter(KnowledgeFile.knowledge_base_code == kb_entry.base_code).all() for kf_entry in kf_entries: kf_entry.is_deleted = 2 if hasattr(kb_entry, 'del_flag'): kb_entry.del_flag = 2 try: db.commit() return { "code": 200, "msg": "操作成功", "data": { "report_id": kb_entry.reportId } } except Exception as e: db.rollback() return Response(content="An error occurred while deleting the record: " + str(e), status_code=500) @router.delete('/delete/list') async def delete_knowledge_list(request: Request, db: Session = Depends(get_db)): # 从请求的 JSON 数据中获取 reportID body = await request.json() report_id_to_use = body.get('reportID') if not report_id_to_use: return Response(content="Missing required parameter 'reportID'", status_code=400) # kb_entry = (db.query(KnowledgeBase) # .filter(KnowledgeBase.del_flag != '2') # 确保这里过滤掉 del_flag 等于 '2' 的记录 # .filter(KnowledgeBase.reportId.in_report_(id_to_use))) # 确保 reportID 是在过滤后的查询中使用的 # # kb_entrys = kb_entry.all() query = db.query(KnowledgeBase) query = query.filter(KnowledgeBase.del_flag != '2') query = query.filter(KnowledgeBase.reportId.in_(report_id_to_use)) kb_entrys = query.all() if not kb_entrys: raise HTTPException(status_code=404, detail="报告不存在") for kb_entry in kb_entrys: # 将找到的记录的 is_deleted 改为 2 kb_entry.is_deleted = 2 kf_entries = db.query(KnowledgeFile).filter(KnowledgeFile.knowledge_base_code == kb_entry.base_code).all() for kf_entry in kf_entries: kf_entry.is_deleted = 2 if hasattr(kb_entry, 'del_flag'): kb_entry.del_flag = 2 try: db.commit() return { "code": 200, "msg": "操作成功", "data": None } except Exception as e: db.rollback() return Response(content="An error occurred while deleting the record: " + str(e), status_code=500) def delete_file_fun(knowledge_base_code,db: Session): file_query = db.query(KnowledgeFile) file_query = file_query.filter(KnowledgeFile.is_deleted != '2') file_query = file_query.filter(KnowledgeFile.knowledge_base_code == knowledge_base_code) # file_query = file_query.filter(KnowledgeFile.foreign_key == foreign_key) files = file_query.all() for file in files: file.is_deleted='2' db.commit() @router.post('/update') async def update_knowledge(request: Request, db: Session = Depends(get_db)): body = await request.json() report_id_to_use = body.get('reportId') if not report_id_to_use: return Response(content="Missing required parameter 'reportId'", status_code=400) kb_entry = (db.query(KnowledgeBase) .filter(KnowledgeBase.del_flag != '2') # 确保这里过滤掉 del_flag 等于 '2' 的记录 .filter(KnowledgeBase.reportId == report_id_to_use)) # 确保 reportID 是在过滤后的查询中使用的 kb_entry = kb_entry.first() if not kb_entry: raise HTTPException(status_code=404, detail="报告不存在") kb_entry.reportName = body.get('reportName', kb_entry.reportName) kb_entry.subject = body.get('subject', kb_entry.subject) kb_entry.eventType = body.get('eventType', kb_entry.eventType) kb_entry.publishingUnit = body.get('publishingUnit', kb_entry.publishingUnit) kb_entry.publishDate = body.get('publishDate', kb_entry.publishDate) kb_entry.summary = body.get('summary', kb_entry.summary) kb_entry.updateTime = datetime.strptime(body.get('publishDate', kb_entry.updateTime), '%Y-%m-%d %H:%M:%S') base_code = kb_entry.base_code if len(body.get('fileNames')) > 0: if kb_entry.base_code: delete_file_fun(kb_entry.base_code,db=db) knowledge_files = [ KnowledgeFile( file_identifier=f'file{random.randint(1000000000, 9999999999)}', file_path=f'/data/upload/mergefile/{fileName["url"]}', storage_file_name=fileName["url"], file_name=fileName["name"], # 使用 fileName["name"] 作为文件名 is_deleted=0, updateTime=body['publishDate'], knowledge_base_code=base_code ) for fileName in body['fileNames'] # body['fileNames'] 现在是一个包含对象的数组,每个对象都有 'name' 和 'url' 属性 ] db.add_all(knowledge_files) try: db.commit() return { "code": 200, "msg": "修改成功", "data": { "report_id": report_id_to_use, "updateTime": kb_entry.updateTime.isoformat() } } except IntegrityError as e: db.rollback() return Response(content=f"Database error: {str(e)}", status_code=409) except Exception as e: db.rollback() return Response(content=str(e), status_code=500) @router.put('/update/list') async def update_knowledge_list(request: Request, db: Session = Depends(get_db)): body = await request.json() update_items = body.get('updateItems') if not update_items or not isinstance(update_items, list): return Response(content="Missing required parameter 'updateItems'", status_code=400) try: for item in update_items: report_id_to_use = item.get('reportId') if not report_id_to_use: continue # 如果没有提供 reportId,则跳过当前项 # 根据 reportId 查找 KnowledgeBase 记录 kb_entry = db.query(KnowledgeBase).filter( KnowledgeBase.reportId == report_id_to_use ).first() if not kb_entry: continue # 如果没有找到记录,则跳过当前项 # 更新 KnowledgeBase 记录 for field, value in item.items(): if field != 'reportId' and hasattr(kb_entry, field): setattr(kb_entry, field, value) kb_entry.updateTime = datetime.fromisoformat(item.get('publishDate')) # 提交事务 db.commit() return { "code": 200, "msg": "批量更新成功", "data": { "updated_count": len(update_items) } } except IntegrityError as e: db.rollback() return Response(content=f"Database error: {str(e)}", status_code=409) except Exception as e: db.rollback() return Response(content=str(e), status_code=500)