oauth.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. from fastapi import APIRouter, Request, Depends, Form
  4. from database import get_db
  5. from utils.StripTagsHTMLParser import *
  6. from sqlalchemy.orm import Session
  7. from datetime import datetime, timedelta
  8. import jwt
  9. from passlib.context import CryptContext
  10. from models import *
  11. from sqlalchemy import text, exists, and_, or_, not_
  12. from sqlalchemy.sql import func
  13. from models import *
  14. from extensions import logger
  15. from utils import *
  16. import traceback
  17. from exceptions import TokenException
  18. router = APIRouter()
  19. SECRET_KEY = "09d25e094faa6ca2556c818166b7a9563b93f7099f6f0f4caa6cf63b88e8d3ff"
  20. ALGORITHM = "HS256"
  21. pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
  22. @router.post("/create/token")
  23. async def login(
  24. request: Request,
  25. client_id: str = Form(..., description=''),
  26. client_secret: str = Form(..., description=''),
  27. db: Session = Depends(get_db)
  28. ):
  29. hash_passwd = pwd_context.hash(client_secret)
  30. return {
  31. "code": 1,
  32. "msg": "",
  33. "data": hash_passwd
  34. }
  35. @router.post('/token')
  36. async def login(
  37. request: Request,
  38. client_id: str = Form(..., description=''),
  39. client_secret: str = Form(..., description=''),
  40. grant_type: str = Form(..., description=''),
  41. scope: str = Form(..., description=''),
  42. db: Session = Depends(get_db)
  43. ):
  44. app = authenticate_app(db, client_id, client_secret)
  45. if not app:
  46. return {"code": 0, "msg": "client_id not exists", "data": {}}
  47. expires_in = 7200
  48. access_token_expires = timedelta(seconds=expires_in)
  49. access_token = create_access_token(
  50. data={"sub": client_id}, expires_delta=access_token_expires
  51. )
  52. return {
  53. "code": 1,
  54. "msg": "成功",
  55. "data": {
  56. "access_token": access_token,
  57. "expires_in": expires_in,
  58. "token_type": "Bearer",
  59. "scope": "all"
  60. }
  61. }
  62. def verify_secret(plain_secret, hashed_secret):
  63. return pwd_context.verify(plain_secret, hashed_secret)
  64. def get_app(db: Session, client_id: str):
  65. app = db.query(DangerAppInfo).filter(DangerAppInfo.client_id == client_id).first()
  66. return app
  67. def authenticate_app(db: Session, client_id: str, client_secret: str):
  68. app = get_app(db, client_id)
  69. if not app:
  70. return False
  71. if not verify_secret(client_secret, app.client_secret):
  72. return False
  73. return app
  74. def create_access_token(*, data: dict, expires_delta: timedelta = None):
  75. to_encode = data.copy()
  76. if expires_delta:
  77. expire = datetime.now() + expires_delta
  78. else:
  79. expire = datetime.now() + timedelta(seconds=7200)
  80. to_encode.update({"exp": expire})
  81. encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
  82. return encoded_jwt
  83. def valid_access_token(Authorization: str = Header(..., alias="Authorization"), db: Session = Depends(get_db)) -> str:
  84. try:
  85. access_token = Authorization.removeprefix("Bearer ")
  86. payload = jwt.decode(access_token, SECRET_KEY, algorithms=[ALGORITHM])
  87. client_id: str = payload.get("sub")
  88. app = get_app(db, client_id)
  89. if not app:
  90. raise HTTPException(status_code=401, detail="access_token已失效")
  91. except Exception:
  92. # 处理异常
  93. traceback.print_exc()
  94. raise HTTPException(status_code=401, detail="access_token已失效")
  95. return client_id