79 lines
2.4 KiB
Python
79 lines
2.4 KiB
Python
import datetime
|
|
import json
|
|
import os
|
|
from typing import Any, Union
|
|
|
|
import bcrypt
|
|
import jwt
|
|
from fastapi import Depends, HTTPException, status
|
|
from fastapi.security import HTTPBearer
|
|
from jwt.exceptions import InvalidTokenError
|
|
|
|
import queries
|
|
|
|
ACCESS_TOKEN_EXPIRE_MINUTES = 30 # 30 minutes
|
|
ALGORITHM = "HS256"
|
|
JWT_SECRET_KEY = 'abcdefghijklmnopqrstuvwxyz'
|
|
|
|
security = HTTPBearer()
|
|
|
|
|
|
def create_access_token(subject: str|int, encryption_key: str, expires_delta: int = None) -> str:
|
|
"""Creates a jwt token for the logged in user"""
|
|
|
|
if expires_delta is not None:
|
|
expires_delta = datetime.datetime.now(datetime.timezone.utc) + expires_delta
|
|
else:
|
|
expires_delta = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
|
|
|
to_encode = {"exp": expires_delta, "sub": subject, "key": encryption_key}
|
|
encoded_jwt = jwt.encode(to_encode, JWT_SECRET_KEY, ALGORITHM)
|
|
return encoded_jwt
|
|
|
|
|
|
def get_current_user(token: str = Depends(security)) -> dict:
|
|
"""Parses a jwt token and if it's valid, returns the user ID from it"""
|
|
|
|
credentials_exception = HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Could not validate credentials",
|
|
headers={'WWW-Authenticate': 'Bearer'}
|
|
)
|
|
credential = token.credentials
|
|
try:
|
|
payload = jwt.decode(credential, JWT_SECRET_KEY, algorithms=[ALGORITHM])
|
|
user_id: str = payload.get("sub")
|
|
if user_id is None:
|
|
raise credentials_exception
|
|
|
|
except InvalidTokenError:
|
|
raise credentials_exception
|
|
|
|
|
|
user = queries.GET_USER_BY_ID({'user_id': user_id})
|
|
|
|
if user is None:
|
|
raise credentials_exception
|
|
|
|
cur_user = {'id': user_id}
|
|
cur_user['username'] = user['username']
|
|
cur_user['encryption_key'] = payload['key']
|
|
|
|
return cur_user
|
|
|
|
|
|
class Hasher:
|
|
"""Class for hashing and verifying passwords"""
|
|
|
|
@staticmethod
|
|
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
|
encoded_password = plain_password.encode('utf-8')
|
|
encoded_hash = hashed_password.encode('utf-8')
|
|
return bcrypt.checkpw(encoded_password, encoded_hash)
|
|
|
|
@staticmethod
|
|
def get_password_hash(password: str) -> str:
|
|
salt = bcrypt.gensalt()
|
|
encoded_password = password.encode('utf-8')
|
|
hash = bcrypt.hashpw(encoded_password, salt)
|
|
return hash.decode('utf-8') |