diff --git a/main.py b/main.py index 9ddce9d..962c4e4 100644 --- a/main.py +++ b/main.py @@ -1,20 +1,17 @@ import json +from sqlite3 import IntegrityError from fastapi import Depends, FastAPI, HTTPException, status from fastapi.encoders import jsonable_encoder from fastapi.middleware.cors import CORSMiddleware from fastapi.security import OAuth2PasswordBearer +import queries from auth import Hasher, create_access_token, get_current_user -from crypto import ( - deserialize_into_bytes, - fernet_decrypt, - fernet_encrypt, - generate_random_encryption_key, - generate_user_passkey, - serialize_bytes, -) -from models import Secret, User, UserLogin +from crypto import (deserialize_into_bytes, fernet_decrypt, fernet_encrypt, + generate_random_encryption_key, generate_user_passkey, + serialize_bytes) +from models import Key, Secret, User, UserLogin app = FastAPI() @@ -42,42 +39,33 @@ async def root(): async def register(user: User): """Registers a user""" - users = [] - with open('database/users.json', 'r') as f: - text = f.read() - if text: - users = json.loads(text) + plain_text_password = user.password + user.password = Hasher.get_password_hash(plain_text_password) - if user.id is not None: - raise HTTPException( - status_code=400, - detail="User id shall be auto generated, cannot be provided in request" - ) - if not users: - user.id = 0 - else: - max_user_id = max([i['id'] for i in users]) - user.id = max_user_id + 1 - - user_exists = [i for i in users if i['username'] == user.username] - if user_exists: + conn = queries.connect_db() + try: + cur = conn.execute(queries.CREATE_USER_QUERY, jsonable_encoder(user)) + user_id = cur.fetchall()[0][0] + conn.commit() + except IntegrityError as e: raise HTTPException(status_code=400, detail="Username already in use") encryption_key = generate_random_encryption_key() - salt, master_key = generate_user_passkey(user.password) - encrypted_encryption_key = fernet_encrypt(encryption_key, master_key) + salt, master_key = generate_user_passkey(plain_text_password) - user.password = Hasher.get_password_hash(user.password) - user.encryption_key = encrypted_encryption_key.decode('utf-8') - user.salt = serialize_bytes(salt) + key = Key( + user_id=user_id, + encryption_key=fernet_encrypt(encryption_key, master_key).decode('utf-8'), + encryption_key_salt = serialize_bytes(salt) + ) - users.append(jsonable_encoder(user)) - # print(f"{salt=}\n{user.salt=}\n{encrypted_encryption_key=}\n{user.encryption_key=}\n{master_key=}") - with open('database/users.json', 'w') as f: - json.dump(users, f) + try: + conn.execute(queries.CREATE_KEY_QUERY, jsonable_encoder(key)) + except Exception as e: + print('failed to create key', e) - return {'user_id': user.id} + return {'user_id': user_id} @app.post('/login') diff --git a/queries.py b/queries.py new file mode 100644 index 0000000..d391cce --- /dev/null +++ b/queries.py @@ -0,0 +1,25 @@ +import sqlite3 + + +def connect_db(): + conn = sqlite3.connect('database/database.sqlite') + conn.execute("PRAGMA foreign_keys = 1") + return conn + + +CREATE_USER_QUERY = """ + INSERT INTO + users ( + name, email, username, password + ) + VALUES ( + :name, :email, :username, :password + ) + returning id +""" + +CREATE_KEY_QUERY = """ + INSERT INTO + keys (user_id, encryption_key, encryption_key_salt) + VALUES (:user_id, :encryption_key, :encryption_key_salt) +""" \ No newline at end of file