FastAuth/auth.py
2024-06-10 20:40:22 +05:30

82 lines
2.5 KiB
Python

import datetime
import json
import os
from typing import Any, Union
import bcrypt
import jwt
from fastapi import Depends, HTTPException, status
from fastapi.security import HTTPBearer
from jwt.exceptions import InvalidTokenError
ACCESS_TOKEN_EXPIRE_MINUTES = 30 # 30 minutes
ALGORITHM = "HS256"
JWT_SECRET_KEY = 'abcdefghijklmnopqrstuvwxyz'
security = HTTPBearer()
def create_access_token(subject: str|int, encryption_key: str, expires_delta: int = None) -> str:
"""Creates a jwt token for the logged in user"""
if expires_delta is not None:
expires_delta = datetime.datetime.now(datetime.timezone.utc) + expires_delta
else:
expires_delta = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
to_encode = {"exp": expires_delta, "sub": subject, "key": encryption_key}
encoded_jwt = jwt.encode(to_encode, JWT_SECRET_KEY, ALGORITHM)
return encoded_jwt
def get_current_user(token: str = Depends(security)) -> dict:
"""Parses a jwt token and if it's valid, returns the user ID from it"""
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={'WWW-Authenticate': 'Bearer'}
)
credential = token.credentials
try:
payload = jwt.decode(credential, JWT_SECRET_KEY, algorithms=[ALGORITHM])
user_id: str = payload.get("sub")
if user_id is None:
raise credentials_exception
except InvalidTokenError:
raise credentials_exception
with open('database/users.json', 'r') as f:
text = f.read()
if text:
data = json.loads(text)
else:
raise credentials_exception
user = [i for i in data if i['id']==user_id]
if not user:
raise credentials_exception
cur_user = {'id': user_id}
cur_user['username'] = user[0]['username']
cur_user['encryption_key'] = payload['key']
return cur_user
class Hasher:
"""Class for hashing and verifying passwords"""
@staticmethod
def verify_password(plain_password: str, hashed_password: str) -> bool:
encoded_password = plain_password.encode('utf-8')
encoded_hash = hashed_password.encode('utf-8')
return bcrypt.checkpw(encoded_password, encoded_hash)
@staticmethod
def get_password_hash(password: str) -> str:
salt = bcrypt.gensalt()
encoded_password = password.encode('utf-8')
hash = bcrypt.hashpw(encoded_password, salt)
return hash.decode('utf-8')