Kaynağa Gözat

241110-1代码。

baoyubo 6 ay önce
ebeveyn
işleme
890e614323
2 değiştirilmiş dosya ile 320 ekleme ve 59 silme
  1. 32 1
      models/pattern_base.py
  2. 288 58
      routers/api/pattern/__init__.py

+ 32 - 1
models/pattern_base.py

@@ -7,7 +7,7 @@ from datetime import datetime
 class TpPatternList(Base):
     __tablename__ = 'tp_pattern_list'
 
-    id = Column(Integer, autoincrement=True, primary_key=True)
+    id = Column(String(255),  primary_key=True)
     pattern_name = Column(String(255), nullable=True, comment='图案名称')
     content = Column(JSON, nullable=True, comment='图案json')
     del_flag = Column(String(1), default='0', comment='删除标志(0代表存在 2代表删除)')
@@ -15,5 +15,36 @@ class TpPatternList(Base):
     update_time = Column(DateTime, default=datetime.now, onupdate=datetime.now, comment='数据更新时间')
     create_dept = Column(BigInteger, default=None, comment='创建部门')
     create_by = Column(BigInteger, default=None, comment='创建者')
+    class Config:
+        orm_mode = True
+
+class TpPatternWSList(Base):
+    __tablename__ = 'tp_pattern_ws_list'
+
+    id = Column(String(255), primary_key=True)
+    pattern_id = Column(String(255), nullable=True, comment='图案id')
+    content = Column(JSON, nullable=True, comment='图案json')
+    del_flag = Column(String(1), default='0', comment='删除标志(0代表存在 2代表删除)')
+    create_time = Column(DateTime, default=datetime.now, comment='数据创建时间')
+    update_time = Column(DateTime, default=datetime.now, onupdate=datetime.now, comment='数据更新时间')
+    create_dept = Column(BigInteger, default=None, comment='创建部门')
+    create_by = Column(BigInteger, default=None, comment='创建者')
+    update_by = Column(BigInteger, default=None, comment='更新者')
+    class Config:
+        orm_mode = True
+
+class TpPatternWSUserList(Base):
+    __tablename__ = 'tp_pattern_ws_user_list'
+
+    id = Column(String(255), primary_key=True)
+    pattern_id = Column(String(255), nullable=True, comment='图案id')
+    pattern_name = Column(String(255), nullable=True, comment='图案名称')
+    user_id = Column(String(255), nullable=True, comment='用户id')
+    del_flag = Column(String(1), default='0', comment='删除标志(0代表存在 2代表删除)')
+    create_time = Column(DateTime, default=datetime.now, comment='数据创建时间')
+    update_time = Column(DateTime, default=datetime.now, onupdate=datetime.now, comment='数据更新时间')
+    create_dept = Column(BigInteger, default=None, comment='创建部门')
+    create_by = Column(BigInteger, default=None, comment='创建者')
+    update_by = Column(BigInteger, default=None, comment='更新者')
     class Config:
         orm_mode = True

+ 288 - 58
routers/api/pattern/__init__.py

@@ -1,8 +1,9 @@
 #!/usr/bin/env python3
 # -*- coding: utf-8 -*-
 
-from fastapi import APIRouter, Request, Depends, Query, HTTPException, status
+from fastapi import APIRouter, Request, Depends, Query, HTTPException, status,WebSocket,WebSocketDisconnect
 from common.security import valid_access_token
+from fastapi.responses import JSONResponse
 from sqlalchemy.orm import Session
 from sqlalchemy.sql import func
 from common.auth_user import *
@@ -25,14 +26,19 @@ async def create_pattern(
     body = Depends(remove_xss_json),
     db: Session = Depends(get_db)
 ):
-    new_pattern = TpPatternList(
-        pattern_name=body['pattern_name'],
-        content=body['content'],
-        create_dept = user_id
-    )
-    db.add(new_pattern)
-    db.commit()
-    return {"code": 200, "msg": "创建成功", "data": None}
+    try:
+        new_pattern = TpPatternList(
+            id = new_guid(),
+            pattern_name=body['pattern_name'],
+            content=body['content'],
+            create_dept = user_id
+        )
+        db.add(new_pattern)
+        db.commit()
+        return {"code": 200, "msg": "创建成功", "data": None}
+    except Exception as e:
+        traceback.print_exc()
+        raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
 
 @router.put("/update/{pattern_id}")
 async def update_pattern(
@@ -41,33 +47,40 @@ async def update_pattern(
     body=Depends(remove_xss_json),
     db: Session = Depends(get_db)
 ):
-    query = db.query(TpPatternList)
-    query = query.filter(TpPatternList.id == pattern_id)
-    query = query.filter(TpPatternList.del_flag != '2')
-    update_pattern = query.first()
-    if not update_pattern:
-        raise HTTPException(status_code=404, detail="图案不存在")
-
-    update_pattern.pattern_name = body['pattern_name']
-    update_pattern.content = body['content']
-    update_pattern.create_dept = user_id
-    db.commit()
-    return {"code": 200, "msg": "更新成功"}
+    try:
+        query = db.query(TpPatternList)
+        query = query.filter(TpPatternList.id == pattern_id)
+        query = query.filter(TpPatternList.del_flag != '2')
+        update_pattern = query.first()
+        if not update_pattern:
+            raise HTTPException(status_code=404, detail="图案不存在")
+
+        update_pattern.pattern_name = body['pattern_name']
+        update_pattern.content = body['content']
+        update_pattern.create_dept = user_id
+        db.commit()
+        return {"code": 200, "msg": "更新成功"}
+    except Exception as e:
+        traceback.print_exc()
+        raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
 
 @router.get("/info/{pattern_id}")
 async def get_pattern_info(
     pattern_id: int,
     db: Session = Depends(get_db)
 ):
-    query = db.query(TpPatternList)
-    query = query.filter(TpPatternList.id == pattern_id)
-    query = query.filter(TpPatternList.del_flag != '2')
-    pattern = query.first()
-    # pattern = db.query(TpPatternList).filter(TpPatternList.id == pattern_id).first()
-    if not pattern:
-        raise HTTPException(status_code=404, detail="图案不存在")
-    return {"code": 200, "msg": "获取成功", "data": {"pattern_name": pattern.pattern_name, "content": pattern.content}}
-
+    try:
+        query = db.query(TpPatternList)
+        query = query.filter(TpPatternList.id == pattern_id)
+        query = query.filter(TpPatternList.del_flag != '2')
+        pattern = query.first()
+        # pattern = db.query(TpPatternList).filter(TpPatternList.id == pattern_id).first()
+        if not pattern:
+            raise HTTPException(status_code=404, detail="图案不存在")
+        return {"code": 200, "msg": "获取成功", "data": {"pattern_name": pattern.pattern_name, "content": pattern.content}}
+    except Exception as e:
+        traceback.print_exc()
+        raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
 @router.get("/list")
 async def get_pattern_list(
     pattern_name: str = Query(None, description='预案名称'),
@@ -75,39 +88,256 @@ async def get_pattern_list(
     pageSize: int = Query(5, gt=0, description='每页条目数量'),
     db: Session = Depends(get_db)
 ):
+    try:
+        query = db.query(TpPatternList)
+        query = query.filter(TpPatternList.del_flag != '2')
+        if pattern_name:
+            query = query.filter(TpPatternList.pattern_name.like(f'%{pattern_name}%'))
+        total_items = query.count()
 
-    query = db.query(TpPatternList)
-    query = query.filter(TpPatternList.del_flag != '2')
-    if pattern_name:
-        query = query.filter(TpPatternList.pattern_name.like(f'%{pattern_name}%'))
-    total_items = query.count()
-
-    # 排序
+        # 排序
 
-    query = query.order_by(TpPatternList.create_time.desc())
-    # 执行分页查询
-    patterns = query.offset((page - 1) * pageSize).limit(pageSize).all()
-    return {"code": 200, "msg": "查询成功", "data": [{"id": p.id, "pattern_name": p.pattern_name, "content": p.content} for p in patterns],
-            "total": total_items,
-            "page": page,
-            "pageSize": pageSize,
-            "totalPages": (total_items + pageSize - 1) // pageSize
-            }
+        query = query.order_by(TpPatternList.create_time.desc())
+        # 执行分页查询
+        patterns = query.offset((page - 1) * pageSize).limit(pageSize).all()
+        return {"code": 200, "msg": "查询成功", "data": [{"id": p.id, "pattern_name": p.pattern_name, "content": p.content} for p in patterns],
+                "total": total_items,
+                "page": page,
+                "pageSize": pageSize,
+                "totalPages": (total_items + pageSize - 1) // pageSize
+                }
+    except Exception as e:
+        traceback.print_exc()
+        raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
 
 @router.delete("/delete/{pattern_id}")
 async def delete_pattern(
     pattern_id: int,
     db: Session = Depends(get_db)
 ):
-    # 检查图案是否存在
-    query = db.query(TpPatternList)
-    query = query.filter(TpPatternList.id == pattern_id)
-    query = query.filter(TpPatternList.del_flag != '2')
-    pattern = query.first()
-    if not pattern:
-        raise HTTPException(status_code=404, detail="图案不存在")
-
-    # 执行删除操作
-    pattern.del_flag='2'
-    db.commit()
-    return {"code": 200, "msg": "删除成功"}
+    try:
+        # 检查图案是否存在
+        query = db.query(TpPatternList)
+        query = query.filter(TpPatternList.id == pattern_id)
+        query = query.filter(TpPatternList.del_flag != '2')
+        pattern = query.first()
+        if not pattern:
+            raise HTTPException(status_code=404, detail="图案不存在")
+
+        # 执行删除操作
+        pattern.del_flag='2'
+        db.commit()
+        return {"code": 200, "msg": "删除成功"}
+    except Exception as e:
+        traceback.print_exc()
+        raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
+
+def pattern_id_get_tp_pattern_ws_info(pattern_id:str,db: Session):
+    query = db.query(TpPatternWSList)
+    query = query.filter(TpPatternWSList.del_flag != '2')
+    query = query.filter(TpPatternWSList.pattern_id == pattern_id)
+    query.order_by(TpPatternWSList.create_time.desc())
+    return query.first()
+
+def pattern_id_get_tp_pattern_ws_user_list(pattern_id:str,db: Session):
+    query = db.query(TpPatternWSUserList)
+    query = query.filter(TpPatternWSUserList.del_flag != '2')
+    query = query.filter(TpPatternWSUserList.pattern_id == pattern_id)
+    query.order_by(TpPatternWSUserList.create_time.desc())
+    return query.all()
+
+def user_id_get_tp_pattern_ws_user_list(user_id:str,db: Session):
+    query = db.query(TpPatternWSUserList)
+    query = query.filter(TpPatternWSUserList.del_flag != '2')
+    query = query.filter(TpPatternWSUserList.user_id == user_id)
+    query.order_by(TpPatternWSUserList.create_time.desc())
+    return query.all()
+
+def user_id_and_pattern_id_get_tp_pattern_ws_user_info(user_id:str,pattern_id:str,db: Session):
+    query = db.query(TpPatternWSUserList)
+    query = query.filter(TpPatternWSUserList.del_flag != '2')
+    query = query.filter(TpPatternWSUserList.user_id == user_id)
+    query = query.filter(TpPatternWSUserList.pattern_id == pattern_id)
+    query.order_by(TpPatternWSUserList.create_time.desc())
+    return query.all()
+
+class ConnectionManager:
+    def __init__(self):
+        self.active_connections = {}  #: List[WebSocket]
+
+    async def connect(self, websocket: WebSocket,pattern_id:str,db: Session):
+        await websocket.accept()
+        if pattern_id not in self.active_connections:
+            data = pattern_id_get_tp_pattern_ws_info(pattern_id,db)
+            if data:
+                await websocket.send_text(data.content)
+            self.active_connections[pattern_id] = [websocket]
+        else:
+            self.active_connections[pattern_id].append(websocket)
+
+    def disconnect(self, websocket: WebSocket,pattern_id:str):
+        self.active_connections[pattern_id].remove(websocket)
+        if not self.active_connections[pattern_id]:
+            del self.active_connections[pattern_id]
+    async def broadcast(self, message: str,pattern_id:str,user_id,db: Session):
+        new_pattern = TpPatternWSList(
+            id = new_guid(),
+            pattern_id=pattern_id,
+            content=pattern_id,
+            create_dept = user_id
+        )
+        db.add(new_pattern)
+        db.commit()
+        for connection in self.active_connections[pattern_id]:
+            await connection.send_text(message)
+
+
+manager = ConnectionManager()
+
+@router.websocket("/{pattern_id}/ws")
+async def websocket_endpoint(pattern_id:str ,websocket: WebSocket,user_id=Depends(valid_access_token),db: Session = Depends(get_db)):
+    user_list = [i.user_id for i in pattern_id_get_tp_pattern_ws_user_list(pattern_id, db)]
+    if user_id not in user_list:
+        return JSONResponse(status_code=404, content={
+            'code': 404,
+            'msg': '抱歉,您无权限,请联系系统管理员'
+        })
+    await manager.connect(websocket,pattern_id,db)
+    try:
+        while True:
+            data = await websocket.receive_text()
+
+            await manager.broadcast(data,pattern_id,user_id,db)  # 广播消息给所有连接
+    except WebSocketDisconnect:
+        manager.disconnect(websocket,pattern_id)
+
+@router.post("/ws/create")
+async def create_pattern(
+    user_id=Depends(valid_access_token),
+    body = Depends(remove_xss_json),
+    db: Session = Depends(get_db)
+):
+    try:
+        user_id_list = body['user_id_list']
+        if user_id not in user_id_list:
+            user_id_list.append(user_id)
+        for user in user_id_list:
+            new_pattern_ws = TpPatternWSUserList(
+                id=new_guid(),
+                pattern_id=body['pattern_id'],
+                user_id=user,
+                create_dept=user_id
+            )
+            db.add(new_pattern_ws)
+        new_pattern_ws = TpPatternWSList(
+            id = new_guid(),
+            pattern_id=body['pattern_id'],
+            content=body['content'],
+            create_dept = user_id
+        )
+        db.add(new_pattern_ws)
+        db.commit()
+        return {"code": 200, "msg": "创建成功", "data": None}
+    except Exception as e:
+        traceback.print_exc()
+        raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
+
+@router.put("/ws/rollback")
+async def rollback_pattern(
+    user_id=Depends(valid_access_token),
+    body = Depends(remove_xss_json),
+    db: Session = Depends(get_db)
+):
+    try:
+        user_list = [i.user_id for i in pattern_id_get_tp_pattern_ws_user_list(body['pattern_id'],db)]
+        if user_id not in user_list:
+            return JSONResponse(status_code=404, content={
+                                'code': 404,
+                                'msg': '抱歉,您无权限,请联系系统管理员'
+                            })
+        data = pattern_id_get_tp_pattern_ws_info(body['pattern_id'],db)
+        data.del_flag='2'
+        data.update_by = user_id
+        db.commit()
+        return {"code": 200, "msg": "回滚成功", "data": None}
+    except Exception as e:
+        traceback.print_exc()
+        raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
+
+
+@router.get("/ws/list")
+async def get_pattern_list(
+    pattern_name: str = Query(None, description='预案名称'),
+    page: int = Query(1, gt=0, description='页码'),
+    pageSize: int = Query(5, gt=0, description='每页条目数量'),
+    user_id=Depends(valid_access_token),
+    db: Session = Depends(get_db)
+):
+    try:
+        query = db.query(TpPatternWSUserList)
+        query = query.filter(TpPatternWSUserList.del_flag != '2')
+        query = query.filter(TpPatternWSUserList.user_id==user_id)
+        if pattern_name:
+            query = query.filter(TpPatternWSUserList.pattern_name.like(f'%{pattern_name}%'))
+        total_items = query.count()
+        # 排序
+        query = query.order_by(TpPatternWSUserList.create_time.desc())
+        # 执行分页查询
+        patterns = query.offset((page - 1) * pageSize).limit(pageSize).all()
+        return {"code": 200, "msg": "查询成功", "data": [{"id": p.id,"pattern_id":p.pattern_id, "pattern_name": p.pattern_name} for p in patterns],
+                "total": total_items,
+                "page": page,
+                "pageSize": pageSize,
+                "totalPages": (total_items + pageSize - 1) // pageSize
+                }
+    except Exception as e:
+        traceback.print_exc()
+        raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
+
+@router.put("/ws/delete_user")
+async def rollback_pattern(
+    user_id=Depends(valid_access_token),
+    body = Depends(remove_xss_json),
+    db: Session = Depends(get_db)
+):
+    try:
+        pattern_info = pattern_id_get_tp_pattern_ws_info(body['pattern_id'],db)
+        if pattern_info.create_by!=user_id:
+            return JSONResponse(status_code=404, content={
+                                'code': 404,
+                                'msg': '抱歉,您无权限,请联系系统管理员'
+                            })
+        user = user_id_and_pattern_id_get_tp_pattern_ws_user_info(body['user_id'],body['pattern_id'],db)
+        user.del_flag='2'
+        user.update_by = user_id
+        db.commit()
+        return {"code": 200, "msg": "关闭协同成功", "data": None}
+    except Exception as e:
+        traceback.print_exc()
+        raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
+
+@router.put("/ws/delete_user")
+async def rollback_pattern(
+    user_id=Depends(valid_access_token),
+    body = Depends(remove_xss_json),
+    db: Session = Depends(get_db)
+):
+    try:
+        pattern_info = pattern_id_get_tp_pattern_ws_info(body['pattern_id'],db)
+        if pattern_info.create_by!=user_id:
+            return JSONResponse(status_code=404, content={
+                                'code': 404,
+                                'msg': '抱歉,您无权限,请联系系统管理员'
+                            })
+        new_pattern_ws = TpPatternWSUserList(
+            id=new_guid(),
+            pattern_id=body['pattern_id'],
+            user_id=body['user_id'],
+            create_dept=user_id
+        )
+        db.add(new_pattern_ws)
+        db.commit()
+        return {"code": 200, "msg": "开启协同成功", "data": None}
+    except Exception as e:
+        traceback.print_exc()
+        raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")