From 922dc9f5a7a4f4a3204d865fb9e413b18d9ba6ad Mon Sep 17 00:00:00 2001 From: Oscar Krause Date: Thu, 29 Dec 2022 09:40:36 +0100 Subject: [PATCH] refactored database structure and created migration script --- app/main.py | 4 ++-- app/orm.py | 35 +++++++++++++++++++++++++---------- 2 files changed, 27 insertions(+), 12 deletions(-) diff --git a/app/main.py b/app/main.py index 540aace..a50ed45 100644 --- a/app/main.py +++ b/app/main.py @@ -20,7 +20,7 @@ from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker from util import load_key, load_file -from orm import Origin, Lease, init as db_init +from orm import Origin, Lease, init as db_init, migrate logger = logging.getLogger() load_dotenv('../version.env') @@ -29,7 +29,7 @@ VERSION, COMMIT, DEBUG = env('VERSION', 'unknown'), env('COMMIT', 'unknown'), bo app = FastAPI(title='FastAPI-DLS', description='Minimal Delegated License Service (DLS).', version=VERSION) db = create_engine(str(env('DATABASE', 'sqlite:///db.sqlite'))) -db_init(db) +db_init(db), migrate(db) DLS_URL = str(env('DLS_URL', 'localhost')) DLS_PORT = int(env('DLS_PORT', '443')) diff --git a/app/orm.py b/app/orm.py index c451b29..aa3ba47 100644 --- a/app/orm.py +++ b/app/orm.py @@ -1,6 +1,6 @@ import datetime -from sqlalchemy import Column, VARCHAR, CHAR, ForeignKey, DATETIME, UniqueConstraint, update, and_, delete, inspect +from sqlalchemy import Column, VARCHAR, CHAR, ForeignKey, DATETIME, update, and_, inspect from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker @@ -43,13 +43,13 @@ class Origin(Base): if entity is None: session.add(origin) else: - values = dict( + x = dict( hostname=origin.hostname, guest_driver_version=origin.guest_driver_version, os_platform=origin.os_platform, - os_version=origin.os_version, + os_version=origin.os_version ) - session.execute(update(Origin).where(Origin.origin_ref == origin.origin_ref).values(**values)) + session.execute(update(Origin).where(Origin.origin_ref == origin.origin_ref).values(**x)) session.commit() session.flush() session.close() @@ -58,9 +58,9 @@ class Origin(Base): class Lease(Base): __tablename__ = "lease" - origin_ref = Column(CHAR(length=36), ForeignKey(Origin.origin_ref), primary_key=True, nullable=False, index=True) # uuid4 lease_ref = Column(CHAR(length=36), primary_key=True, nullable=False, index=True) # uuid4 + origin_ref = Column(CHAR(length=36), ForeignKey(Origin.origin_ref), nullable=False, index=True) # uuid4 lease_created = Column(DATETIME(), nullable=False) lease_expires = Column(DATETIME(), nullable=False) lease_updated = Column(DATETIME(), nullable=False) @@ -85,14 +85,14 @@ class Lease(Base): @staticmethod def create_or_update(engine: Engine, lease: "Lease"): session = sessionmaker(bind=engine)() - entity = session.query(Lease).filter(and_(Lease.origin_ref == lease.origin_ref, Lease.lease_ref == lease.lease_ref)).first() + entity = session.query(Lease).filter(Lease.lease_ref == lease.lease_ref).first() if entity is None: if lease.lease_updated is None: lease.lease_updated = lease.lease_created session.add(lease) else: - values = dict(lease_expires=lease.lease_expires, lease_updated=lease.lease_updated) - session.execute(update(Lease).where(and_(Lease.origin_ref == lease.origin_ref, Lease.lease_ref == lease.lease_ref)).values(**values)) + x = dict(origin_ref=lease.origin_ref, lease_expires=lease.lease_expires, lease_updated=lease.lease_updated) + session.execute(update(Lease).where(Lease.lease_ref == lease.lease_ref).values(**x)) session.commit() session.flush() session.close() @@ -114,8 +114,8 @@ class Lease(Base): @staticmethod def renew(engine: Engine, lease: "Lease", lease_expires: datetime.datetime, lease_updated: datetime.datetime): session = sessionmaker(bind=engine)() - values = dict(lease_expires=lease.lease_expires, lease_updated=lease.lease_updated) - session.execute(update(Lease).where(and_(Lease.origin_ref == lease.origin_ref, Lease.lease_ref == lease.lease_ref)).values(**values)) + x = dict(lease_expires=lease.lease_expires, lease_updated=lease.lease_updated) + session.execute(update(Lease).where(and_(Lease.origin_ref == lease.origin_ref, Lease.lease_ref == lease.lease_ref)).values(**x)) session.commit() session.close() @@ -137,3 +137,18 @@ def init(engine: Engine): session.execute(str(table.create_statement(engine))) session.commit() session.close() + + +def migrate(engine: Engine): + db = inspect(engine) + + def upgrade_1_0_to_1_1(): + x = db.dialect.get_columns(engine.connect(), Lease.__tablename__) + x = next(_ for _ in x if _['name'] == 'origin_ref') + if x['primary_key'] > 0: + print('Found old database schema with "origin_ref" as primary-key in "lease" table. Dropping table!') + print(' Your leases are recreated on next renewal!') + print(' If an error message appears on the client, you can ignore it.') + Lease.__table__.drop(bind=engine) + + upgrade_1_0_to_1_1()