__init__.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567
  1. from fastapi import APIRouter, Request, Depends, HTTPException, Query,Response
  2. from sqlalchemy.exc import IntegrityError
  3. from common.security import valid_access_token
  4. from fastapi.responses import JSONResponse,StreamingResponse
  5. from common.db import db_czrz
  6. from common.auth_user import *
  7. from database import get_db
  8. from sqlalchemy.orm import Session
  9. from models import *
  10. from utils.ry_system_util import *
  11. import json
  12. import random
  13. from sqlalchemy import create_engine, select, or_
  14. from typing import Optional
  15. import uuid
  16. import traceback
  17. router = APIRouter()
  18. @router.post('/create')
  19. async def create_knowledge(request: Request, db: Session = Depends(get_db), user_id=Depends(valid_access_token)):
  20. body = await request.json()
  21. required_fields = ['reportName', 'subject', 'eventType', 'publishingUnit', 'publishDate', 'summary', 'fileNames']
  22. if not all(field in body for field in required_fields):
  23. missing_fields = ", ".join([field for field in required_fields if field not in body])
  24. return Response(content=f"Missing required fields: {missing_fields}", status_code=400)
  25. reportId = f'ZJBG{random.randint(1000000000, 9999999999)}'
  26. base_code = 'base' + f'{random.randint(1000000000, 9999999999)}'
  27. uuid1 = uuid.uuid1()
  28. knowledge = KnowledgeBase(
  29. reportId=reportId,
  30. **{
  31. field: body[field]
  32. for field in required_fields
  33. if field != 'fileNames'
  34. },
  35. base_code=base_code,
  36. del_flag = 0,
  37. updateTime=body['publishDate'],
  38. notificationType="总结报告" # 硬编码的值,如果有其他情况需要处理,请修改
  39. )
  40. db.add(knowledge)
  41. knowledge_files = [
  42. KnowledgeFile(
  43. file_identifier=f'file{random.randint(1000000000, 9999999999)}',
  44. file_path=f'/data/upload/mergefile/{fileName["url"]}',
  45. storage_file_name=fileName["url"],
  46. file_name=fileName["name"], # 使用 fileName["name"] 作为文件名
  47. is_deleted=0,
  48. updateTime=body['publishDate'],
  49. createTime=body['publishDate'],
  50. knowledge_base_code=base_code
  51. )
  52. for fileName in body['fileNames'] # body['fileNames'] 现在是一个包含对象的数组,每个对象都有 'name' 和 'url' 属性
  53. ]
  54. db.add_all(knowledge_files)
  55. try:
  56. db.commit()
  57. knowledge_entry = db.query(KnowledgeBase).filter(KnowledgeBase.reportId == reportId).first()
  58. files_entry = db.query(KnowledgeFile).filter(KnowledgeFile.knowledge_base_code == base_code).all()
  59. return {
  60. "code": 200,
  61. "msg": "总结报告创建成功",
  62. "status": "success",
  63. "data": [{
  64. "reportId":knowledge_entry.reportId,
  65. "reportName":knowledge_entry.reportName,
  66. "eventType":knowledge_entry.eventType,
  67. "publishingUnit":knowledge_entry.publishingUnit,
  68. "subject":knowledge_entry.subject,
  69. "publishDate":knowledge_entry.publishDate,
  70. "summary":knowledge_entry.summary,
  71. "files": [
  72. {
  73. "fileIdentifier": file.file_identifier,
  74. "fileName": file.file_name,
  75. "url": file.storage_file_name,
  76. "status":"success"
  77. }
  78. for file in files_entry
  79. ]
  80. }]
  81. }
  82. except IntegrityError as e:
  83. db.rollback()
  84. return Response(content=f"Database error: {str(e)}", status_code=409)
  85. except Exception as e:
  86. db.rollback()
  87. return Response(content=f"Internal server error: {str(e)}", status_code=500)
  88. @router.get('/select')
  89. async def select_knowledge(
  90. db: Session = Depends(get_db),
  91. sortBy: str = Query(..., description="排序字段"),
  92. sortOrder: str = Query(..., description="排序顺序"),
  93. pageNum: int = Query(1, gt=0, description="页码"),
  94. pageSize: int = Query(10, gt=0, le=100, description="每页大小"),
  95. eventType: str = Query(None, description="事件类型"),
  96. publishDateRange: str = Query(None, description="发布日期范围"),
  97. query: str = Query(None, description="查询关键字", user_id=Depends(valid_access_token))
  98. ):
  99. data_query = db.query(KnowledgeBase)
  100. data_query = data_query.filter(KnowledgeBase.del_flag != '2')
  101. if eventType:
  102. data_query = data_query.filter(KnowledgeBase.eventType == eventType)
  103. if publishDateRange:
  104. start_date, end_date = publishDateRange.split('-')
  105. data_query = data_query.filter(KnowledgeBase.publishDate.between(start_date, end_date))
  106. if query:
  107. search_fields = [getattr(KnowledgeBase, field) for field in ('reportName', 'publishingUnit', 'reportId') if hasattr(KnowledgeBase, field)]
  108. search_conditions = [field.like(f'%{query}%') for field in search_fields]
  109. data_query = data_query.filter(or_(*search_conditions))
  110. if hasattr(KnowledgeBase, sortBy):
  111. sort_attr = getattr(KnowledgeBase, sortBy)
  112. data_query = data_query.order_by(sort_attr.asc() if sortOrder == 'asc' else sort_attr.desc())
  113. total_count = data_query.count()
  114. offset = (pageNum - 1) * pageSize
  115. data_query = data_query.offset(offset).limit(pageSize)
  116. fields = ['reportId', 'reportName', 'eventType', 'publishDate', 'publishingUnit', 'summary', 'subject', 'notificationType', 'base_code']
  117. entities = [getattr(KnowledgeBase, field) for field in fields if hasattr(KnowledgeBase, field)]
  118. data = data_query.with_entities(*entities).offset(offset).limit(pageSize).all()
  119. result_items = []
  120. for item in data:
  121. item_dict = {field: getattr(item, field) for field in fields}
  122. print(item_dict)
  123. base_code = item_dict['base_code']
  124. kf_entries = db.query(KnowledgeFile).filter(KnowledgeFile.knowledge_base_code == base_code).filter(KnowledgeFile.is_deleted != '2').all()
  125. item_dict['files'] = [
  126. {
  127. "fileIdentifier": file.file_identifier,
  128. "fileName": file.file_name,
  129. "url": file.storage_file_name,
  130. "status":"success"
  131. }
  132. for file in kf_entries
  133. ]
  134. result_items.append(item_dict)
  135. result = {
  136. "code": 200,
  137. 'msg': '查询成功',
  138. 'pages': (total_count + pageSize - 1) // pageSize,
  139. 'total': total_count,
  140. "currentPage": pageNum,
  141. "pageSize": pageSize,
  142. 'data': result_items
  143. }
  144. return result
  145. @router.get('/export')
  146. async def select_knowledge(
  147. request: Request,
  148. db: Session = Depends(get_db),
  149. sortBy: str = Query(..., description="排序字段"),
  150. sortOrder: str = Query(..., description="排序顺序"),
  151. pageNum: int = Query(1, gt=0, description="页码"),
  152. pageSize: int = Query(10, gt=0, le=100, description="每页大小"),
  153. eventType: str = Query(None, description="事件类型"),
  154. publishDateRange: str = Query(None, description="发布日期范围"),
  155. auth_user: AuthUser = Depends(find_auth_user),
  156. query: str = Query(None, description="查询关键字", user_id=Depends(valid_access_token))
  157. ):
  158. try:
  159. data_query = db.query(KnowledgeBase)
  160. data_query = data_query.filter(KnowledgeBase.del_flag != '2')
  161. if eventType:
  162. data_query = data_query.filter(KnowledgeBase.eventType == eventType)
  163. if publishDateRange:
  164. start_date, end_date = publishDateRange.split('-')
  165. data_query = data_query.filter(KnowledgeBase.publishDate.between(start_date, end_date))
  166. if query:
  167. search_fields = [getattr(KnowledgeBase, field) for field in ('reportName', 'publishingUnit', 'reportId') if
  168. hasattr(KnowledgeBase, field)]
  169. search_conditions = [field.like(f'%{query}%') for field in search_fields]
  170. data_query = data_query.filter(or_(*search_conditions))
  171. if hasattr(KnowledgeBase, sortBy):
  172. sort_attr = getattr(KnowledgeBase, sortBy)
  173. data_query = data_query.order_by(sort_attr.asc() if sortOrder == 'asc' else sort_attr.desc())
  174. total_count = data_query.count()
  175. offset = (pageNum - 1) * pageSize
  176. # data_query = data_query.offset(offset).limit(pageSize)
  177. fields = ['reportId', 'reportName', 'eventType', 'publishDate', 'publishingUnit', 'summary', 'subject',
  178. 'notificationType', 'base_code']
  179. entities = [getattr(KnowledgeBase, field) for field in fields if hasattr(KnowledgeBase, field)]
  180. data = data_query.with_entities(*entities).all()
  181. # mm_event_type_list = dict_type_get_dict_data_info(db,'mm_event_type')
  182. mm_event_type = {}
  183. for i in dict_type_get_dict_data_info(db,'mm_event_type'):
  184. mm_event_type[i.dict_value]=i.dict_label
  185. result_items = []
  186. for item in data:
  187. item_dict = {field: getattr(item, field) for field in fields}
  188. # print(item_dict)
  189. # base_code = item_dict['base_code']
  190. # kf_entries = db.query(KnowledgeFile).filter(KnowledgeFile.knowledge_base_code == base_code).filter(
  191. # KnowledgeFile.is_deleted != '2').all()
  192. # item_dict['files'] = [
  193. # {
  194. # "fileIdentifier": file.file_identifier,
  195. # "fileName": file.file_name,
  196. # "url": file.storage_file_name,
  197. # "status": "success"
  198. # }
  199. # for file in kf_entries
  200. # ]
  201. item_dict = {"报告编号":item_dict["reportId"],
  202. "报告名称":item_dict["reportName"],
  203. "主题词":item_dict["subject"],
  204. "事件类型":mm_event_type[item_dict["eventType"]],
  205. "摘要":item_dict["summary"],
  206. "来源单位":item_dict["publishingUnit"],
  207. "发布日期":item_dict["publishDate"],
  208. "知识类型":item_dict["notificationType"]}
  209. result_items.append(item_dict)
  210. import pandas as pd
  211. from io import BytesIO
  212. # 将查询结果转换为 DataFrame
  213. df = pd.DataFrame(result_items)
  214. # 将 DataFrame 导出为 Excel 文件
  215. output = BytesIO()
  216. with pd.ExcelWriter(output, engine='openpyxl') as writer:
  217. df.to_excel(writer, index=False)
  218. # 设置响应头
  219. output.seek(0)
  220. from urllib.parse import quote
  221. encoded_filename = f'知识管理{datetime.now().strftime("%Y%m%d%H%mi%s")}.xlsx'
  222. encoded_filename = quote(encoded_filename, encoding='utf-8')
  223. headers = {
  224. 'Content-Disposition': f'attachment; filename*=UTF-8\'\'{encoded_filename}',
  225. 'Content-Type': 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet'
  226. }
  227. db_czrz.log(db, auth_user, "知识库管理", f"知识管理导出数据成功", request.client.host)
  228. # 返回文件流
  229. return StreamingResponse(output, headers=headers)
  230. # result = {
  231. # "code": 200,
  232. # 'msg': '查询成功',
  233. # 'pages': (total_count + pageSize - 1) // pageSize,
  234. # 'total': total_count,
  235. # "currentPage": pageNum,
  236. # "pageSize": pageSize,
  237. # 'data': result_items
  238. # }
  239. #
  240. # return result
  241. except Exception as e:
  242. traceback.print_exc()
  243. # 处理异常
  244. db.rollback()
  245. raise HTTPException(status_code=500, detail=str(e))
  246. @router.get('/detail')
  247. async def get_knowledge_detail(db: Session = Depends(get_db), reportID: Optional[str] = Query(None, description="报告ID"), user_id=Depends(valid_access_token)):
  248. if not reportID:
  249. raise HTTPException(status_code=400, detail="Missing required parameter 'reportID'")
  250. # kb_entry = db.query(KnowledgeBase)
  251. # kb_entry = kb_entry.filter(KnowledgeBase.del_flag != '2')
  252. # 先过滤出 del_flag 不等于 '2' 的记录,然后再根据 report_id 筛选
  253. kb_entry = (db.query(KnowledgeBase)
  254. .filter(KnowledgeBase.del_flag != '2') # 确保这里过滤掉 del_flag 等于 '2' 的记录
  255. .filter(KnowledgeBase.reportId == reportID)) # 确保 reportID 是在过滤后的查询中使用的
  256. kb_entry = kb_entry.first()
  257. if not kb_entry:
  258. raise HTTPException(status_code=404, detail="报告不存在")
  259. kf_entries = db.query(KnowledgeFile).filter(KnowledgeFile.knowledge_base_code == kb_entry.base_code)\
  260. .filter(KnowledgeFile.is_deleted != '2').all()
  261. files = [
  262. {"content": kf.file_name, "url": kf.storage_file_name,"file_identifier":kf.file_identifier}
  263. for kf in kf_entries
  264. ]
  265. '''
  266. files = [
  267. {
  268. "fileIdentifier": file.file_identifier,
  269. "fileName": file.file_name,
  270. "url": file.storage_file_name,
  271. "status":"success"
  272. }
  273. for file in kf_entries
  274. ]
  275. '''
  276. result = {
  277. "code": 200,
  278. "msg": "查询成功",
  279. "data": [{
  280. "report_id": kb_entry.reportId,
  281. "reportName": kb_entry.reportName,
  282. "subject": kb_entry.subject,
  283. "eventType": kb_entry.eventType,
  284. "publishDate": kb_entry.publishDate,
  285. "publishingUnit": kb_entry.publishingUnit,
  286. "summary": kb_entry.summary,
  287. "notificationType": kb_entry.notificationType,
  288. "file": files
  289. }]
  290. }
  291. return result
  292. @router.post ('/delete')
  293. async def delete_knowledge(request: Request, db: Session = Depends(get_db), user_id=Depends(valid_access_token)):
  294. # 从请求的 JSON 数据中获取 reportID
  295. body = await request.json()
  296. report_id_to_use = body.get('reportID')
  297. if not report_id_to_use:
  298. return Response(content="Missing required parameter 'reportID'", status_code=400)
  299. kb_entry = (db.query(KnowledgeBase)
  300. .filter(KnowledgeBase.del_flag != '2') # 确保这里过滤掉 del_flag 等于 '2' 的记录
  301. .filter(KnowledgeBase.reportId == report_id_to_use)) # 确保 reportID 是在过滤后的查询中使用的
  302. kb_entry = kb_entry.first()
  303. if not kb_entry:
  304. raise HTTPException(status_code=404, detail="报告不存在")
  305. # 将找到的记录的 is_deleted 改为 2
  306. kb_entry.is_deleted = 2
  307. kf_entries = db.query(KnowledgeFile).filter(KnowledgeFile.knowledge_base_code == kb_entry.base_code).all()
  308. for kf_entry in kf_entries:
  309. kf_entry.is_deleted = 2
  310. if hasattr(kb_entry, 'del_flag'):
  311. kb_entry.del_flag = 2
  312. try:
  313. db.commit()
  314. return {
  315. "code": 200,
  316. "msg": "操作成功",
  317. "data": {
  318. "report_id": kb_entry.reportId
  319. }
  320. }
  321. except Exception as e:
  322. db.rollback()
  323. return Response(content="An error occurred while deleting the record: " + str(e), status_code=500)
  324. @router.delete('/delete/list')
  325. async def delete_knowledge_list(request: Request, db: Session = Depends(get_db), user_id=Depends(valid_access_token)):
  326. # 从请求的 JSON 数据中获取 reportID
  327. body = await request.json()
  328. report_id_to_use = body.get('reportID')
  329. if not report_id_to_use:
  330. return Response(content="Missing required parameter 'reportID'", status_code=400)
  331. # kb_entry = (db.query(KnowledgeBase)
  332. # .filter(KnowledgeBase.del_flag != '2') # 确保这里过滤掉 del_flag 等于 '2' 的记录
  333. # .filter(KnowledgeBase.reportId.in_report_(id_to_use))) # 确保 reportID 是在过滤后的查询中使用的
  334. #
  335. # kb_entrys = kb_entry.all()
  336. query = db.query(KnowledgeBase)
  337. query = query.filter(KnowledgeBase.del_flag != '2')
  338. query = query.filter(KnowledgeBase.reportId.in_(report_id_to_use))
  339. kb_entrys = query.all()
  340. if not kb_entrys:
  341. raise HTTPException(status_code=404, detail="报告不存在")
  342. for kb_entry in kb_entrys:
  343. # 将找到的记录的 is_deleted 改为 2
  344. kb_entry.is_deleted = 2
  345. kf_entries = db.query(KnowledgeFile).filter(KnowledgeFile.knowledge_base_code == kb_entry.base_code).all()
  346. for kf_entry in kf_entries:
  347. kf_entry.is_deleted = 2
  348. if hasattr(kb_entry, 'del_flag'):
  349. kb_entry.del_flag = 2
  350. try:
  351. db.commit()
  352. return {
  353. "code": 200,
  354. "msg": "操作成功",
  355. "data": None
  356. }
  357. except Exception as e:
  358. db.rollback()
  359. return Response(content="An error occurred while deleting the record: " + str(e), status_code=500)
  360. def delete_file_fun(knowledge_base_code,db: Session):
  361. file_query = db.query(KnowledgeFile)
  362. file_query = file_query.filter(KnowledgeFile.is_deleted != '2')
  363. file_query = file_query.filter(KnowledgeFile.knowledge_base_code == knowledge_base_code)
  364. # file_query = file_query.filter(KnowledgeFile.foreign_key == foreign_key)
  365. files = file_query.all()
  366. for file in files:
  367. file.is_deleted='2'
  368. db.commit()
  369. @router.post('/update')
  370. async def update_knowledge(request: Request, db: Session = Depends(get_db), user_id=Depends(valid_access_token)):
  371. body = await request.json()
  372. report_id_to_use = body.get('reportId')
  373. if not report_id_to_use:
  374. return Response(content="Missing required parameter 'reportId'", status_code=400)
  375. kb_entry = (db.query(KnowledgeBase)
  376. .filter(KnowledgeBase.del_flag != '2') # 确保这里过滤掉 del_flag 等于 '2' 的记录
  377. .filter(KnowledgeBase.reportId == report_id_to_use)) # 确保 reportID 是在过滤后的查询中使用的
  378. kb_entry = kb_entry.first()
  379. if not kb_entry:
  380. raise HTTPException(status_code=404, detail="报告不存在")
  381. kb_entry.reportName = body.get('reportName', kb_entry.reportName)
  382. kb_entry.subject = body.get('subject', kb_entry.subject)
  383. kb_entry.eventType = body.get('eventType', kb_entry.eventType)
  384. kb_entry.publishingUnit = body.get('publishingUnit', kb_entry.publishingUnit)
  385. kb_entry.publishDate = body.get('publishDate', kb_entry.publishDate)
  386. kb_entry.summary = body.get('summary', kb_entry.summary)
  387. kb_entry.updateTime = datetime.strptime(body.get('publishDate', kb_entry.updateTime), '%Y-%m-%d %H:%M:%S')
  388. base_code = kb_entry.base_code
  389. if len(body.get('fileNames')) > 0:
  390. if kb_entry.base_code:
  391. delete_file_fun(kb_entry.base_code,db=db)
  392. knowledge_files = [
  393. KnowledgeFile(
  394. file_identifier=f'file{random.randint(1000000000, 9999999999)}',
  395. file_path=f'/data/upload/mergefile/{fileName["url"]}',
  396. storage_file_name=fileName["url"],
  397. file_name=fileName["name"], # 使用 fileName["name"] 作为文件名
  398. is_deleted=0,
  399. updateTime=body['publishDate'],
  400. knowledge_base_code=base_code
  401. )
  402. for fileName in body['fileNames'] # body['fileNames'] 现在是一个包含对象的数组,每个对象都有 'name' 和 'url' 属性
  403. ]
  404. db.add_all(knowledge_files)
  405. try:
  406. db.commit()
  407. return {
  408. "code": 200,
  409. "msg": "修改成功",
  410. "data": {
  411. "report_id": report_id_to_use,
  412. "updateTime": kb_entry.updateTime.isoformat()
  413. }
  414. }
  415. except IntegrityError as e:
  416. db.rollback()
  417. return Response(content=f"Database error: {str(e)}", status_code=409)
  418. except Exception as e:
  419. db.rollback()
  420. return Response(content=str(e), status_code=500)
  421. @router.put('/update/list')
  422. async def update_knowledge_list(request: Request, db: Session = Depends(get_db), user_id=Depends(valid_access_token)):
  423. body = await request.json()
  424. update_items = body.get('updateItems')
  425. if not update_items or not isinstance(update_items, list):
  426. return Response(content="Missing required parameter 'updateItems'", status_code=400)
  427. try:
  428. for item in update_items:
  429. report_id_to_use = item.get('reportId')
  430. if not report_id_to_use:
  431. continue # 如果没有提供 reportId,则跳过当前项
  432. # 根据 reportId 查找 KnowledgeBase 记录
  433. kb_entry = db.query(KnowledgeBase).filter(
  434. KnowledgeBase.reportId == report_id_to_use
  435. ).first()
  436. if not kb_entry:
  437. continue # 如果没有找到记录,则跳过当前项
  438. # 更新 KnowledgeBase 记录
  439. for field, value in item.items():
  440. if field != 'reportId' and hasattr(kb_entry, field):
  441. setattr(kb_entry, field, value)
  442. kb_entry.updateTime = datetime.fromisoformat(item.get('publishDate'))
  443. # 提交事务
  444. db.commit()
  445. return {
  446. "code": 200,
  447. "msg": "批量更新成功",
  448. "data": {
  449. "updated_count": len(update_items)
  450. }
  451. }
  452. except IntegrityError as e:
  453. db.rollback()
  454. return Response(content=f"Database error: {str(e)}", status_code=409)
  455. except Exception as e:
  456. db.rollback()
  457. return Response(content=str(e), status_code=500)