__init__.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295
  1. from fastapi import APIRouter, Request, Depends, HTTPException
  2. from sqlalchemy.exc import IntegrityError
  3. from fastapi.responses import JSONResponse
  4. from database import get_db
  5. from sqlalchemy.orm import Session
  6. from models import *
  7. import json
  8. import random
  9. from sqlalchemy import create_engine, select
  10. from typing import Optional
  11. router = APIRouter()
  12. # @router.post('/create')
  13. # async def create_knowledge(request:Request,db:Session = Depends(get_db)):
  14. # data = await request.body()
  15. # body = data.decode(encoding='utf-8')
  16. # if len(body) > 0:
  17. # body = json.loads(body)
  18. # print(body)
  19. # random_10_digit_number = random.randint(1000000000, 9999999999)
  20. # # file_identifier = 'f'
  21. # reportId = 'ZJBG'+str(random_10_digit_number)
  22. # # file_identifier =
  23. # reportName = body["reportName"]
  24. # subject = body["subject"]
  25. # eventType = body["eventType"]
  26. # publishingUnit = body["publishingUnit"]
  27. # publishDate = body["publishDate"]
  28. # summary = body["summary"]
  29. #
  30. # notificationType = body["notificationType"]
  31. #
  32. # base_code = 'base'+str(random.randint(1000000000, 9999999999))
  33. #
  34. # fileNames = body["fileName"]
  35. # filePath = '/data/upload/mergefile/'
  36. #
  37. #
  38. #
  39. # konwledge = KnowledgeBase(
  40. # reportId=reportId,
  41. #
  42. # reportName=reportName,
  43. # subject=subject,
  44. # eventType=eventType,
  45. # publishingUnit=publishingUnit,
  46. # publishDate=publishDate,
  47. # summary = summary,
  48. # notificationType = notificationType,
  49. #
  50. # base_code = base_code
  51. # )
  52. # db.add(konwledge)
  53. #
  54. # for fileName in fileNames:
  55. # file_identifier='file'+str(random.randint(1000000000, 9999999999))
  56. # knowledge_file = KnowledgeFile(
  57. # file_identifier=file_identifier,
  58. # file_path=filePath,
  59. # file_name = fileName,
  60. # is_deleted = 0,
  61. # knowledge_base_code = base_code
  62. # )
  63. # db.add(knowledge_file)
  64. #
  65. # db.commit()
  66. # return {
  67. # "code":0,
  68. # "data":{
  69. # "reportId": reportId,
  70. # "status": "success",
  71. # "message": "总结报告创建成功"
  72. # }
  73. # }
  74. @router.post('/create')
  75. async def create_knowledge(request: Request, db: Session = Depends(get_db)):
  76. try:
  77. data = await request.body()
  78. body = json.loads(data.decode(encoding='utf-8'))
  79. # 验证必需的字段
  80. required_fields = ['reportName', 'subject', 'eventType', 'publishingUnit', 'publishDate', 'summary',
  81. 'notificationType', 'fileNames']
  82. missing_fields = [field for field in required_fields if field not in body]
  83. print('missing_fields',missing_fields)
  84. if missing_fields:
  85. raise HTTPException(status_code=401, detail=f"Missing required fields: {', '.join(missing_fields)}")
  86. # 生成随机的报告ID和基础知识代码
  87. random_10_digit_number = random.randint(1000000000, 9999999999)
  88. reportId = 'ZJBG' + str(random_10_digit_number)
  89. base_code = 'base' + str(random.randint(1000000000, 9999999999))
  90. # 从请求体中提取其他数据
  91. reportName = body["reportName"]
  92. subject = body["subject"]
  93. eventType = body["eventType"]
  94. publishingUnit = body["publishingUnit"]
  95. publishDate = body["publishDate"]
  96. summary = body["summary"]
  97. notificationType = body["notificationType"]
  98. fileNames = body["fileNames"] # 注意:这里假设它是列表
  99. # 创建 KnowledgeBase 实例
  100. konwledge = KnowledgeBase(
  101. reportId=reportId,
  102. reportName=reportName,
  103. subject=subject,
  104. eventType=eventType,
  105. publishingUnit=publishingUnit,
  106. publishDate=publishDate,
  107. summary=summary,
  108. notificationType=notificationType,
  109. base_code=base_code
  110. )
  111. db.add(konwledge)
  112. # 创建 KnowledgeFile 实例
  113. filePath = '/data/upload/mergefile/'
  114. for fileName in fileNames:
  115. file_identifier = 'file' + str(random.randint(1000000000, 9999999999))
  116. knowledge_file = KnowledgeFile(
  117. file_identifier=file_identifier,
  118. file_path=filePath + fileName, # 如果fileName是完整的路径,则可能不需要再次添加filePath
  119. file_name=fileName,
  120. is_deleted=0,
  121. knowledge_base_code=base_code
  122. )
  123. db.add(knowledge_file)
  124. db.commit()
  125. return {
  126. "code": 200,
  127. "msg": "总结报告创建成功",
  128. "status": "success",
  129. "data":
  130. [reportId]
  131. }
  132. # return {
  133. # "code": 200,
  134. # "msg": "操作成功",
  135. # "data": {
  136. # "reportId": reportId,
  137. # "status": "success",
  138. # "message": "总结报告创建成功"
  139. # }
  140. # }
  141. except json.JSONDecodeError:
  142. raise HTTPException(status_code=400, detail="Invalid JSON data")
  143. except IntegrityError as e:
  144. db.rollback()
  145. raise HTTPException(status_code=409, detail=f"Database error: {str(e)}")
  146. except Exception as e:
  147. db.rollback()
  148. print(e)
  149. raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
  150. @router.post('/select')
  151. @router.get('/select')
  152. async def select_knowledge(request: Request, db: Session = Depends(get_db), pageNum: Optional[int] = None,
  153. pageSize: Optional[int] = None):
  154. # page_from_json = None
  155. try:
  156. # # 尝试从请求体中解析 JSON 数据
  157. json_data = await request.json()
  158. # 初始化分页参数
  159. page_from_json = json_data.get('pageNum') if json_data else None
  160. size_from_json = json_data.get('pageSize') if json_data else None
  161. except:
  162. page_from_json = None
  163. # 如果查询参数和请求体都存在,优先选择查询参数
  164. if pageNum is not None and pageSize is not None:
  165. # 使用查询参数
  166. page_to_use = pageNum
  167. size_to_use = pageSize
  168. elif page_from_json is not None and size_from_json is not None:
  169. # 使用请求体中的参数
  170. page_to_use = page_from_json
  171. size_to_use = size_from_json
  172. else:
  173. # 如果只有一个存在,使用存在的那个,否则抛出异常
  174. if pageNum is not None:
  175. page_to_use = pageNum
  176. elif page_from_json is not None:
  177. page_to_use = page_from_json
  178. else:
  179. raise HTTPException(status_code=400, detail="Page parameter is required")
  180. if pageSize is not None:
  181. size_to_use = pageSize
  182. elif size_from_json is not None:
  183. size_to_use = size_from_json
  184. else:
  185. raise HTTPException(status_code=400, detail="Size parameter is required")
  186. # 验证分页参数
  187. if size_to_use > 100:
  188. size_to_use = 100
  189. # 计算 offset
  190. offset = (page_to_use - 1) * size_to_use
  191. # 使用 ORM 查询并应用分页
  192. data = db.query(KnowledgeBase).offset(offset).limit(size_to_use).all()
  193. # 计算总条数
  194. total_count = db.query(func.count(KnowledgeBase.reportId)).scalar()
  195. # 计算总页数
  196. total_pages = (total_count // size_to_use) + (1 if total_count % size_to_use else 0)
  197. # 返回查询结果
  198. # result = {
  199. # "code": 200,
  200. # 'msg': '查询成功',
  201. # 'data': {
  202. # 'pages': total_pages,
  203. # 'total': total_count,
  204. # "currentPage": page_to_use,
  205. # "pageSize": size_to_use,
  206. # "rows": data
  207. # }
  208. # }
  209. result = {
  210. "code": 200,
  211. 'msg': '查询成功',
  212. 'pages': total_pages,
  213. 'total': total_count,
  214. "currentPage": page_to_use,
  215. "pageSize": size_to_use,
  216. 'data': data
  217. }
  218. return result
  219. @router.post('/detail')
  220. @router.get('/detail')
  221. async def get_knowledge_detail(request: Request, db: Session = Depends(get_db), reportID: Optional[str] = None):
  222. # 尝试从请求体中解析 JSON 数据
  223. report_id_body=None
  224. try:
  225. data = await request.json()
  226. report_id_body = data.get('reportID') if data else None
  227. except:
  228. print('报错')
  229. pass
  230. # 确定要使用的 report_id
  231. report_id_to_use = reportID or report_id_body
  232. print(reportID)
  233. print(report_id_body)
  234. # 如果没有提供 report_id,则抛出异常
  235. if not report_id_to_use:
  236. raise HTTPException(status_code=400, detail="Missing required parameter 'reportID'")
  237. # 查询 KnowledgeBase
  238. kb_entry = db.query(KnowledgeBase).filter(KnowledgeBase.reportId == report_id_to_use).first()
  239. if not kb_entry:
  240. raise HTTPException(status_code=404, detail="No knowledge base found for the given report ID")
  241. # 查询相关的 KnowledgeFile
  242. kf_entries = db.query(KnowledgeFile).filter(KnowledgeFile.knowledge_base_code == kb_entry.base_code).all()
  243. # 准备返回的数据
  244. result = {
  245. "code": 200,
  246. "msg": "查询成功",
  247. "data": [{
  248. "report_id": kb_entry.reportId,
  249. "reportName": kb_entry.reportName,
  250. "subject": kb_entry.subject,
  251. "eventType": kb_entry.eventType,
  252. "publishDate": kb_entry.publishDate,
  253. "publishingUnit": kb_entry.publishingUnit,
  254. "summary": kb_entry.summary,
  255. "notificationType": kb_entry.notificationType,
  256. "file": [
  257. {"content": kf.file_name, "url": f'http://127.0.0.1:9988/api/file/download/{kf.file_name}'}
  258. for kf in kf_entries
  259. ]
  260. }]
  261. }
  262. return result