__init__.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456
  1. from fastapi import APIRouter, Request, Depends, HTTPException, Query,Response
  2. from sqlalchemy.exc import IntegrityError
  3. from fastapi.responses import JSONResponse
  4. from database import get_db
  5. from sqlalchemy.orm import Session
  6. from models import *
  7. import json
  8. import random
  9. from sqlalchemy import create_engine, select, or_
  10. from typing import Optional
  11. import uuid
  12. router = APIRouter()
  13. @router.post('/create')
  14. async def create_knowledge(request: Request, db: Session = Depends(get_db)):
  15. body = await request.json()
  16. required_fields = ['reportName', 'subject', 'eventType', 'publishingUnit', 'publishDate', 'summary', 'fileNames']
  17. if not all(field in body for field in required_fields):
  18. missing_fields = ", ".join([field for field in required_fields if field not in body])
  19. return Response(content=f"Missing required fields: {missing_fields}", status_code=400)
  20. reportId = f'ZJBG{random.randint(1000000000, 9999999999)}'
  21. base_code = 'base' + f'{random.randint(1000000000, 9999999999)}'
  22. uuid1 = uuid.uuid1()
  23. knowledge = KnowledgeBase(
  24. reportId=reportId,
  25. **{
  26. field: body[field]
  27. for field in required_fields
  28. if field != 'fileNames'
  29. },
  30. base_code=base_code,
  31. del_flag = 0,
  32. updateTime=body['publishDate'],
  33. notificationType="总结报告" # 硬编码的值,如果有其他情况需要处理,请修改
  34. )
  35. db.add(knowledge)
  36. knowledge_files = [
  37. KnowledgeFile(
  38. file_identifier=f'file{random.randint(1000000000, 9999999999)}',
  39. file_path=f'/data/upload/mergefile/{fileName["url"]}',
  40. storage_file_name=fileName["url"],
  41. file_name=fileName["name"], # 使用 fileName["name"] 作为文件名
  42. is_deleted=0,
  43. updateTime=body['publishDate'],
  44. createTime=body['publishDate'],
  45. knowledge_base_code=base_code
  46. )
  47. for fileName in body['fileNames'] # body['fileNames'] 现在是一个包含对象的数组,每个对象都有 'name' 和 'url' 属性
  48. ]
  49. db.add_all(knowledge_files)
  50. try:
  51. db.commit()
  52. knowledge_entry = db.query(KnowledgeBase).filter(KnowledgeBase.reportId == reportId).first()
  53. files_entry = db.query(KnowledgeFile).filter(KnowledgeFile.knowledge_base_code == base_code).all()
  54. return {
  55. "code": 200,
  56. "msg": "总结报告创建成功",
  57. "status": "success",
  58. "data": [{
  59. "reportId":knowledge_entry.reportId,
  60. "reportName":knowledge_entry.reportName,
  61. "eventType":knowledge_entry.eventType,
  62. "publishingUnit":knowledge_entry.publishingUnit,
  63. "subject":knowledge_entry.subject,
  64. "publishDate":knowledge_entry.publishDate,
  65. "summary":knowledge_entry.summary,
  66. "files": [
  67. {
  68. "fileIdentifier": file.file_identifier,
  69. "fileName": file.file_name,
  70. "url": file.storage_file_name,
  71. "status":"success"
  72. }
  73. for file in files_entry
  74. ]
  75. }]
  76. }
  77. except IntegrityError as e:
  78. db.rollback()
  79. return Response(content=f"Database error: {str(e)}", status_code=409)
  80. except Exception as e:
  81. db.rollback()
  82. return Response(content=f"Internal server error: {str(e)}", status_code=500)
  83. @router.get('/select')
  84. async def select_knowledge(
  85. db: Session = Depends(get_db),
  86. sortBy: str = Query(..., description="排序字段"),
  87. sortOrder: str = Query(..., description="排序顺序"),
  88. pageNum: int = Query(1, gt=0, description="页码"),
  89. pageSize: int = Query(10, gt=0, le=100, description="每页大小"),
  90. eventType: str = Query(None, description="事件类型"),
  91. publishDateRange: str = Query(None, description="发布日期范围"),
  92. query: str = Query(None, description="查询关键字")
  93. ):
  94. data_query = db.query(KnowledgeBase)
  95. data_query = data_query.filter(KnowledgeBase.del_flag != '2')
  96. if eventType:
  97. data_query = data_query.filter(KnowledgeBase.eventType == eventType)
  98. if publishDateRange:
  99. start_date, end_date = publishDateRange.split('-')
  100. data_query = data_query.filter(KnowledgeBase.publishDate.between(start_date, end_date))
  101. if query:
  102. search_fields = [getattr(KnowledgeBase, field) for field in ('reportName', 'publishingUnit', 'reportId') if hasattr(KnowledgeBase, field)]
  103. search_conditions = [field.like(f'%{query}%') for field in search_fields]
  104. data_query = data_query.filter(or_(*search_conditions))
  105. if hasattr(KnowledgeBase, sortBy):
  106. sort_attr = getattr(KnowledgeBase, sortBy)
  107. data_query = data_query.order_by(sort_attr.asc() if sortOrder == 'asc' else sort_attr.desc())
  108. offset = (pageNum - 1) * pageSize
  109. data_query = data_query.offset(offset).limit(pageSize)
  110. fields = ['reportId', 'reportName', 'eventType', 'publishDate', 'publishingUnit', 'summary', 'subject', 'notificationType', 'base_code']
  111. entities = [getattr(KnowledgeBase, field) for field in fields if hasattr(KnowledgeBase, field)]
  112. data = data_query.with_entities(*entities).offset(offset).limit(pageSize).all()
  113. result_items = []
  114. for item in data:
  115. item_dict = {field: getattr(item, field) for field in fields}
  116. print(item_dict)
  117. base_code = item_dict['base_code']
  118. kf_entries = db.query(KnowledgeFile).filter(KnowledgeFile.knowledge_base_code == base_code).filter(KnowledgeFile.is_deleted != '2').all()
  119. item_dict['files'] = [
  120. {
  121. "fileIdentifier": file.file_identifier,
  122. "fileName": file.file_name,
  123. "url": file.storage_file_name,
  124. "status":"success"
  125. }
  126. for file in kf_entries
  127. ]
  128. result_items.append(item_dict)
  129. total_count = data_query.count()
  130. result = {
  131. "code": 200,
  132. 'msg': '查询成功',
  133. 'pages': (total_count + pageSize - 1) // pageSize,
  134. 'total': total_count,
  135. "currentPage": pageNum,
  136. "pageSize": pageSize,
  137. 'data': result_items
  138. }
  139. return result
  140. @router.get('/detail')
  141. async def get_knowledge_detail(db: Session = Depends(get_db), reportID: Optional[str] = Query(None, description="报告ID")):
  142. if not reportID:
  143. raise HTTPException(status_code=400, detail="Missing required parameter 'reportID'")
  144. # kb_entry = db.query(KnowledgeBase)
  145. # kb_entry = kb_entry.filter(KnowledgeBase.del_flag != '2')
  146. # 先过滤出 del_flag 不等于 '2' 的记录,然后再根据 report_id 筛选
  147. kb_entry = (db.query(KnowledgeBase)
  148. .filter(KnowledgeBase.del_flag != '2') # 确保这里过滤掉 del_flag 等于 '2' 的记录
  149. .filter(KnowledgeBase.reportId == reportID)) # 确保 reportID 是在过滤后的查询中使用的
  150. kb_entry = kb_entry.first()
  151. if not kb_entry:
  152. raise HTTPException(status_code=404, detail="报告不存在")
  153. kf_entries = db.query(KnowledgeFile).filter(KnowledgeFile.knowledge_base_code == kb_entry.base_code)\
  154. .filter(KnowledgeFile.is_deleted != '2').all()
  155. files = [
  156. {"content": kf.file_name, "url": kf.storage_file_name,"file_identifier":kf.file_identifier}
  157. for kf in kf_entries
  158. ]
  159. '''
  160. files = [
  161. {
  162. "fileIdentifier": file.file_identifier,
  163. "fileName": file.file_name,
  164. "url": file.storage_file_name,
  165. "status":"success"
  166. }
  167. for file in kf_entries
  168. ]
  169. '''
  170. result = {
  171. "code": 200,
  172. "msg": "查询成功",
  173. "data": [{
  174. "report_id": kb_entry.reportId,
  175. "reportName": kb_entry.reportName,
  176. "subject": kb_entry.subject,
  177. "eventType": kb_entry.eventType,
  178. "publishDate": kb_entry.publishDate,
  179. "publishingUnit": kb_entry.publishingUnit,
  180. "summary": kb_entry.summary,
  181. "notificationType": kb_entry.notificationType,
  182. "file": files
  183. }]
  184. }
  185. return result
  186. @router.post ('/delete')
  187. async def delete_knowledge(request: Request, db: Session = Depends(get_db)):
  188. # 从请求的 JSON 数据中获取 reportID
  189. body = await request.json()
  190. report_id_to_use = body.get('reportID')
  191. if not report_id_to_use:
  192. return Response(content="Missing required parameter 'reportID'", status_code=400)
  193. kb_entry = (db.query(KnowledgeBase)
  194. .filter(KnowledgeBase.del_flag != '2') # 确保这里过滤掉 del_flag 等于 '2' 的记录
  195. .filter(KnowledgeBase.reportId == report_id_to_use)) # 确保 reportID 是在过滤后的查询中使用的
  196. kb_entry = kb_entry.first()
  197. if not kb_entry:
  198. raise HTTPException(status_code=404, detail="报告不存在")
  199. # 将找到的记录的 is_deleted 改为 2
  200. kb_entry.is_deleted = 2
  201. kf_entries = db.query(KnowledgeFile).filter(KnowledgeFile.knowledge_base_code == kb_entry.base_code).all()
  202. for kf_entry in kf_entries:
  203. kf_entry.is_deleted = 2
  204. if hasattr(kb_entry, 'del_flag'):
  205. kb_entry.del_flag = 2
  206. try:
  207. db.commit()
  208. return {
  209. "code": 200,
  210. "msg": "操作成功",
  211. "data": {
  212. "report_id": kb_entry.reportId
  213. }
  214. }
  215. except Exception as e:
  216. db.rollback()
  217. return Response(content="An error occurred while deleting the record: " + str(e), status_code=500)
  218. @router.delete('/delete/list')
  219. async def delete_knowledge_list(request: Request, db: Session = Depends(get_db)):
  220. # 从请求的 JSON 数据中获取 reportID
  221. body = await request.json()
  222. report_id_to_use = body.get('reportID')
  223. if not report_id_to_use:
  224. return Response(content="Missing required parameter 'reportID'", status_code=400)
  225. # kb_entry = (db.query(KnowledgeBase)
  226. # .filter(KnowledgeBase.del_flag != '2') # 确保这里过滤掉 del_flag 等于 '2' 的记录
  227. # .filter(KnowledgeBase.reportId.in_report_(id_to_use))) # 确保 reportID 是在过滤后的查询中使用的
  228. #
  229. # kb_entrys = kb_entry.all()
  230. query = db.query(KnowledgeBase)
  231. query = query.filter(KnowledgeBase.del_flag != '2')
  232. query = query.filter(KnowledgeBase.reportId.in_(report_id_to_use))
  233. kb_entrys = query.all()
  234. if not kb_entrys:
  235. raise HTTPException(status_code=404, detail="报告不存在")
  236. for kb_entry in kb_entrys:
  237. # 将找到的记录的 is_deleted 改为 2
  238. kb_entry.is_deleted = 2
  239. kf_entries = db.query(KnowledgeFile).filter(KnowledgeFile.knowledge_base_code == kb_entry.base_code).all()
  240. for kf_entry in kf_entries:
  241. kf_entry.is_deleted = 2
  242. if hasattr(kb_entry, 'del_flag'):
  243. kb_entry.del_flag = 2
  244. try:
  245. db.commit()
  246. return {
  247. "code": 200,
  248. "msg": "操作成功",
  249. "data": None
  250. }
  251. except Exception as e:
  252. db.rollback()
  253. return Response(content="An error occurred while deleting the record: " + str(e), status_code=500)
  254. def delete_file_fun(knowledge_base_code,db: Session):
  255. file_query = db.query(KnowledgeFile)
  256. file_query = file_query.filter(KnowledgeFile.is_deleted != '2')
  257. file_query = file_query.filter(KnowledgeFile.knowledge_base_code == knowledge_base_code)
  258. # file_query = file_query.filter(KnowledgeFile.foreign_key == foreign_key)
  259. files = file_query.all()
  260. for file in files:
  261. file.is_deleted='2'
  262. db.commit()
  263. @router.post('/update')
  264. async def update_knowledge(request: Request, db: Session = Depends(get_db)):
  265. body = await request.json()
  266. report_id_to_use = body.get('reportId')
  267. if not report_id_to_use:
  268. return Response(content="Missing required parameter 'reportId'", status_code=400)
  269. kb_entry = (db.query(KnowledgeBase)
  270. .filter(KnowledgeBase.del_flag != '2') # 确保这里过滤掉 del_flag 等于 '2' 的记录
  271. .filter(KnowledgeBase.reportId == report_id_to_use)) # 确保 reportID 是在过滤后的查询中使用的
  272. kb_entry = kb_entry.first()
  273. if not kb_entry:
  274. raise HTTPException(status_code=404, detail="报告不存在")
  275. kb_entry.reportName = body.get('reportName', kb_entry.reportName)
  276. kb_entry.subject = body.get('subject', kb_entry.subject)
  277. kb_entry.eventType = body.get('eventType', kb_entry.eventType)
  278. kb_entry.publishingUnit = body.get('publishingUnit', kb_entry.publishingUnit)
  279. kb_entry.publishDate = body.get('publishDate', kb_entry.publishDate)
  280. kb_entry.summary = body.get('summary', kb_entry.summary)
  281. kb_entry.updateTime = datetime.strptime(body.get('publishDate', kb_entry.updateTime), '%Y-%m-%d %H:%M:%S')
  282. base_code = kb_entry.base_code
  283. if len(body.get('fileNames')) > 0:
  284. if kb_entry.base_code:
  285. delete_file_fun(kb_entry.base_code,db=db)
  286. knowledge_files = [
  287. KnowledgeFile(
  288. file_identifier=f'file{random.randint(1000000000, 9999999999)}',
  289. file_path=f'/data/upload/mergefile/{fileName["url"]}',
  290. storage_file_name=fileName["url"],
  291. file_name=fileName["name"], # 使用 fileName["name"] 作为文件名
  292. is_deleted=0,
  293. updateTime=body['publishDate'],
  294. knowledge_base_code=base_code
  295. )
  296. for fileName in body['fileNames'] # body['fileNames'] 现在是一个包含对象的数组,每个对象都有 'name' 和 'url' 属性
  297. ]
  298. db.add_all(knowledge_files)
  299. try:
  300. db.commit()
  301. return {
  302. "code": 200,
  303. "msg": "修改成功",
  304. "data": {
  305. "report_id": report_id_to_use,
  306. "updateTime": kb_entry.updateTime.isoformat()
  307. }
  308. }
  309. except IntegrityError as e:
  310. db.rollback()
  311. return Response(content=f"Database error: {str(e)}", status_code=409)
  312. except Exception as e:
  313. db.rollback()
  314. return Response(content=str(e), status_code=500)
  315. @router.put('/update/list')
  316. async def update_knowledge_list(request: Request, db: Session = Depends(get_db)):
  317. body = await request.json()
  318. update_items = body.get('updateItems')
  319. if not update_items or not isinstance(update_items, list):
  320. return Response(content="Missing required parameter 'updateItems'", status_code=400)
  321. try:
  322. for item in update_items:
  323. report_id_to_use = item.get('reportId')
  324. if not report_id_to_use:
  325. continue # 如果没有提供 reportId,则跳过当前项
  326. # 根据 reportId 查找 KnowledgeBase 记录
  327. kb_entry = db.query(KnowledgeBase).filter(
  328. KnowledgeBase.reportId == report_id_to_use
  329. ).first()
  330. if not kb_entry:
  331. continue # 如果没有找到记录,则跳过当前项
  332. # 更新 KnowledgeBase 记录
  333. for field, value in item.items():
  334. if field != 'reportId' and hasattr(kb_entry, field):
  335. setattr(kb_entry, field, value)
  336. kb_entry.updateTime = datetime.fromisoformat(item.get('publishDate'))
  337. # 提交事务
  338. db.commit()
  339. return {
  340. "code": 200,
  341. "msg": "批量更新成功",
  342. "data": {
  343. "updated_count": len(update_items)
  344. }
  345. }
  346. except IntegrityError as e:
  347. db.rollback()
  348. return Response(content=f"Database error: {str(e)}", status_code=409)
  349. except Exception as e:
  350. db.rollback()
  351. return Response(content=str(e), status_code=500)