__init__.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. from fastapi import APIRouter, Request, Depends, HTTPException, Query
  2. from sqlalchemy.exc import IntegrityError
  3. from fastapi.responses import JSONResponse
  4. from database import get_db
  5. from sqlalchemy.orm import Session
  6. from models import *
  7. import json
  8. import random
  9. from sqlalchemy import create_engine, select, or_
  10. from typing import Optional
  11. router = APIRouter()
  12. @router.post('/create')
  13. async def create_knowledge(request: Request, db: Session = Depends(get_db)):
  14. body = await request.json()
  15. required_fields = ['reportName', 'subject', 'eventType', 'publishingUnit', 'publishDate', 'summary', 'fileNames']
  16. if not all(field in body for field in required_fields):
  17. missing_fields = ", ".join([field for field in required_fields if field not in body])
  18. return Response(content=f"Missing required fields: {missing_fields}", status_code=400)
  19. reportId = f'ZJBG{random.randint(1000000000, 9999999999)}'
  20. base_code = 'base' + f'{random.randint(1000000000, 9999999999)}'
  21. knowledge = KnowledgeBase(
  22. reportId=reportId,
  23. **{
  24. field: body[field]
  25. for field in required_fields
  26. if field != 'fileNames'
  27. },
  28. base_code=base_code,
  29. del_flag = 0,
  30. updateTime=body['publishDate'],
  31. notificationType="总结报告" # 硬编码的值,如果有其他情况需要处理,请修改
  32. )
  33. db.add(knowledge)
  34. knowledge_files = [
  35. KnowledgeFile(
  36. file_identifier=f'file{random.randint(1000000000, 9999999999)}',
  37. file_path=f'/data/upload/mergefile/{fileName}', # 假设fileName是文件名
  38. file_name=fileName,
  39. is_deleted=0,
  40. updateTime=body['publishDate'],
  41. createTime=body['publishDate'],
  42. knowledge_base_code=base_code
  43. )
  44. for fileName in body['fileNames']
  45. ]
  46. db.add_all(knowledge_files)
  47. try:
  48. db.commit()
  49. return {
  50. "code": 200,
  51. "msg": "总结报告创建成功",
  52. "status": "success",
  53. "data": [reportId]
  54. }
  55. except IntegrityError as e:
  56. db.rollback()
  57. return Response(content=f"Database error: {str(e)}", status_code=409)
  58. except Exception as e:
  59. db.rollback()
  60. return Response(content=f"Internal server error: {str(e)}", status_code=500)
  61. @router.get('/select')
  62. async def select_knowledge(
  63. db: Session = Depends(get_db),
  64. sortBy: str = Query(..., description="排序字段"),
  65. sortOrder: str = Query(..., description="排序顺序"),
  66. pageNum: int = Query(1, gt=0, description="页码"),
  67. pageSize: int = Query(10, gt=0, le=100, description="每页大小"),
  68. eventType: str = Query(None, description="事件类型"),
  69. publishDateRange: str = Query(None, description="发布日期范围"),
  70. query: str = Query(None, description="查询关键字")
  71. ):
  72. data_query = db.query(KnowledgeBase)
  73. data_query = data_query.filter(KnowledgeBase.del_flag != '2')
  74. if eventType:
  75. data_query = data_query.filter(KnowledgeBase.eventType == eventType)
  76. if publishDateRange:
  77. start_date, end_date = publishDateRange.split('-')
  78. data_query = data_query.filter(KnowledgeBase.publishDate.between(start_date, end_date))
  79. if query:
  80. search_fields = [getattr(KnowledgeBase, field) for field in ('reportName', 'publishingUnit', 'reportId') if hasattr(KnowledgeBase, field)]
  81. search_conditions = [field.like(f'%{query}%') for field in search_fields]
  82. data_query = data_query.filter(or_(*search_conditions))
  83. if hasattr(KnowledgeBase, sortBy):
  84. sort_attr = getattr(KnowledgeBase, sortBy)
  85. data_query = data_query.order_by(sort_attr.asc() if sortOrder == 'asc' else sort_attr.desc())
  86. offset = (pageNum - 1) * pageSize
  87. data_query = data_query.offset(offset).limit(pageSize)
  88. fields = ['reportId', 'reportName', 'eventType', 'publishDate', 'publishingUnit', 'summary', 'subject', 'notificationType']
  89. entities = [getattr(KnowledgeBase, field) for field in fields if hasattr(KnowledgeBase, field)]
  90. data = data_query.with_entities(*entities).offset(offset).limit(pageSize).all()
  91. result_items = []
  92. for item in data:
  93. item_dict = {field: getattr(item, field) for field in fields}
  94. result_items.append(item_dict)
  95. total_count = data_query.count()
  96. result = {
  97. "code": 200,
  98. 'msg': '查询成功',
  99. 'pages': (total_count + pageSize - 1) // pageSize,
  100. 'total': total_count,
  101. "currentPage": pageNum,
  102. "pageSize": pageSize,
  103. 'data': result_items
  104. }
  105. return result
  106. @router.get('/detail')
  107. async def get_knowledge_detail(db: Session = Depends(get_db), reportID: Optional[str] = Query(None, description="报告ID")):
  108. report_id_to_use = reportID
  109. if not report_id_to_use:
  110. raise HTTPException(status_code=400, detail="Missing required parameter 'reportID'")
  111. kb_entry = db.query(KnowledgeBase).filter(KnowledgeBase.reportId == report_id_to_use).first()
  112. if not kb_entry:
  113. raise HTTPException(status_code=404, detail="No knowledge base found for the given report ID")
  114. kf_entries = db.query(KnowledgeFile).filter(KnowledgeFile.knowledge_base_code == kb_entry.base_code).all()
  115. files = [
  116. {"content": kf.file_name, "url": f'http://127.0.0.1:9988/api/file/download/{kf.file_name}'}
  117. for kf in kf_entries
  118. ]
  119. result = {
  120. "code": 200,
  121. "msg": "查询成功",
  122. "data": [{
  123. "report_id": kb_entry.reportId,
  124. "reportName": kb_entry.reportName,
  125. "subject": kb_entry.subject,
  126. "eventType": kb_entry.eventType,
  127. "publishDate": kb_entry.publishDate,
  128. "publishingUnit": kb_entry.publishingUnit,
  129. "summary": kb_entry.summary,
  130. "notificationType": kb_entry.notificationType,
  131. "file": files
  132. }]
  133. }
  134. return result