Explorar o código

241111-代码。

baoyubo hai 8 meses
pai
achega
e72800d213
Modificáronse 3 ficheiros con 167 adicións e 39 borrados
  1. 2 2
      common/security.py
  2. 1 0
      models/pattern_base.py
  3. 164 37
      routers/api/pattern/__init__.py

+ 2 - 2
common/security.py

@@ -37,8 +37,8 @@ def valid_websocket_token(Authorization: str ) -> int:  #= Header(..., alias="se
     # 目前小屏测试还不能用登录功能,暂时先这样 2024/11/03
     # def valid_access_token(Authorization: str = Header("Bearer eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJzdWIiOiIxIiwiZXhwIjoyMDM5Njk2ODMzfQ.Rhd38oo_S1odjg0xnT4n31cCWCAAPXGb8y_V2XcgqzQ"))->int:
     try:
-        access_token = Authorization.removeprefix("Authorization: Bearer ")
-
+        access_token = Authorization.replace("Authorization: Bearer ","")
+        # print(access_token)
         token_exception = TokenException()
         payload = jwt.decode(access_token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
         print(payload, payload.get("sub"))

+ 1 - 0
models/pattern_base.py

@@ -42,6 +42,7 @@ class TpPatternWSUserList(Base):
     pattern_id = Column(String(255), nullable=True, comment='图案id')
     pattern_name = Column(String(255), nullable=True, comment='图案名称')
     user_id = Column(BigInteger, nullable=True, comment='用户id')
+    ws_flag = Column(String(255), default='true', comment='协同标志')
     del_flag = Column(String(1), default='0', comment='删除标志(0代表存在 2代表删除)')
     create_time = Column(DateTime, default=datetime.now, comment='数据创建时间')
     update_time = Column(DateTime, default=datetime.now, onupdate=datetime.now, comment='数据更新时间')

+ 164 - 37
routers/api/pattern/__init__.py

@@ -254,34 +254,93 @@ class ConnectionManager:
         message = json.dumps(message)
         for connection in self.active_connections[pattern_id]:
             await connection.send_text(message)
-
-
+    async def ownmessage(self,websocket: WebSocket,pattern_id:str,db: Session):#, message: str
+        data = pattern_id_get_tp_pattern_ws_list(pattern_id,db)
+        message = []
+        for info in data:
+            user= user_id_get_user_info(db,info.create_by)
+            dept = dept_id_get_dept_info(db,user.dept_id)
+            message.append({"id":info.id,
+                            "name":info.name,
+                            "pattern_id":info.pattern_id,
+                            "content":info.content,
+                            "visible":info.visible,
+                            "user_id":info.create_by,
+                            "nick_name":user.nick_name,
+                            "dept_name":dept.dept_name})
+        message = json.dumps(message)
+        await websocket.send_text(message)
 manager = ConnectionManager()
 
 @router.websocket("/{pattern_id}/ws")
 async def websocket_endpoint(pattern_id:str ,websocket: WebSocket,db: Session = Depends(get_db)):
 
     await manager.connect(websocket,pattern_id,db)
-
-        # return JSONResponse(status_code=404, content={
-        #     'code': 404,
-        #     'msg': '抱歉,您无权限,请联系系统管理员'
-        # })
+    print(manager.active_connections.keys())
     try:
         data = await websocket.receive_text()
         user_id = valid_websocket_token(data)
-        user_list = [i.user_id for i in pattern_id_get_tp_pattern_ws_user_list(pattern_id, db)]
+        user_list = [ ]
+        for i in pattern_id_get_tp_pattern_ws_user_list(pattern_id, db):
+            if i.ws_flag=='true':
+                user_list.append(i.user_id)
+        # print(user_id,user_list)
         if user_id not in user_list:
-            # await websocket.send_text('抱歉,您无权限,请联系系统管理员')
-            manager.disconnect(websocket, pattern_id)
+            await websocket.send_text('{"code":500,"msg":"抱歉,您无权限,请联系系统管理员"')
+            # manager.disconnect(websocket, pattern_id)
+            raise HTTPException(status_code=500, detail="抱歉,您无权限,请联系系统管理员")
+        await websocket.send_text('{"code":200,"msg":"连接成功"')
+        await manager.ownmessage(websocket,pattern_id,db)
         # await manager.broadcast(pattern_id, db)  # 广播消息给所有连接
-        # now_max_time = pattern_id_get_tp_pattern_ws_max_time(pattern_id, db).update_time
         while True:
-            # data = await websocket.receive_text()
-            # time.sleep(0.5)
-            # max_time = pattern_id_get_tp_pattern_ws_max_time(pattern_id, db).update_time
-            # if now_max_time<max_time:
-            await manager.broadcast(pattern_id,db)  # 广播消息给所有连接
+            data = await websocket.receive_text()
+            data = json.loads(data)
+            if 'operation' in data:
+                if data['operation'] == 'add':
+                    if 'name' in data  and 'content'in data and 'visible' in data:
+                        new_pattern_ws = TpPatternWSList(
+                            id=new_guid(),
+                            name=data['name'],
+                            pattern_id=pattern_id,
+                            content=data['content'],
+                            visible=data['visible'],
+                            create_by=user_id
+                        )
+                        db.add(new_pattern_ws)
+                        db.commit()
+                        await websocket.send_text('{"code": 200, "msg": "创建成功", "data": None}')
+                        await manager.broadcast(pattern_id, db)
+                    else:
+                        await websocket.send_text('{"code": 404, "msg": "新增数据 name/content/visible为必填"}')
+                elif data['operation'] == 'update' :
+                    if 'id' in data:
+                        id = data['id']
+                        query = db.query(TpPatternWSList)
+                        query = query.filter(TpPatternWSList.del_flag != '2')
+                        query = query.filter(TpPatternWSList.id == id)
+                        info = query.first()
+                        if 'name' in data:
+                            info.name = data['name']
+                            info.update_by = user_id
+                        # if 'pattern_id' in data:
+                        #     info.pattern_id = data['pattern_id']
+                        #     info.update_by = user_id
+                        if 'content' in data:
+                            info.content = data['content']
+                            info.update_by = user_id
+                        if 'visible' in data:
+                            info.visible = data['visible']
+                            info.update_by = user_id
+                        db.commit()
+                        await websocket.send_text('{"code": 200, "msg": "更新成功", "data": None}')
+                        await manager.broadcast(pattern_id, db)
+                    else:
+                        await websocket.send_text('{"code": 404, "msg": "更新数据 id为必填"}')
+                else:
+                    await websocket.send_text('{"code":404,"msg":"operation入参add or update"')
+            else:
+                await websocket.send_text('{"code":404,"msg":"需包含operation参数"')
+         # 广播消息给所有连接
     except WebSocketDisconnect:
         manager.disconnect(websocket,pattern_id)
 
@@ -333,6 +392,8 @@ async def update_pattern(
             info.visible=body['visible']
             info.update_by=user_id
         db.commit()
+        # print(manager.active_connections.keys())
+        # print(manager.active_connections)
         # await manager.broadcast(body['pattern_id'], db)
         return {"code": 200, "msg": "更新成功", "data": None}
     except Exception as e:
@@ -347,8 +408,14 @@ async def create_pattern(
     db: Session = Depends(get_db)
 ):
     try:
+        if pattern_id_get_tp_pattern_ws_group_def_info(body['pattern_id'],db):
+            # return JSONResponse(
+            #     status_code=500, content={"code":500,"msg":"pattern_id已存在"}
+            # )
+            body['pattern_id']=new_guid()
         user_id_list = body['user_id_list']
-        if str(user_id) not in user_id_list:
+        if str(user_id) not in user_id_list and user_id not in user_id_list:
+            print(str(user_id),user_id_list)
             user_id_list.append(str(user_id))
         for user in user_id_list:
             new_pattern_user_ws = TpPatternWSUserList(
@@ -359,14 +426,15 @@ async def create_pattern(
                 create_by=user_id
             )
             db.add(new_pattern_user_ws)
-        new_pattern_ws = TpPatternWSList(
-            id = new_guid(),
-            name=body['name'],
-            pattern_id=body['pattern_id'],
-            content=body['content'],
-            visible = body['visible'],
-            create_by = user_id
-        )
+        if 'name' in body and 'content'  in body and 'visible'  in body:
+            new_pattern_ws = TpPatternWSList(
+                id = new_guid(),
+                name=body['name'],
+                pattern_id=body['pattern_id'],
+                content=body['content'],
+                visible = body['visible'],
+                create_by = user_id
+            )
         db.add(new_pattern_ws)
         new_pattern_group_ws = TpPatternWSGroupList(
             group_id=body['pattern_id'],
@@ -413,8 +481,8 @@ async def create_pattern(
 @router.get("/ws/list")
 async def get_pattern_list(
     pattern_name: str = Query(None, description='预案名称'),
-    page: int = Query(1, gt=0, description='页码'),
-    pageSize: int = Query(5, gt=0, description='每页条目数量'),
+    # page: int = Query(1, gt=0, description='页码'),
+    # pageSize: int = Query(5, gt=0, description='每页条目数量'),
     user_id=Depends(valid_access_token),
     db: Session = Depends(get_db)
 ):
@@ -428,12 +496,12 @@ async def get_pattern_list(
         # 排序
         query = query.order_by(TpPatternWSUserList.create_time.desc())
         # 执行分页查询
-        patterns = query.offset((page - 1) * pageSize).limit(pageSize).all()
+        patterns = query.all() # .offset((page - 1) * pageSize).limit(pageSize)
         return {"code": 200, "msg": "查询成功", "data": [{"id": p.id,"pattern_id":p.pattern_id, "pattern_name": p.pattern_name} for p in patterns],
-                "total": total_items,
-                "page": page,
-                "pageSize": pageSize,
-                "totalPages": (total_items + pageSize - 1) // pageSize
+                # "total": total_items,
+                # "page": page,
+                # "pageSize": pageSize,
+                # "totalPages": (total_items + pageSize - 1) // pageSize
                 }
     except Exception as e:
         traceback.print_exc()
@@ -461,7 +529,7 @@ async def get_pattern_list(
         for p in patterns:
             user = user_id_get_user_info(db, p.user_id)
             dept = dept_id_get_dept_info(db, user.dept_id)
-            data.append({"id": p.id,"pattern_id":p.pattern_id, "pattern_name": p.pattern_name,"nick_name":user.nick_name,"dept_name":dept.dept_name})
+            data.append({"id": p.id,"pattern_id":p.pattern_id,"ws_flag":p.ws_flag, "pattern_name": p.pattern_name,"nick_name":user.nick_name,"dept_name":dept.dept_name})
         return {"code": 200, "msg": "查询成功", "data": data,
                 "total": total_items,
                 "page": page,
@@ -492,7 +560,7 @@ async def rollback_pattern(
                                 'msg': '抱歉,您无权限,请联系系统管理员'
                             })
         user = user_id_and_pattern_id_get_tp_pattern_ws_user_info(body['user_id'],body['pattern_id'],db)
-        user.del_flag='2'
+        user.ws_flag='false'
         user.update_by = user_id
         db.commit()
         return {"code": 200, "msg": "关闭协同成功", "data": None}
@@ -521,7 +589,7 @@ async def rollback_pattern(
                             })
         users = pattern_id_get_tp_pattern_ws_user_list(body['pattern_id'],db)
         for user in users:
-            user.del_flag='2'
+            user.ws_flag='false'
             user.update_by = user_id
         db.commit()
         return {"code": 200, "msg": "关闭协同成功", "data": None}
@@ -530,6 +598,63 @@ async def rollback_pattern(
         raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
 
 
+@router.put("/ws/reset_user")
+async def rollback_pattern(
+    user_id=Depends(valid_access_token),
+    body = Depends(remove_xss_json),
+    db: Session = Depends(get_db)
+):
+    try:
+        pattern_info = pattern_id_get_tp_pattern_ws_group_def_info(body['pattern_id'],db)
+        if pattern_info is None:
+            return JSONResponse(status_code=404, content={
+                                'code': 404,
+                                'msg': '抱歉,pattern_id不存在,请联系系统管理员'
+                            })
+        if pattern_info.create_by!=user_id:
+            return JSONResponse(status_code=404, content={
+                                'code': 404,
+                                'msg': '抱歉,您无权限,请联系系统管理员'
+                            })
+        user = user_id_and_pattern_id_get_tp_pattern_ws_user_info(body['user_id'],body['pattern_id'],db)
+        user.ws_flag='true'
+        user.update_by = user_id
+        db.commit()
+        return {"code": 200, "msg": "开启协同成功", "data": None}
+    except Exception as e:
+        traceback.print_exc()
+        raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
+
+
+@router.put("/ws/reset_all_user")
+async def rollback_pattern(
+    user_id=Depends(valid_access_token),
+    body = Depends(remove_xss_json),
+    db: Session = Depends(get_db)
+):
+    try:
+        pattern_info = pattern_id_get_tp_pattern_ws_group_def_info(body['pattern_id'],db)
+        if pattern_info is None:
+            return JSONResponse(status_code=404, content={
+                                'code': 404,
+                                'msg': '抱歉,pattern_id不存在,请联系系统管理员'
+                            })
+        if pattern_info.create_by!=user_id:
+            return JSONResponse(status_code=404, content={
+                                'code': 404,
+                                'msg': '抱歉,您无权限,请联系系统管理员'
+                            })
+        users = pattern_id_get_tp_pattern_ws_user_list(body['pattern_id'],db)
+        for user in users:
+            user.ws_flag='true'
+            user.update_by = user_id
+        db.commit()
+        return {"code": 200, "msg": "开启协同成功", "data": None}
+    except Exception as e:
+        traceback.print_exc()
+        raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
+
+
 
 @router.post("/ws/add_user")
 async def rollback_pattern(
@@ -552,11 +677,12 @@ async def rollback_pattern(
         user_id_list = body['user_id_list']
         if str(user_id) not in user_id_list:
             user_id_list.append(str(user_id))
+        info = pattern_id_get_tp_pattern_ws_group_def_info(body['pattern_id'],db)
         for user in user_id_list:
             new_pattern_ws = TpPatternWSUserList(
                 id=new_guid(),
                 pattern_id=body['pattern_id'],
-                pattern_name=body['pattern_name'],
+                pattern_name=info.pattern_name,
                 user_id=user,
                 create_by=user_id
             )
@@ -585,11 +711,12 @@ async def add_group_pattern(
                                 'code': 404,
                                 'msg': '抱歉,您无权限,请联系系统管理员'
                             })
+        info = pattern_id_get_tp_pattern_ws_group_def_info(body['pattern_id'], db)
         new_pattern_ws_group = TpPatternWSGroupList(
             group_id=new_guid(),
             group_name = body['group_name'],
             pattern_id=body['pattern_id'],
-            pattern_name=body['pattern_name'],
+            pattern_name=info.pattern_name,
             create_by=user_id
         )
         db.add(new_pattern_ws_group)