123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173 |
- from fastapi import APIRouter, Request, Depends, HTTPException, Query
- from sqlalchemy.exc import IntegrityError
- from fastapi.responses import JSONResponse
- from database import get_db
- from sqlalchemy.orm import Session
- from models import *
- import json
- import random
- from sqlalchemy import create_engine, select, or_
- from typing import Optional
- router = APIRouter()
- @router.post('/create')
- async def create_knowledge(request: Request, db: Session = Depends(get_db)):
- body = await request.json()
- required_fields = ['reportName', 'subject', 'eventType', 'publishingUnit', 'publishDate', 'summary', 'fileNames']
- if not all(field in body for field in required_fields):
- missing_fields = ", ".join([field for field in required_fields if field not in body])
- return Response(content=f"Missing required fields: {missing_fields}", status_code=400)
- reportId = f'ZJBG{random.randint(1000000000, 9999999999)}'
- base_code = 'base' + f'{random.randint(1000000000, 9999999999)}'
- knowledge = KnowledgeBase(
- reportId=reportId,
- **{
- field: body[field]
- for field in required_fields
- if field != 'fileNames'
- },
- base_code=base_code,
- del_flag = 0,
- updateTime=body['publishDate'],
- notificationType="总结报告" # 硬编码的值,如果有其他情况需要处理,请修改
- )
- db.add(knowledge)
- knowledge_files = [
- KnowledgeFile(
- file_identifier=f'file{random.randint(1000000000, 9999999999)}',
- file_path=f'/data/upload/mergefile/{fileName}', # 假设fileName是文件名
- file_name=fileName,
- is_deleted=0,
- updateTime=body['publishDate'],
- createTime=body['publishDate'],
- knowledge_base_code=base_code
- )
- for fileName in body['fileNames']
- ]
- db.add_all(knowledge_files)
- try:
- db.commit()
- return {
- "code": 200,
- "msg": "总结报告创建成功",
- "status": "success",
- "data": [reportId]
- }
- except IntegrityError as e:
- db.rollback()
- return Response(content=f"Database error: {str(e)}", status_code=409)
- except Exception as e:
- db.rollback()
- return Response(content=f"Internal server error: {str(e)}", status_code=500)
- @router.get('/select')
- async def select_knowledge(
- db: Session = Depends(get_db),
- sortBy: str = Query(..., description="排序字段"),
- sortOrder: str = Query(..., description="排序顺序"),
- pageNum: int = Query(1, gt=0, description="页码"),
- pageSize: int = Query(10, gt=0, le=100, description="每页大小"),
- eventType: str = Query(None, description="事件类型"),
- publishDateRange: str = Query(None, description="发布日期范围"),
- query: str = Query(None, description="查询关键字")
- ):
- data_query = db.query(KnowledgeBase)
- data_query = data_query.filter(KnowledgeBase.del_flag != '2')
- if eventType:
- data_query = data_query.filter(KnowledgeBase.eventType == eventType)
- if publishDateRange:
- start_date, end_date = publishDateRange.split('-')
- data_query = data_query.filter(KnowledgeBase.publishDate.between(start_date, end_date))
- if query:
- search_fields = [getattr(KnowledgeBase, field) for field in ('reportName', 'publishingUnit', 'reportId') if hasattr(KnowledgeBase, field)]
- search_conditions = [field.like(f'%{query}%') for field in search_fields]
- data_query = data_query.filter(or_(*search_conditions))
- if hasattr(KnowledgeBase, sortBy):
- sort_attr = getattr(KnowledgeBase, sortBy)
- data_query = data_query.order_by(sort_attr.asc() if sortOrder == 'asc' else sort_attr.desc())
- offset = (pageNum - 1) * pageSize
- data_query = data_query.offset(offset).limit(pageSize)
- fields = ['reportId', 'reportName', 'eventType', 'publishDate', 'publishingUnit', 'summary', 'subject', 'notificationType']
- entities = [getattr(KnowledgeBase, field) for field in fields if hasattr(KnowledgeBase, field)]
- data = data_query.with_entities(*entities).offset(offset).limit(pageSize).all()
- result_items = []
- for item in data:
- item_dict = {field: getattr(item, field) for field in fields}
- result_items.append(item_dict)
- total_count = data_query.count()
- result = {
- "code": 200,
- 'msg': '查询成功',
- 'pages': (total_count + pageSize - 1) // pageSize,
- 'total': total_count,
- "currentPage": pageNum,
- "pageSize": pageSize,
- 'data': result_items
- }
- return result
- @router.get('/detail')
- async def get_knowledge_detail(db: Session = Depends(get_db), reportID: Optional[str] = Query(None, description="报告ID")):
- report_id_to_use = reportID
- if not report_id_to_use:
- raise HTTPException(status_code=400, detail="Missing required parameter 'reportID'")
- kb_entry = db.query(KnowledgeBase).filter(KnowledgeBase.reportId == report_id_to_use).first()
- if not kb_entry:
- raise HTTPException(status_code=404, detail="No knowledge base found for the given report ID")
- kf_entries = db.query(KnowledgeFile).filter(KnowledgeFile.knowledge_base_code == kb_entry.base_code).all()
- files = [
- {"content": kf.file_name, "url": f'http://127.0.0.1:9988/api/file/download/{kf.file_name}'}
- for kf in kf_entries
- ]
- result = {
- "code": 200,
- "msg": "查询成功",
- "data": [{
- "report_id": kb_entry.reportId,
- "reportName": kb_entry.reportName,
- "subject": kb_entry.subject,
- "eventType": kb_entry.eventType,
- "publishDate": kb_entry.publishDate,
- "publishingUnit": kb_entry.publishingUnit,
- "summary": kb_entry.summary,
- "notificationType": kb_entry.notificationType,
- "file": files
- }]
- }
- return result
|