__init__.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302
  1. from fastapi import APIRouter, Request, Depends, HTTPException
  2. from sqlalchemy.exc import IntegrityError
  3. from fastapi.responses import JSONResponse
  4. from database import get_db
  5. from sqlalchemy.orm import Session
  6. from models import *
  7. import json
  8. import random
  9. from sqlalchemy import create_engine, select
  10. from typing import Optional
  11. router = APIRouter()
  12. # @router.post('/create')
  13. # async def create_knowledge(request:Request,db:Session = Depends(get_db)):
  14. # data = await request.body()
  15. # body = data.decode(encoding='utf-8')
  16. # if len(body) > 0:
  17. # body = json.loads(body)
  18. # print(body)
  19. # random_10_digit_number = random.randint(1000000000, 9999999999)
  20. # # file_identifier = 'f'
  21. # reportId = 'ZJBG'+str(random_10_digit_number)
  22. # # file_identifier =
  23. # reportName = body["reportName"]
  24. # subject = body["subject"]
  25. # eventType = body["eventType"]
  26. # publishingUnit = body["publishingUnit"]
  27. # publishDate = body["publishDate"]
  28. # summary = body["summary"]
  29. #
  30. # notificationType = body["notificationType"]
  31. #
  32. # base_code = 'base'+str(random.randint(1000000000, 9999999999))
  33. #
  34. # fileNames = body["fileName"]
  35. # filePath = '/data/upload/mergefile/'
  36. #
  37. #
  38. #
  39. # konwledge = KnowledgeBase(
  40. # reportId=reportId,
  41. #
  42. # reportName=reportName,
  43. # subject=subject,
  44. # eventType=eventType,
  45. # publishingUnit=publishingUnit,
  46. # publishDate=publishDate,
  47. # summary = summary,
  48. # notificationType = notificationType,
  49. #
  50. # base_code = base_code
  51. # )
  52. # db.add(konwledge)
  53. #
  54. # for fileName in fileNames:
  55. # file_identifier='file'+str(random.randint(1000000000, 9999999999))
  56. # knowledge_file = KnowledgeFile(
  57. # file_identifier=file_identifier,
  58. # file_path=filePath,
  59. # file_name = fileName,
  60. # is_deleted = 0,
  61. # knowledge_base_code = base_code
  62. # )
  63. # db.add(knowledge_file)
  64. #
  65. # db.commit()
  66. # return {
  67. # "code":0,
  68. # "data":{
  69. # "reportId": reportId,
  70. # "status": "success",
  71. # "message": "总结报告创建成功"
  72. # }
  73. # }
  74. @router.post('/create')
  75. async def create_knowledge(request: Request, db: Session = Depends(get_db)):
  76. try:
  77. data = await request.body()
  78. body = json.loads(data.decode(encoding='utf-8'))
  79. # 验证必需的字段
  80. required_fields = ['reportName', 'subject', 'eventType', 'publishingUnit', 'publishDate', 'summary',
  81. # 'notificationType',
  82. 'fileNames']
  83. missing_fields = [field for field in required_fields if field not in body]
  84. print('missing_fields',missing_fields)
  85. if missing_fields:
  86. raise HTTPException(status_code=401, detail=f"Missing required fields: {', '.join(missing_fields)}")
  87. # 生成随机的报告ID和基础知识代码
  88. random_10_digit_number = random.randint(1000000000, 9999999999)
  89. reportId = 'ZJBG' + str(random_10_digit_number)
  90. base_code = 'base' + str(random.randint(1000000000, 9999999999))
  91. # 从请求体中提取其他数据
  92. reportName = body["reportName"]
  93. subject = body["subject"]
  94. eventType = body["eventType"]
  95. publishingUnit = body["publishingUnit"]
  96. publishDate = body["publishDate"]
  97. summary = body["summary"]
  98. notificationType = "总结报告"
  99. updateTime = body["publishDate"]
  100. # notificationType = body["notificationType"]
  101. fileNames = body["fileNames"] # 注意:这里假设它是列表
  102. # 创建 KnowledgeBase 实例
  103. konwledge = KnowledgeBase(
  104. reportId=reportId,
  105. reportName=reportName,
  106. subject=subject,
  107. eventType=eventType,
  108. publishingUnit=publishingUnit,
  109. publishDate=publishDate,
  110. summary=summary,
  111. notificationType=notificationType,
  112. base_code=base_code,
  113. updateTime=updateTime
  114. )
  115. db.add(konwledge)
  116. # 创建 KnowledgeFile 实例
  117. filePath = '/data/upload/mergefile/'
  118. for fileName in fileNames:
  119. file_identifier = 'file' + str(random.randint(1000000000, 9999999999))
  120. knowledge_file = KnowledgeFile(
  121. file_identifier=file_identifier,
  122. file_path=filePath + fileName, # 如果fileName是完整的路径,则可能不需要再次添加filePath
  123. file_name=fileName,
  124. is_deleted=0,
  125. updateTime = updateTime,
  126. createTime = updateTime,
  127. knowledge_base_code=base_code
  128. )
  129. db.add(knowledge_file)
  130. db.commit()
  131. return {
  132. "code": 200,
  133. "msg": "总结报告创建成功",
  134. "status": "success",
  135. "data":
  136. [reportId]
  137. }
  138. # return {
  139. # "code": 200,
  140. # "msg": "操作成功",
  141. # "data": {
  142. # "reportId": reportId,
  143. # "status": "success",
  144. # "message": "总结报告创建成功"
  145. # }
  146. # }
  147. except json.JSONDecodeError:
  148. raise HTTPException(status_code=400, detail="Invalid JSON data")
  149. except IntegrityError as e:
  150. db.rollback()
  151. raise HTTPException(status_code=409, detail=f"Database error: {str(e)}")
  152. except Exception as e:
  153. db.rollback()
  154. print(e)
  155. raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
  156. @router.post('/select')
  157. @router.get('/select')
  158. async def select_knowledge(request: Request, db: Session = Depends(get_db), pageNum: Optional[int] = None,
  159. pageSize: Optional[int] = None):
  160. # page_from_json = None
  161. try:
  162. # # 尝试从请求体中解析 JSON 数据
  163. json_data = await request.json()
  164. # 初始化分页参数
  165. page_from_json = json_data.get('pageNum') if json_data else None
  166. size_from_json = json_data.get('pageSize') if json_data else None
  167. except:
  168. page_from_json = None
  169. # 如果查询参数和请求体都存在,优先选择查询参数
  170. if pageNum is not None and pageSize is not None:
  171. # 使用查询参数
  172. page_to_use = pageNum
  173. size_to_use = pageSize
  174. elif page_from_json is not None and size_from_json is not None:
  175. # 使用请求体中的参数
  176. page_to_use = page_from_json
  177. size_to_use = size_from_json
  178. else:
  179. # 如果只有一个存在,使用存在的那个,否则抛出异常
  180. if pageNum is not None:
  181. page_to_use = pageNum
  182. elif page_from_json is not None:
  183. page_to_use = page_from_json
  184. else:
  185. raise HTTPException(status_code=400, detail="Page parameter is required")
  186. if pageSize is not None:
  187. size_to_use = pageSize
  188. elif size_from_json is not None:
  189. size_to_use = size_from_json
  190. else:
  191. raise HTTPException(status_code=400, detail="Size parameter is required")
  192. # 验证分页参数
  193. if size_to_use > 100:
  194. size_to_use = 100
  195. # 计算 offset
  196. offset = (page_to_use - 1) * size_to_use
  197. # 使用 ORM 查询并应用分页
  198. data = db.query(KnowledgeBase).offset(offset).limit(size_to_use).all()
  199. # 计算总条数
  200. total_count = db.query(func.count(KnowledgeBase.reportId)).scalar()
  201. # 计算总页数
  202. total_pages = (total_count // size_to_use) + (1 if total_count % size_to_use else 0)
  203. # 返回查询结果
  204. # result = {
  205. # "code": 200,
  206. # 'msg': '查询成功',
  207. # 'data': {
  208. # 'pages': total_pages,
  209. # 'total': total_count,
  210. # "currentPage": page_to_use,
  211. # "pageSize": size_to_use,
  212. # "rows": data
  213. # }
  214. # }
  215. result = {
  216. "code": 200,
  217. 'msg': '查询成功',
  218. 'pages': total_pages,
  219. 'total': total_count,
  220. "currentPage": page_to_use,
  221. "pageSize": size_to_use,
  222. 'data': data
  223. }
  224. return result
  225. @router.post('/detail')
  226. @router.get('/detail')
  227. async def get_knowledge_detail(request: Request, db: Session = Depends(get_db), reportID: Optional[str] = None):
  228. # 尝试从请求体中解析 JSON 数据
  229. report_id_body=None
  230. try:
  231. data = await request.json()
  232. report_id_body = data.get('reportID') if data else None
  233. except:
  234. print('报错')
  235. pass
  236. # 确定要使用的 report_id
  237. report_id_to_use = reportID or report_id_body
  238. print(reportID)
  239. print(report_id_body)
  240. # 如果没有提供 report_id,则抛出异常
  241. if not report_id_to_use:
  242. raise HTTPException(status_code=400, detail="Missing required parameter 'reportID'")
  243. # 查询 KnowledgeBase
  244. kb_entry = db.query(KnowledgeBase).filter(KnowledgeBase.reportId == report_id_to_use).first()
  245. if not kb_entry:
  246. raise HTTPException(status_code=404, detail="No knowledge base found for the given report ID")
  247. # 查询相关的 KnowledgeFile
  248. kf_entries = db.query(KnowledgeFile).filter(KnowledgeFile.knowledge_base_code == kb_entry.base_code).all()
  249. # 准备返回的数据
  250. result = {
  251. "code": 200,
  252. "msg": "查询成功",
  253. "data": [{
  254. "report_id": kb_entry.reportId,
  255. "reportName": kb_entry.reportName,
  256. "subject": kb_entry.subject,
  257. "eventType": kb_entry.eventType,
  258. "publishDate": kb_entry.publishDate,
  259. "publishingUnit": kb_entry.publishingUnit,
  260. "summary": kb_entry.summary,
  261. "notificationType": kb_entry.notificationType,
  262. "file": [
  263. {"content": kf.file_name, "url": f'http://127.0.0.1:9988/api/file/download/{kf.file_name}'}
  264. for kf in kf_entries
  265. ]
  266. }]
  267. }
  268. return result