import hashlib
import math
import time
import typing
import jwt
import pydantic
from fastapi.exceptions import HTTPException
from fastapi.requests import Request
from fastapi.security import OAuth2PasswordBearer
from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN
from fastapi_token.encrypt import gen_key, gen_nonce_from_timestamp, encrypt
from fastapi_token.schemas import EncryptAuth, GrantToken, Auth, AccessField
[docs]class TimeExpireError(HTTPException):
""" 当前的token过期"""
def __init__(self, msg):
super(TimeExpireError, self).__init__(
status_code=HTTP_401_UNAUTHORIZED,
detail=f"Not authenticated, auth fail timestamp not allowed"
f", Error msg : {msg}",
)
[docs]class VerifyError(HTTPException):
""" 验证不通过 """
def __init__(self, msg):
super(VerifyError, self).__init__(
status_code=HTTP_401_UNAUTHORIZED,
detail=f"Not authenticated, auth fail signature not correct"
f", Error msg : {msg}",
)
[docs]class TokenExpireError(HTTPException):
""" user_token过期 """
def __init__(self, msg):
super(TokenExpireError, self).__init__(
status_code=HTTP_401_UNAUTHORIZED,
detail=f"Not authenticated, auth fail token expire"
f", Error msg : {msg}",
)
[docs]class TokenBase:
"""
token 生成基类
token生成中使用的变量:
- user_id 用户id
- user_token 用户获得的认证token, 用于生成最终的在请求中使用的token, 为字符串
token生成和认证过程:
1. 利用 user_id 以及其他信息生成 user_token 使用函数 :func:`gen_user_token`
2. 客户端使用 :func:`gen_auth_token` 中的编码方式生成 :class:`fastapi_token_gen.schemas.Auth` 形式的数据
3. 客户端使用 jwt以及约定的参数对上述生成的数据进行编码, 并组成 OAuth2 Bearer Token 形式发送给服务端
4. 服务端获取 jwt编码的token后, 利用函数 :func:`auth` 对token进行认证
"""
[docs] def gen_user_token(self, user_id: str, **config) -> str:
"""
生成用户的token, 用于生成最终认证token
:param user_id 用户ID用于生成认证token
:param config
:return:
"""
raise NotImplementedError
[docs] def gen_auth_token(self, user_id: str, user_token: str, **config) -> typing.Tuple[Auth, str]:
"""
根据 user_token 生成最终的认证access_token
:return:
"""
raise NotImplementedError
[docs] def auth(self, authorization: str) -> Auth:
"""
认证, 利用access_token 进行认证
:return:
"""
raise NotImplementedError
[docs]class OAuth2(OAuth2PasswordBearer):
def __init__(self, token_instance: TokenBase, **args):
super().__init__(**args)
self.token_instance = token_instance
async def __call__(self, request: Request) -> Auth:
authorization = await super().__call__(request)
if not authorization:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
)
return self.token_instance.auth(authorization)
[docs]class EncryptToken(TokenBase):
"""
在HTTP非加密环境下实现认证过程, 并使得认证token的生成不依赖服务端分配而是一次性分配一个密钥,在不暴露此密钥的情况下进行认证. 此过程中
服务端也是无状态的,也就是不需要存储分配给客户的密钥.
利用对称加密方式生成,利用JWS自带签名方式验证,支持增加 ``user_token`` 的过期时间和权限管理
``user_token`` 分发和认证过程:
1. 利用 :class:`fastapi_token.schemas.AccessField` 中的信息生成 `key`, `nonce`,使用用chacha20ietf对内置文明进行加密
获得密文作为客户端JWT加密密钥, 利用JWT生成包含上述生成信息和密文的token作为 ``user_token``
2. 客户端解码后得到作为加密密钥的密文和生成信息, 使用加密密钥使用JWT编码 :class:`fastapi_token.schemas.EncryptAuth`,
发送给服务端
3. 服务端解码token后获得生成密钥的信息, 并重新生成密和初始向量并加密内置明文获取客户端JWT加密的密钥, 并利用此密钥验证客户端发送的token
的签名, 从而验证客户端的 ``user_token``
由于在上述过程中,客户端或者中间攻击者若修改发送的 :class:`fastapi_token.schemas.AccessField` 中的字段会导致最终服务端还原的密钥
发生改变从而阻止对于 ``user_token`` 的修改, 重放攻击可以通过验证客户端发送的token中的时间戳部分防止.
"""
def __init__(
self,
secret_key: str,
algorithm_jwt: str,
salt_jwt: str,
salt_grand: str,
access_token_expire_second: int,
):
"""
:param secret_key: 总密钥,用于内部各种密钥的生成
:param algorithm_jwt: jwt编码使用的算法
:param salt_jwt: jwt编码使用的密钥的加盐内容
:param salt_grand: user_token 生成的加盐内容
:param access_token_expire_second: 客户端认证内容的过期时间
"""
self.secret_key = secret_key
self.secret_key_grand = hashlib.md5((self.secret_key + salt_grand).encode("utf-8")).hexdigest()
self.secret_key_jwt = hashlib.md5((self.secret_key + salt_jwt).encode("utf-8")).hexdigest()
self.algorithm_jwt = algorithm_jwt
self.access_token_expire_second = access_token_expire_second
self.secret_str = "衬衫的价格是九磅十五便士".encode("utf-8")
[docs] @staticmethod
def gen_key(salt: str = "", secret_key="") -> bytes:
"""
生成用于对称加密的密钥,从 secret_key 生成
:return:
"""
return gen_key(
(secret_key + salt if salt is not None else "").encode("utf-8")
)
[docs] def auth(self, authorization: str) -> EncryptAuth:
try:
payload = EncryptAuth(**jwt.decode(authorization, options={'verify_signature': False}))
except pydantic.ValidationError as e:
raise VerifyError(f"JWT token missing filed, mes: {e.errors()}")
except jwt.DecodeError:
raise VerifyError(f"This string is not a valid JWT token")
access_field = AccessField(**payload.dict())
key = self.gen_key(secret_key=self.secret_key_grand, salt=access_field.gen_salt())
nonce = gen_nonce_from_timestamp(access_field.token_expire)
encrypt_key = encrypt(self.secret_str, key=key, nonce=nonce).hex()
try:
payload = EncryptAuth(
**jwt.decode(authorization, key=encrypt_key, algorithms=[self.algorithm_jwt]))
except jwt.InvalidSignatureError:
raise VerifyError(f"This token is invalid, use a valid token")
current_timestamp = time.time()
if math.fabs(
payload.timestamp - current_timestamp + self.access_token_expire_second / 2
) > self.access_token_expire_second:
raise TimeExpireError(
f"current time is: {current_timestamp}, token time is : {payload.timestamp}, "
f"access token expire second is : {self.access_token_expire_second}")
if payload.token_expire < current_timestamp:
raise TokenExpireError(f"user token is expired. current time is : {current_timestamp}, "
f"user token expired time is : {payload.token_expire}")
return payload
[docs] def check_user_token(self, user_token: str):
try:
grant_token = GrantToken(
**jwt.decode(user_token, key=self.secret_key_jwt, algorithms=[self.algorithm_jwt])
)
return grant_token
except jwt.InvalidSignatureError:
raise VerifyError(f"User token verify fail, this token may not the key in this system,"
f"Info in this token is : {jwt.decode(user_token, options={'verify_signature': False})}")
except jwt.DecodeError:
raise VerifyError(f"This string is not a valid JWT token")
[docs] def gen_user_token(self, user_id: str, access_field: typing.Optional[AccessField] = None, **config) -> str:
"""
生成用户的token, 用于生成最终认证token
:param user_id :用户ID
:param access_field : 生成的token的权限,不指定则生成最大权限的token
:return: jwt 格式的 user_token
"""
if not access_field:
access_field = AccessField(
token_expire=config.get(
"expire_timestamp",
(int(time.time()) + self.access_token_expire_second)
),
allow_method=["*"]
)
key = self.gen_key(secret_key=self.secret_key_grand, salt=access_field.gen_salt())
nonce = gen_nonce_from_timestamp(access_field.token_expire)
grand_token = GrantToken(
jwt_algorithm=self.algorithm_jwt,
user_id=user_id,
verify_token=self.gen_key(secret_key=self.secret_key_grand, salt=key.hex()).hex(),
encrypt_key=encrypt(self.secret_str, key=key, nonce=nonce).hex(),
**access_field.dict(),
)
return jwt.encode(grand_token.dict(), self.secret_key_jwt, self.algorithm_jwt)
[docs] @staticmethod
def gen_auth_token(user_id: str, user_token: str, **config) -> typing.Tuple[EncryptAuth, str]:
"""
这里 user_token 为生成认证的jwt代码
根据 user_token 生成最终的认证access_token
:param user_id
:param user_token
:param config
:return: 认证内容以及jwt加密后内容
"""
grand_token = GrantToken(**jwt.decode(user_token, options={"verify_signature": False}))
access_field = AccessField(**grand_token.dict())
timestamp = config.get("timestamp", int(time.time()))
encrypt_auth = EncryptAuth(user_id=user_id, timestamp=timestamp, **access_field.dict())
return encrypt_auth, jwt.encode(
encrypt_auth.dict(),
key=grand_token.encrypt_key,
algorithm=grand_token.jwt_algorithm,
)
[docs]class EncryptPlainToken(EncryptToken):
[docs] def auth(self, authorization: str) -> GrantToken:
try:
grant_token: GrantToken = self.check_user_token(authorization)
except pydantic.ValidationError as e:
raise VerifyError(f"JWT token missing filed, mes: {e.errors()}")
current_timestamp = time.time()
if grant_token.token_expire < current_timestamp:
raise TokenExpireError(f"user token is expired. current time is : {current_timestamp}, "
f"user token expired time is : {grant_token.token_expire}")
return grant_token