__init__.py 16 KB


  1. from fastapi import APIRouter, Request, Depends, HTTPException, Query,Response
  2. from sqlalchemy.exc import IntegrityError
  3. from common.security import valid_access_token
  4. from fastapi.responses import JSONResponse
  5. from database import get_db
  6. from sqlalchemy.orm import Session
  7. from models import *
  8. import json
  9. import random
  10. from sqlalchemy import create_engine, select, or_
  11. from typing import Optional
  12. import uuid
  13. router = APIRouter()
  14. @router.post('/create')
  15. async def create_knowledge(request: Request, db: Session = Depends(get_db), user_id=Depends(valid_access_token)):
  16. body = await request.json()
  17. required_fields = ['reportName', 'subject', 'eventType', 'publishingUnit', 'publishDate', 'summary', 'fileNames']
  18. if not all(field in body for field in required_fields):
  19. missing_fields = ", ".join([field for field in required_fields if field not in body])
  20. return Response(content=f"Missing required fields: {missing_fields}", status_code=400)
  21. reportId = f'ZJBG{random.randint(1000000000, 9999999999)}'
  22. base_code = 'base' + f'{random.randint(1000000000, 9999999999)}'
  23. uuid1 = uuid.uuid1()
  24. knowledge = KnowledgeBase(
  25. reportId=reportId,
  26. **{
  27. field: body[field]
  28. for field in required_fields
  29. if field != 'fileNames'
  30. },
  31. base_code=base_code,
  32. del_flag = 0,
  33. updateTime=body['publishDate'],
  34. notificationType="总结报告" # 硬编码的值,如果有其他情况需要处理,请修改
  35. )
  36. db.add(knowledge)
  37. knowledge_files = [
  38. KnowledgeFile(
  39. file_identifier=f'file{random.randint(1000000000, 9999999999)}',
  40. file_path=f'/data/upload/mergefile/{fileName["url"]}',
  41. storage_file_name=fileName["url"],
  42. file_name=fileName["name"], # 使用 fileName["name"] 作为文件名
  43. is_deleted=0,
  44. updateTime=body['publishDate'],
  45. createTime=body['publishDate'],
  46. knowledge_base_code=base_code
  47. )
  48. for fileName in body['fileNames'] # body['fileNames'] 现在是一个包含对象的数组,每个对象都有 'name' 和 'url' 属性
  49. ]
  50. db.add_all(knowledge_files)
  51. try:
  52. db.commit()
  53. knowledge_entry = db.query(KnowledgeBase).filter(KnowledgeBase.reportId == reportId).first()
  54. files_entry = db.query(KnowledgeFile).filter(KnowledgeFile.knowledge_base_code == base_code).all()
  55. return {
  56. "code": 200,
  57. "msg": "总结报告创建成功",
  58. "status": "success",
  59. "data": [{
  60. "reportId":knowledge_entry.reportId,
  61. "reportName":knowledge_entry.reportName,
  62. "eventType":knowledge_entry.eventType,
  63. "publishingUnit":knowledge_entry.publishingUnit,
  64. "subject":knowledge_entry.subject,
  65. "publishDate":knowledge_entry.publishDate,
  66. "summary":knowledge_entry.summary,
  67. "files": [
  68. {
  69. "fileIdentifier": file.file_identifier,
  70. "fileName": file.file_name,
  71. "url": file.storage_file_name,
  72. "status":"success"
  73. }
  74. for file in files_entry
  75. ]
  76. }]
  77. }
  78. except IntegrityError as e:
  79. db.rollback()
  80. return Response(content=f"Database error: {str(e)}", status_code=409)
  81. except Exception as e:
  82. db.rollback()
  83. return Response(content=f"Internal server error: {str(e)}", status_code=500)
  84. @router.get('/select')
  85. async def select_knowledge(
  86. db: Session = Depends(get_db),
  87. sortBy: str = Query(..., description="排序字段"),
  88. sortOrder: str = Query(..., description="排序顺序"),
  89. pageNum: int = Query(1, gt=0, description="页码"),
  90. pageSize: int = Query(10, gt=0, le=100, description="每页大小"),
  91. eventType: str = Query(None, description="事件类型"),
  92. publishDateRange: str = Query(None, description="发布日期范围"),
  93. query: str = Query(None, description="查询关键字", user_id=Depends(valid_access_token))
  94. ):
  95. data_query = db.query(KnowledgeBase)
  96. data_query = data_query.filter(KnowledgeBase.del_flag != '2')
  97. if eventType:
  98. data_query = data_query.filter(KnowledgeBase.eventType == eventType)
  99. if publishDateRange:
  100. start_date, end_date = publishDateRange.split('-')
  101. data_query = data_query.filter(KnowledgeBase.publishDate.between(start_date, end_date))
  102. if query:
  103. search_fields = [getattr(KnowledgeBase, field) for field in ('reportName', 'publishingUnit', 'reportId') if hasattr(KnowledgeBase, field)]
  104. search_conditions = [field.like(f'%{query}%') for field in search_fields]
  105. data_query = data_query.filter(or_(*search_conditions))
  106. if hasattr(KnowledgeBase, sortBy):
  107. sort_attr = getattr(KnowledgeBase, sortBy)
  108. data_query = data_query.order_by(sort_attr.asc() if sortOrder == 'asc' else sort_attr.desc())
  109. total_count = data_query.count()
  110. offset = (pageNum - 1) * pageSize
  111. data_query = data_query.offset(offset).limit(pageSize)
  112. fields = ['reportId', 'reportName', 'eventType', 'publishDate', 'publishingUnit', 'summary', 'subject', 'notificationType', 'base_code']
  113. entities = [getattr(KnowledgeBase, field) for field in fields if hasattr(KnowledgeBase, field)]
  114. data = data_query.with_entities(*entities).offset(offset).limit(pageSize).all()
  115. result_items = []
  116. for item in data:
  117. item_dict = {field: getattr(item, field) for field in fields}
  118. print(item_dict)
  119. base_code = item_dict['base_code']
  120. kf_entries = db.query(KnowledgeFile).filter(KnowledgeFile.knowledge_base_code == base_code).filter(KnowledgeFile.is_deleted != '2').all()
  121. item_dict['files'] = [
  122. {
  123. "fileIdentifier": file.file_identifier,
  124. "fileName": file.file_name,
  125. "url": file.storage_file_name,
  126. "status":"success"
  127. }
  128. for file in kf_entries
  129. ]
  130. result_items.append(item_dict)
  131. result = {
  132. "code": 200,
  133. 'msg': '查询成功',
  134. 'pages': (total_count + pageSize - 1) // pageSize,
  135. 'total': total_count,
  136. "currentPage": pageNum,
  137. "pageSize": pageSize,
  138. 'data': result_items
  139. }
  140. return result
  141. @router.get('/detail')
  142. async def get_knowledge_detail(db: Session = Depends(get_db), reportID: Optional[str] = Query(None, description="报告ID"), user_id=Depends(valid_access_token)):
  143. if not reportID:
  144. raise HTTPException(status_code=400, detail="Missing required parameter 'reportID'")
  145. # kb_entry = db.query(KnowledgeBase)
  146. # kb_entry = kb_entry.filter(KnowledgeBase.del_flag != '2')
  147. # 先过滤出 del_flag 不等于 '2' 的记录,然后再根据 report_id 筛选
  148. kb_entry = (db.query(KnowledgeBase)
  149. .filter(KnowledgeBase.del_flag != '2') # 确保这里过滤掉 del_flag 等于 '2' 的记录
  150. .filter(KnowledgeBase.reportId == reportID)) # 确保 reportID 是在过滤后的查询中使用的
  151. kb_entry = kb_entry.first()
  152. if not kb_entry:
  153. raise HTTPException(status_code=404, detail="报告不存在")
  154. kf_entries = db.query(KnowledgeFile).filter(KnowledgeFile.knowledge_base_code == kb_entry.base_code)\
  155. .filter(KnowledgeFile.is_deleted != '2').all()
  156. files = [
  157. {"content": kf.file_name, "url": kf.storage_file_name,"file_identifier":kf.file_identifier}
  158. for kf in kf_entries
  159. ]
  160. '''
  161. files = [
  162. {
  163. "fileIdentifier": file.file_identifier,
  164. "fileName": file.file_name,
  165. "url": file.storage_file_name,
  166. "status":"success"
  167. }
  168. for file in kf_entries
  169. ]
  170. '''
  171. result = {
  172. "code": 200,
  173. "msg": "查询成功",
  174. "data": [{
  175. "report_id": kb_entry.reportId,
  176. "reportName": kb_entry.reportName,
  177. "subject": kb_entry.subject,
  178. "eventType": kb_entry.eventType,
  179. "publishDate": kb_entry.publishDate,
  180. "publishingUnit": kb_entry.publishingUnit,
  181. "summary": kb_entry.summary,
  182. "notificationType": kb_entry.notificationType,
  183. "file": files
  184. }]
  185. }
  186. return result
  187. @router.post ('/delete')
  188. async def delete_knowledge(request: Request, db: Session = Depends(get_db), user_id=Depends(valid_access_token)):
  189. # 从请求的 JSON 数据中获取 reportID
  190. body = await request.json()
  191. report_id_to_use = body.get('reportID')
  192. if not report_id_to_use:
  193. return Response(content="Missing required parameter 'reportID'", status_code=400)
  194. kb_entry = (db.query(KnowledgeBase)
  195. .filter(KnowledgeBase.del_flag != '2') # 确保这里过滤掉 del_flag 等于 '2' 的记录
  196. .filter(KnowledgeBase.reportId == report_id_to_use)) # 确保 reportID 是在过滤后的查询中使用的
  197. kb_entry = kb_entry.first()
  198. if not kb_entry:
  199. raise HTTPException(status_code=404, detail="报告不存在")
  200. # 将找到的记录的 is_deleted 改为 2
  201. kb_entry.is_deleted = 2
  202. kf_entries = db.query(KnowledgeFile).filter(KnowledgeFile.knowledge_base_code == kb_entry.base_code).all()
  203. for kf_entry in kf_entries:
  204. kf_entry.is_deleted = 2
  205. if hasattr(kb_entry, 'del_flag'):
  206. kb_entry.del_flag = 2
  207. try:
  208. db.commit()
  209. return {
  210. "code": 200,
  211. "msg": "操作成功",
  212. "data": {
  213. "report_id": kb_entry.reportId
  214. }
  215. }
  216. except Exception as e:
  217. db.rollback()
  218. return Response(content="An error occurred while deleting the record: " + str(e), status_code=500)
  219. @router.delete('/delete/list')
  220. async def delete_knowledge_list(request: Request, db: Session = Depends(get_db), user_id=Depends(valid_access_token)):
  221. # 从请求的 JSON 数据中获取 reportID
  222. body = await request.json()
  223. report_id_to_use = body.get('reportID')
  224. if not report_id_to_use:
  225. return Response(content="Missing required parameter 'reportID'", status_code=400)
  226. # kb_entry = (db.query(KnowledgeBase)
  227. # .filter(KnowledgeBase.del_flag != '2') # 确保这里过滤掉 del_flag 等于 '2' 的记录
  228. # .filter(KnowledgeBase.reportId.in_report_(id_to_use))) # 确保 reportID 是在过滤后的查询中使用的
  229. #
  230. # kb_entrys = kb_entry.all()
  231. query = db.query(KnowledgeBase)
  232. query = query.filter(KnowledgeBase.del_flag != '2')
  233. query = query.filter(KnowledgeBase.reportId.in_(report_id_to_use))
  234. kb_entrys = query.all()
  235. if not kb_entrys:
  236. raise HTTPException(status_code=404, detail="报告不存在")
  237. for kb_entry in kb_entrys:
  238. # 将找到的记录的 is_deleted 改为 2
  239. kb_entry.is_deleted = 2
  240. kf_entries = db.query(KnowledgeFile).filter(KnowledgeFile.knowledge_base_code == kb_entry.base_code).all()
  241. for kf_entry in kf_entries:
  242. kf_entry.is_deleted = 2
  243. if hasattr(kb_entry, 'del_flag'):
  244. kb_entry.del_flag = 2
  245. try:
  246. db.commit()
  247. return {
  248. "code": 200,
  249. "msg": "操作成功",
  250. "data": None
  251. }
  252. except Exception as e:
  253. db.rollback()
  254. return Response(content="An error occurred while deleting the record: " + str(e), status_code=500)
  255. def delete_file_fun(knowledge_base_code,db: Session):
  256. file_query = db.query(KnowledgeFile)
  257. file_query = file_query.filter(KnowledgeFile.is_deleted != '2')
  258. file_query = file_query.filter(KnowledgeFile.knowledge_base_code == knowledge_base_code)
  259. # file_query = file_query.filter(KnowledgeFile.foreign_key == foreign_key)
  260. files = file_query.all()
  261. for file in files:
  262. file.is_deleted='2'
  263. db.commit()
  264. @router.post('/update')
  265. async def update_knowledge(request: Request, db: Session = Depends(get_db), user_id=Depends(valid_access_token)):
  266. body = await request.json()
  267. report_id_to_use = body.get('reportId')
  268. if not report_id_to_use:
  269. return Response(content="Missing required parameter 'reportId'", status_code=400)
  270. kb_entry = (db.query(KnowledgeBase)
  271. .filter(KnowledgeBase.del_flag != '2') # 确保这里过滤掉 del_flag 等于 '2' 的记录
  272. .filter(KnowledgeBase.reportId == report_id_to_use)) # 确保 reportID 是在过滤后的查询中使用的
  273. kb_entry = kb_entry.first()
  274. if not kb_entry:
  275. raise HTTPException(status_code=404, detail="报告不存在")
  276. kb_entry.reportName = body.get('reportName', kb_entry.reportName)
  277. kb_entry.subject = body.get('subject', kb_entry.subject)
  278. kb_entry.eventType = body.get('eventType', kb_entry.eventType)
  279. kb_entry.publishingUnit = body.get('publishingUnit', kb_entry.publishingUnit)
  280. kb_entry.publishDate = body.get('publishDate', kb_entry.publishDate)
  281. kb_entry.summary = body.get('summary', kb_entry.summary)
  282. kb_entry.updateTime = datetime.strptime(body.get('publishDate', kb_entry.updateTime), '%Y-%m-%d %H:%M:%S')
  283. base_code = kb_entry.base_code
  284. if len(body.get('fileNames')) > 0:
  285. if kb_entry.base_code:
  286. delete_file_fun(kb_entry.base_code,db=db)
  287. knowledge_files = [
  288. KnowledgeFile(
  289. file_identifier=f'file{random.randint(1000000000, 9999999999)}',
  290. file_path=f'/data/upload/mergefile/{fileName["url"]}',
  291. storage_file_name=fileName["url"],
  292. file_name=fileName["name"], # 使用 fileName["name"] 作为文件名
  293. is_deleted=0,
  294. updateTime=body['publishDate'],
  295. knowledge_base_code=base_code
  296. )
  297. for fileName in body['fileNames'] # body['fileNames'] 现在是一个包含对象的数组,每个对象都有 'name' 和 'url' 属性
  298. ]
  299. db.add_all(knowledge_files)
  300. try:
  301. db.commit()
  302. return {
  303. "code": 200,
  304. "msg": "修改成功",
  305. "data": {
  306. "report_id": report_id_to_use,
  307. "updateTime": kb_entry.updateTime.isoformat()
  308. }
  309. }
  310. except IntegrityError as e:
  311. db.rollback()
  312. return Response(content=f"Database error: {str(e)}", status_code=409)
  313. except Exception as e:
  314. db.rollback()
  315. return Response(content=str(e), status_code=500)
  316. @router.put('/update/list')
  317. async def update_knowledge_list(request: Request, db: Session = Depends(get_db), user_id=Depends(valid_access_token)):
  318. body = await request.json()
  319. update_items = body.get('updateItems')
  320. if not update_items or not isinstance(update_items, list):
  321. return Response(content="Missing required parameter 'updateItems'", status_code=400)
  322. try:
  323. for item in update_items:
  324. report_id_to_use = item.get('reportId')
  325. if not report_id_to_use:
  326. continue # 如果没有提供 reportId,则跳过当前项
  327. # 根据 reportId 查找 KnowledgeBase 记录
  328. kb_entry = db.query(KnowledgeBase).filter(
  329. KnowledgeBase.reportId == report_id_to_use
  330. ).first()
  331. if not kb_entry:
  332. continue # 如果没有找到记录,则跳过当前项
  333. # 更新 KnowledgeBase 记录
  334. for field, value in item.items():
  335. if field != 'reportId' and hasattr(kb_entry, field):
  336. setattr(kb_entry, field, value)
  337. kb_entry.updateTime = datetime.fromisoformat(item.get('publishDate'))
  338. # 提交事务
  339. db.commit()
  340. return {
  341. "code": 200,
  342. "msg": "批量更新成功",
  343. "data": {
  344. "updated_count": len(update_items)
  345. }
  346. }
  347. except IntegrityError as e:
  348. db.rollback()
  349. return Response(content=f"Database error: {str(e)}", status_code=409)
  350. except Exception as e:
  351. db.rollback()
  352. return Response(content=str(e), status_code=500)