Jelajahi Sumber

241112-1代码。

baoyubo 7 bulan lalu
induk
melakukan
d975389ec4
2 mengubah file dengan 39 tambahan dan 11 penghapusan
  1. 2 2
      common/security.py
  2. 37 9
      routers/api/pattern/__init__.py

+ 2 - 2
common/security.py

@@ -33,11 +33,11 @@ def valid_access_token(Authorization: str = Header(..., alias="Authorization"))
     return int(user_id)
 
 
-def valid_websocket_token(Authorization: str = Header(..., alias="sec-websocket-protocol")) -> int:
+def valid_websocket_token(Authorization: str ) -> int:  #= Header(..., alias="sec-websocket-protocol")
     # 目前小屏测试还不能用登录功能,暂时先这样 2024/11/03
     # def valid_access_token(Authorization: str = Header("Bearer eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJzdWIiOiIxIiwiZXhwIjoyMDM5Njk2ODMzfQ.Rhd38oo_S1odjg0xnT4n31cCWCAAPXGb8y_V2XcgqzQ"))->int:
     try:
-        access_token = Authorization.removeprefix("Bearer ")
+        access_token = Authorization.removeprefix("Authorization: Bearer ")
 
         token_exception = TokenException()
         payload = jwt.decode(access_token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])

+ 37 - 9
routers/api/pattern/__init__.py

@@ -214,6 +214,7 @@ class ConnectionManager:
             self.active_connections[pattern_id] = [websocket]
         else:
             self.active_connections[pattern_id].append(websocket)
+        print('连接成功')
         # data = pattern_id_get_tp_pattern_ws_list(pattern_id, db)
         # message = []
         # for info in data:
@@ -229,9 +230,10 @@ class ConnectionManager:
         #                     "dept_name": dept.dept_name})
         #     message = json.dumps(message)
         #     await websocket.send_text(message)
-            await websocket.send_text('连接成功')
+        #     await websocket.send_text('连接成功')
     def disconnect(self, websocket: WebSocket,pattern_id:str):
-        await websocket.send_text('已断开')
+        # await websocket.send_text('已断开')
+        print('已断开')
         self.active_connections[pattern_id].remove(websocket)
         if not self.active_connections[pattern_id]:
             del self.active_connections[pattern_id]
@@ -257,15 +259,21 @@ class ConnectionManager:
 manager = ConnectionManager()
 
 @router.websocket("/{pattern_id}/ws")
-async def websocket_endpoint(pattern_id:str ,websocket: WebSocket,user_id=Depends(valid_access_token),db: Session = Depends(get_db)):
-    user_list = [i.user_id for i in pattern_id_get_tp_pattern_ws_user_list(pattern_id, db)]
-    if user_id not in user_list:
-        return JSONResponse(status_code=404, content={
-            'code': 404,
-            'msg': '抱歉,您无权限,请联系系统管理员'
-        })
+async def websocket_endpoint(pattern_id:str ,websocket: WebSocket,db: Session = Depends(get_db)):
+
     await manager.connect(websocket,pattern_id,db)
+
+        # return JSONResponse(status_code=404, content={
+        #     'code': 404,
+        #     'msg': '抱歉,您无权限,请联系系统管理员'
+        # })
     try:
+        data = await websocket.receive_text()
+        user_id = valid_websocket_token(data)
+        user_list = [i.user_id for i in pattern_id_get_tp_pattern_ws_user_list(pattern_id, db)]
+        if user_id not in user_list:
+            # await websocket.send_text('抱歉,您无权限,请联系系统管理员')
+            manager.disconnect(websocket, pattern_id)
         # await manager.broadcast(pattern_id, db)  # 广播消息给所有连接
         # now_max_time = pattern_id_get_tp_pattern_ws_max_time(pattern_id, db).update_time
         while True:
@@ -473,6 +481,11 @@ async def rollback_pattern(
 ):
     try:
         pattern_info = pattern_id_get_tp_pattern_ws_group_def_info(body['pattern_id'],db)
+        if pattern_info is None:
+            return JSONResponse(status_code=404, content={
+                                'code': 404,
+                                'msg': '抱歉,pattern_id不存在,请联系系统管理员'
+                            })
         if pattern_info.create_by!=user_id:
             return JSONResponse(status_code=404, content={
                                 'code': 404,
@@ -495,6 +508,11 @@ async def rollback_pattern(
 ):
     try:
         pattern_info = pattern_id_get_tp_pattern_ws_group_def_info(body['pattern_id'],db)
+        if pattern_info is None:
+            return JSONResponse(status_code=404, content={
+                                'code': 404,
+                                'msg': '抱歉,pattern_id不存在,请联系系统管理员'
+                            })
         if pattern_info.create_by!=user_id:
             return JSONResponse(status_code=404, content={
                                 'code': 404,
@@ -526,6 +544,11 @@ async def add_group_pattern(
 ):
     try:
         pattern_info = pattern_id_get_tp_pattern_ws_group_def_info(body['pattern_id'],db)
+        if pattern_info is None:
+            return JSONResponse(status_code=404, content={
+                                'code': 404,
+                                'msg': '抱歉,pattern_id不存在,请联系系统管理员'
+                            })
         if pattern_info.create_by!=user_id:
             return JSONResponse(status_code=404, content={
                                 'code': 404,
@@ -553,6 +576,11 @@ async def add_group_pattern(
 ):
     try:
         pattern_info = pattern_id_get_tp_pattern_ws_group_def_info(body['pattern_id'],db)
+        if pattern_info is None:
+            return JSONResponse(status_code=404, content={
+                                'code': 404,
+                                'msg': '抱歉,pattern_id不存在,请联系系统管理员'
+                            })
         if pattern_info.create_by!=user_id:
             return JSONResponse(status_code=404, content={
                                 'code': 404,