__init__.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245
  1. from fastapi import APIRouter, Request, Depends, HTTPException
  2. from sqlalchemy.exc import IntegrityError
  3. from fastapi.responses import JSONResponse
  4. from database import get_db
  5. from sqlalchemy.orm import Session
  6. from models import *
  7. import json
  8. import random
  9. from sqlalchemy import create_engine, select
  10. router = APIRouter()
  11. # @router.post('/create')
  12. # async def create_knowledge(request:Request,db:Session = Depends(get_db)):
  13. # data = await request.body()
  14. # body = data.decode(encoding='utf-8')
  15. # if len(body) > 0:
  16. # body = json.loads(body)
  17. # print(body)
  18. # random_10_digit_number = random.randint(1000000000, 9999999999)
  19. # # file_identifier = 'f'
  20. # reportId = 'ZJBG'+str(random_10_digit_number)
  21. # # file_identifier =
  22. # reportName = body["reportName"]
  23. # subject = body["subject"]
  24. # eventType = body["eventType"]
  25. # publishingUnit = body["publishingUnit"]
  26. # publishDate = body["publishDate"]
  27. # summary = body["summary"]
  28. #
  29. # notificationType = body["notificationType"]
  30. #
  31. # base_code = 'base'+str(random.randint(1000000000, 9999999999))
  32. #
  33. # fileNames = body["fileName"]
  34. # filePath = '/data/upload/mergefile/'
  35. #
  36. #
  37. #
  38. # konwledge = KnowledgeBase(
  39. # reportId=reportId,
  40. #
  41. # reportName=reportName,
  42. # subject=subject,
  43. # eventType=eventType,
  44. # publishingUnit=publishingUnit,
  45. # publishDate=publishDate,
  46. # summary = summary,
  47. # notificationType = notificationType,
  48. #
  49. # base_code = base_code
  50. # )
  51. # db.add(konwledge)
  52. #
  53. # for fileName in fileNames:
  54. # file_identifier='file'+str(random.randint(1000000000, 9999999999))
  55. # knowledge_file = KnowledgeFile(
  56. # file_identifier=file_identifier,
  57. # file_path=filePath,
  58. # file_name = fileName,
  59. # is_deleted = 0,
  60. # knowledge_base_code = base_code
  61. # )
  62. # db.add(knowledge_file)
  63. #
  64. # db.commit()
  65. # return {
  66. # "code":0,
  67. # "data":{
  68. # "reportId": reportId,
  69. # "status": "success",
  70. # "message": "总结报告创建成功"
  71. # }
  72. # }
  73. @router.post('/create')
  74. async def create_knowledge(request: Request, db: Session = Depends(get_db)):
  75. try:
  76. data = await request.body()
  77. body = json.loads(data.decode(encoding='utf-8'))
  78. # 验证必需的字段
  79. required_fields = ['reportName', 'subject', 'eventType', 'publishingUnit', 'publishDate', 'summary',
  80. 'notificationType', 'fileNames']
  81. missing_fields = [field for field in required_fields if field not in body]
  82. print('missing_fields',missing_fields)
  83. if missing_fields:
  84. raise HTTPException(status_code=401, detail=f"Missing required fields: {', '.join(missing_fields)}")
  85. # 生成随机的报告ID和基础知识代码
  86. random_10_digit_number = random.randint(1000000000, 9999999999)
  87. reportId = 'ZJBG' + str(random_10_digit_number)
  88. base_code = 'base' + str(random.randint(1000000000, 9999999999))
  89. # 从请求体中提取其他数据
  90. reportName = body["reportName"]
  91. subject = body["subject"]
  92. eventType = body["eventType"]
  93. publishingUnit = body["publishingUnit"]
  94. publishDate = body["publishDate"]
  95. summary = body["summary"]
  96. notificationType = body["notificationType"]
  97. fileNames = body["fileNames"] # 注意:这里假设它是列表
  98. # 创建 KnowledgeBase 实例
  99. konwledge = KnowledgeBase(
  100. reportId=reportId,
  101. reportName=reportName,
  102. subject=subject,
  103. eventType=eventType,
  104. publishingUnit=publishingUnit,
  105. publishDate=publishDate,
  106. summary=summary,
  107. notificationType=notificationType,
  108. base_code=base_code
  109. )
  110. db.add(konwledge)
  111. # 创建 KnowledgeFile 实例
  112. filePath = '/data/upload/mergefile/'
  113. for fileName in fileNames:
  114. file_identifier = 'file' + str(random.randint(1000000000, 9999999999))
  115. knowledge_file = KnowledgeFile(
  116. file_identifier=file_identifier,
  117. file_path=filePath + fileName, # 如果fileName是完整的路径,则可能不需要再次添加filePath
  118. file_name=fileName,
  119. is_deleted=0,
  120. knowledge_base_code=base_code
  121. )
  122. db.add(knowledge_file)
  123. db.commit()
  124. return {
  125. "code": 0,
  126. "data": {
  127. "reportId": reportId,
  128. "status": "success",
  129. "message": "总结报告创建成功"
  130. }
  131. }
  132. except json.JSONDecodeError:
  133. raise HTTPException(status_code=400, detail="Invalid JSON data")
  134. except IntegrityError as e:
  135. db.rollback()
  136. raise HTTPException(status_code=409, detail=f"Database error: {str(e)}")
  137. except Exception as e:
  138. db.rollback()
  139. print(e)
  140. raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
  141. @router.post('/select')
  142. @router.get('/select')
  143. async def select_knowledge(request: Request, db: Session = Depends(get_db)):
  144. # 尝试从请求体中解析 JSON 数据
  145. data = await request.json() # 注意:这里直接使用 request.json() 而不是 request.body()
  146. try:
  147. page = int(data.get('page', 1)) # 如果没有提供 page,则默认为 1
  148. size = int(data.get('size', 10)) # 如果没有提供 size,则默认为 10
  149. if size >100:
  150. size = 100
  151. except ValueError as e:
  152. # 如果转换失败,则抛出 HTTPException
  153. raise HTTPException(status_code=400, detail=f"Invalid pagination parameters: {e}")
  154. # 计算 offset
  155. offset = (page - 1) * size
  156. # 使用 ORM 查询并应用分页
  157. data = db.query(KnowledgeBase).offset(offset).limit(size).all()
  158. # 打印结果(可选,用于调试)
  159. # for i in data:
  160. # print(i)
  161. print(f"Returned {len(data)} results from page {page} with size {size}")
  162. # 计算总条数(注意:这可能会很慢,特别是对于大型数据集)
  163. total_count = db.query(func.count(KnowledgeBase.reportId)).scalar()
  164. # 计算总页数
  165. total_pages = (total_count // size) + (1 if total_count % size else 0)
  166. # 返回查询结果
  167. result = {
  168. "code": 0,
  169. 'msg': 'success',
  170. 'data': {
  171. 'pages': total_pages,
  172. 'total': total_count,
  173. "currentPage":page,
  174. "pageSize":size,
  175. "list": data
  176. }
  177. }
  178. return result
  179. @router.post('/detail')
  180. @router.get('/detail')
  181. async def get_knowledge_detail(request: Request, db: Session = Depends(get_db)):
  182. # 尝试从请求体中解析 JSON 数据
  183. data = await request.json()
  184. report_id = data.get('reportID')
  185. if not report_id:
  186. raise HTTPException(status_code=400, detail="Missing required parameter 'reportID'")
  187. # 查询 KnowledgeBase
  188. kb_entry = db.query(KnowledgeBase).filter(KnowledgeBase.reportId == report_id).first()
  189. if not kb_entry:
  190. raise HTTPException(status_code=404, detail="No knowledge base found for the given report ID")
  191. kf_entries = db.query(KnowledgeFile).filter(KnowledgeFile.knowledge_base_code == kb_entry.base_code).all()
  192. # 准备返回的数据
  193. result = {
  194. "code": 0,
  195. "msg": "success",
  196. "data": {
  197. "report_id": kb_entry.reportId,
  198. "reportName": kb_entry.reportName,
  199. "subject": kb_entry.subject,
  200. "eventType": kb_entry.eventType,
  201. "publishDate": kb_entry.publishDate,
  202. "publishingUnit": kb_entry.publishingUnit,
  203. "summary": kb_entry.summary,
  204. "notificationType": kb_entry.notificationType,
  205. # "knowledge_base_code": kb_entry.base_code,
  206. "file": [
  207. {
  208. "content": kf.file_name,
  209. "url": 'http://127.0.0.1:9988/api/file/download/'+kf.file_name
  210. } # 根据需要调整返回的字段
  211. for kf in kf_entries
  212. ]
  213. }
  214. }
  215. return result