sso.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. # -*- coding: utf-8 -*-
  2. from fastapi import APIRouter, Depends, Request, Header, Form, Body
  3. from fastapi.responses import FileResponse, StreamingResponse
  4. from sqlalchemy.orm import Session
  5. from fastapi.responses import JSONResponse
  6. from database import get_db
  7. from utils import *
  8. from utils.vcode import *
  9. from utils.redis_util import *
  10. import base64
  11. from common.const import *
  12. from io import BytesIO
  13. from utils.StripTagsHTMLParser import *
  14. from common import security
  15. from datetime import timedelta
  16. from common.security import valid_access_token
  17. from common.auth_user import *
  18. from common import YzyApi
  19. from models import *
  20. from urllib.parse import quote
  21. import requests
  22. import jwt
  23. import traceback
  24. from common.enc import mpfun
  25. router = APIRouter()
  26. # 提供给数科使用的单点登录token
  27. @router.get('/token/create')
  28. def sso_token(request: Request,
  29. user_id: int = Depends(valid_access_token),
  30. db: Session = Depends(get_db)):
  31. sso_token_expires = timedelta(seconds = 3600 * 24)
  32. sso_token = security.create_access_token(
  33. data={"sub": user_id}, expires_delta = sso_token_expires
  34. )
  35. print('sso_token:', sso_token)
  36. return {
  37. "code": 200,
  38. "msg": "操作成功",
  39. "data": {
  40. "sso_token": sso_token
  41. }
  42. }
  43. # token校验
  44. @router.get('/token/valid')
  45. def sso_token(request: Request,
  46. sso_token: str,
  47. db: Session = Depends(get_db)):
  48. try:
  49. payload = jwt.decode(sso_token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
  50. # print(payload,payload.get("sub"))
  51. user_id: str = payload.get("sub")
  52. row = db.query(SysUser).filter(SysUser.user_id == int(user_id)).first()
  53. if row is None:
  54. return {
  55. "code": 500,
  56. "msg": "token异常"
  57. }
  58. # 角色信息
  59. roles = []
  60. role_ids = db.query(SysUserRole).filter(SysUserRole.user_id == int(user_id)).all()
  61. for role in role_ids:
  62. role_info = db.query(SysRole).filter(SysRole.role_id == role.role_id).first()
  63. roles.append(
  64. {
  65. "roleId": role_info.role_id,
  66. "roleName": role_info.role_name,
  67. "roleKey": role_info.role_key
  68. }
  69. )
  70. role_keys = [
  71. n['roleKey']
  72. for n in roles
  73. ]
  74. data = {
  75. "userId": row.user_id,
  76. "userName": mpfun.dec_data(row.user_name),
  77. "nickName": row.nick_name,
  78. "roles": role_keys
  79. }
  80. return {
  81. "code": 200,
  82. "msg": "操作成功",
  83. "data": data
  84. }
  85. except Exception:
  86. traceback.print_exc()
  87. return {
  88. "code": 500,
  89. "msg": "token异常"
  90. }