diff --git a/auth.py b/auth.py index 590f4ee..7b13f3c 100644 --- a/auth.py +++ b/auth.py @@ -9,6 +9,8 @@ from fastapi import Depends, HTTPException, status from fastapi.security import HTTPBearer from jwt.exceptions import InvalidTokenError +import queries + ACCESS_TOKEN_EXPIRE_MINUTES = 30 # 30 minutes ALGORITHM = "HS256" JWT_SECRET_KEY = 'abcdefghijklmnopqrstuvwxyz' @@ -47,19 +49,14 @@ def get_current_user(token: str = Depends(security)) -> dict: except InvalidTokenError: raise credentials_exception - with open('database/users.json', 'r') as f: - text = f.read() - if text: - data = json.loads(text) - else: - raise credentials_exception - user = [i for i in data if i['id']==user_id] - if not user: + user = queries.GET_USER_BY_ID({'user_id': user_id}) + + if user is None: raise credentials_exception cur_user = {'id': user_id} - cur_user['username'] = user[0]['username'] + cur_user['username'] = user['username'] cur_user['encryption_key'] = payload['key'] return cur_user diff --git a/main.py b/main.py index 962c4e4..c33aa89 100644 --- a/main.py +++ b/main.py @@ -1,10 +1,10 @@ import json from sqlite3 import IntegrityError +from typing import Literal from fastapi import Depends, FastAPI, HTTPException, status from fastapi.encoders import jsonable_encoder from fastapi.middleware.cors import CORSMiddleware -from fastapi.security import OAuth2PasswordBearer import queries from auth import Hasher, create_access_token, get_current_user @@ -42,26 +42,23 @@ async def register(user: User): plain_text_password = user.password user.password = Hasher.get_password_hash(plain_text_password) - conn = queries.connect_db() try: - cur = conn.execute(queries.CREATE_USER_QUERY, jsonable_encoder(user)) - user_id = cur.fetchall()[0][0] - conn.commit() + user = queries.CREATE_USER(user.model_dump()) + user_id = user['id'] except IntegrityError as e: raise HTTPException(status_code=400, detail="Username already in use") encryption_key = generate_random_encryption_key() - salt, master_key = generate_user_passkey(plain_text_password) key = Key( user_id=user_id, encryption_key=fernet_encrypt(encryption_key, master_key).decode('utf-8'), - encryption_key_salt = serialize_bytes(salt) + encryption_key_salt=serialize_bytes(salt) ) try: - conn.execute(queries.CREATE_KEY_QUERY, jsonable_encoder(key)) + queries.CREATE_KEY(key.model_dump()) except Exception as e: print('failed to create key', e) @@ -72,30 +69,21 @@ async def register(user: User): async def login(user: UserLogin): """logs in the user""" - users = [] - with open('database/users.json', 'r') as f: - text = f.read() - if text: - users.extend(json.loads(text)) + login_exception = HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="username or password is incorrect" + ) - cur_user = [i for i in users if i['username']==user.username] - if not cur_user: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="username or password is incorrect" - ) - else: - cur_user = cur_user[0] + cur_user = queries.GET_USER_WITH_KEY(user.model_dump()) + if cur_user is None: + raise login_exception password_match = Hasher.verify_password(user.password, cur_user['password']) if not password_match: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="username or password is incorrect" - ) + raise login_exception encrypted_encryption_key = cur_user['encryption_key'].encode() - salt = deserialize_into_bytes(cur_user['salt']) + salt = deserialize_into_bytes(cur_user['encryption_key_salt']) _, master_key = generate_user_passkey(user.password, salt) encryption_key = fernet_decrypt(encrypted_encryption_key, master_key) access_token = create_access_token(subject=cur_user['id'], encryption_key=encryption_key) @@ -114,27 +102,15 @@ async def create_secret(secret: Secret, current_user: dict = Depends(get_current Stores an encrypted secret for the user. """ - data = [] - with open('database/secrets.json', 'r') as f: - text = f.read() - if text: - data.extend(json.loads(text)) - - if data: - secret_id = max(i['id'] for i in data) + 1 - else: - secret_id = 0 - secret.id = secret_id secret.user_id = current_user['id'] encryption_key = current_user['encryption_key'].encode() encrypted_data = fernet_encrypt(secret.data.encode(), encryption_key) secret.data = encrypted_data.decode('utf-8') - data.append(jsonable_encoder(secret)) - with open('database/secrets.json', 'w') as f: - json.dump(data, f) + queries.CREATE_SECRET(secret.model_dump()) + return secret @@ -176,16 +152,9 @@ async def update_secret(secret: Secret, current_user: dict = Depends(get_current async def list_secret(current_user: dict = Depends(get_current_user)): """Returns the encrypted secrets of the user.""" - data = [] - with open('database/secrets.json', 'r') as f: - text = f.read() - if text: - data.extend(json.loads(text)) - - user_id = current_user['id'] + user_secrets = queries.GET_SECRETS({'user_id': current_user['id']}) encryption_key = current_user['encryption_key'].encode() - user_secrets = [i for i in data if i['user_id']==user_id and i['active']] for secret in user_secrets: cur_data = secret['data'] decrypted_data = fernet_decrypt(cur_data, encryption_key) @@ -197,8 +166,7 @@ async def list_secret(current_user: dict = Depends(get_current_user)): @app.get('/validate-token') async def validate_token(current_user: dict = Depends(get_current_user)): user_id = current_user['id'] - print("user_id: ", user_id) - if user_id is not None: - return {'message': 'authenticated'} + if user_id is None: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) \ No newline at end of file + return {'message': 'authenticated'}