diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 49e8c5f..20ae5f0 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -18,6 +18,8 @@ build: test: image: python:3.10-slim-bullseye stage: test + variables: + DATABASE: sqlite:///../app/db.sqlite before_script: - pip install -r requirements.txt - pip install pytest httpx diff --git a/app/main.py b/app/main.py index 1ac62e6..5afa0f6 100644 --- a/app/main.py +++ b/app/main.py @@ -17,16 +17,18 @@ from jose import jws, jwk, jwt from jose.constants import ALGORITHMS from starlette.middleware.cors import CORSMiddleware from starlette.responses import StreamingResponse, JSONResponse, HTMLResponse -import dataset +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker from Crypto.PublicKey import RSA from Crypto.PublicKey.RSA import RsaKey +from orm import Origin, Lease, init as db_init + logger = logging.getLogger() load_dotenv('../version.env') VERSION, COMMIT, DEBUG = getenv('VERSION', 'unknown'), getenv('COMMIT', 'unknown'), bool(getenv('DEBUG', False)) - def load_file(filename) -> bytes: with open(filename, 'rb') as file: content = file.read() @@ -45,7 +47,8 @@ __details = dict( version=VERSION, ) -app, db = FastAPI(**__details), dataset.connect(str(getenv('DATABASE', 'sqlite:///db.sqlite'))) +app, db = FastAPI(**__details), create_engine(url=str(getenv('DATABASE', 'sqlite:///db.sqlite'))) +db_init(db) TOKEN_EXPIRE_DELTA = relativedelta(hours=1) # days=1 LEASE_EXPIRE_DELTA = relativedelta(days=int(getenv('LEASE_EXPIRE_DAYS', 90))) @@ -94,13 +97,17 @@ async def status(request: Request): @app.get('/-/origins') async def _origins(request: Request): - response = list(map(lambda x: jsonable_encoder(x), db['origin'].all())) + session = sessionmaker(bind=db)() + response = list(map(lambda x: jsonable_encoder(x), session.query(Origin).all())) + session.close() return JSONResponse(response) @app.get('/-/leases') async def _leases(request: Request): - response = list(map(lambda x: jsonable_encoder(x), db['lease'].all())) + session = sessionmaker(bind=db)() + response = list(map(lambda x: jsonable_encoder(x), session.query(Lease).all())) + session.close() return JSONResponse(response) @@ -159,14 +166,14 @@ async def auth_v1_origin(request: Request): origin_ref = j['candidate_origin_ref'] logging.info(f'> [ origin ]: {origin_ref}: {j}') - data = dict( + data = Origin( origin_ref=origin_ref, hostname=j['environment']['hostname'], guest_driver_version=j['environment']['guest_driver_version'], os_platform=j['environment']['os_platform'], os_version=j['environment']['os_version'], ) - db['origin'].upsert(data, ['origin_ref']) + Origin.create_or_update(db, data) response = { "origin_ref": origin_ref, @@ -190,14 +197,14 @@ async def auth_v1_origin_update(request: Request): origin_ref = j['origin_ref'] logging.info(f'> [ update ]: {origin_ref}: {j}') - data = dict( + data = Origin( origin_ref=origin_ref, hostname=j['environment']['hostname'], guest_driver_version=j['environment']['guest_driver_version'], os_platform=j['environment']['os_platform'], os_version=j['environment']['os_version'], ) - db['origin'].upsert(data, ['origin_ref']) + Origin.create_or_update(db, data) response = { "environment": j['environment'], @@ -306,8 +313,8 @@ async def leasing_v1_lessor(request: Request): } }) - data = dict(origin_ref=origin_ref, lease_ref=scope_ref, lease_created=cur_time, lease_expires=expires) - db['lease'].insert_ignore(data, ['origin_ref', 'lease_ref']) # todo: handle update + data = Lease(origin_ref=origin_ref, lease_ref=scope_ref, lease_created=cur_time, lease_expires=expires) + Lease.create_or_update(db, data) response = { "lease_result_list": lease_result_list, @@ -327,7 +334,7 @@ async def leasing_v1_lessor_lease(request: Request): origin_ref = token['origin_ref'] - active_lease_list = list(map(lambda x: x['lease_ref'], db['lease'].find(origin_ref=origin_ref))) + active_lease_list = list(map(lambda x: x.lease_ref, Lease.find_by_origin_ref(db, origin_ref))) logging.info(f'> [ leases ]: {origin_ref}: found {len(active_lease_list)} active leases') response = { @@ -347,14 +354,11 @@ async def leasing_v1_lease_renew(request: Request, lease_ref: str): origin_ref = token['origin_ref'] logging.info(f'> [ renew ]: {origin_ref}: renew {lease_ref}') - if db['lease'].count(origin_ref=origin_ref, lease_ref=lease_ref) == 0: + entity = Lease.find_by_origin_ref_and_lease_ref(db, origin_ref, lease_ref) + if entity is None: raise HTTPException(status_code=404, detail='requested lease not available') expires = cur_time + LEASE_EXPIRE_DELTA - - data = dict(origin_ref=origin_ref, lease_ref=lease_ref, lease_expires=expires, lease_last_update=cur_time) - db['lease'].update(data, ['origin_ref', 'lease_ref']) - response = { "lease_ref": lease_ref, "expires": expires.isoformat(), @@ -364,6 +368,8 @@ async def leasing_v1_lease_renew(request: Request, lease_ref: str): "sync_timestamp": cur_time.isoformat(), } + Lease.renew(db, entity, expires, cur_time) + return JSONResponse(response) @@ -373,8 +379,8 @@ async def leasing_v1_lessor_lease_remove(request: Request): origin_ref = token['origin_ref'] - released_lease_list = list(map(lambda x: x['lease_ref'], db['lease'].find(origin_ref=origin_ref))) - deletions = db['lease'].delete(origin_ref=origin_ref) + released_lease_list = list(map(lambda x: x.lease_ref, Lease.find_by_origin_ref(db, origin_ref))) + deletions = Lease.cleanup(db, origin_ref) logging.info(f'> [ remove ]: {origin_ref}: removed {deletions} leases') response = { diff --git a/app/orm.py b/app/orm.py new file mode 100644 index 0000000..26b869a --- /dev/null +++ b/app/orm.py @@ -0,0 +1,116 @@ +import datetime + +from sqlalchemy import Column, VARCHAR, CHAR, ForeignKey, DATETIME, UniqueConstraint, update, and_, delete, inspect +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.future import Engine +from sqlalchemy.orm import sessionmaker + +Base = declarative_base() + + +class Origin(Base): + __tablename__ = "origin" + + origin_ref = Column(CHAR(length=36), primary_key=True, unique=True, index=True) # uuid4 + + hostname = Column(VARCHAR(length=256), nullable=True) + guest_driver_version = Column(VARCHAR(length=10), nullable=True) + os_platform = Column(VARCHAR(length=256), nullable=True) + os_version = Column(VARCHAR(length=256), nullable=True) + + def __repr__(self): + return f'Origin(origin_ref={self.origin_ref}, hostname={self.hostname})' + + @staticmethod + def create_statement(engine: Engine): + from sqlalchemy.schema import CreateTable + return CreateTable(Origin.__table__).compile(engine) + + @staticmethod + def create_or_update(engine: Engine, origin: "Origin"): + session = sessionmaker(autocommit=True, autoflush=True, bind=engine)() + entity = session.query(Origin).filter(Origin.origin_ref == origin.origin_ref).first() + print(entity) + if entity is None: + session.add(origin) + else: + values = dict( + hostname=origin.hostname, + guest_driver_version=origin.guest_driver_version, + os_platform=origin.os_platform, + os_version=origin.os_version, + ) + session.execute(update(Origin).where(Origin.origin_ref == origin.origin_ref).values(**values)) + session.flush() + session.close() + + +class Lease(Base): + __tablename__ = "lease" + + origin_ref = Column(CHAR(length=36), ForeignKey(Origin.origin_ref), primary_key=True, nullable=False, index=True) # uuid4 + lease_ref = Column(CHAR(length=36), primary_key=True, nullable=False, index=True) # uuid4 + + lease_created = Column(DATETIME(), nullable=False) + lease_expires = Column(DATETIME(), nullable=False) + lease_updated = Column(DATETIME(), nullable=False) + + def __repr__(self): + return f'Lease(origin_ref={self.origin_ref}, lease_ref={self.lease_ref}, expires={self.lease_expires})' + + @staticmethod + def create_statement(engine: Engine): + from sqlalchemy.schema import CreateTable + return CreateTable(Lease.__table__).compile(engine) + + @staticmethod + def create_or_update(engine: Engine, lease: "Lease"): + session = sessionmaker(autocommit=True, autoflush=True, bind=engine)() + entity = session.query(Lease).filter(and_(Lease.origin_ref == lease.origin_ref, Lease.lease_ref == lease.lease_ref)).first() + if entity is None: + if lease.lease_updated is None: + lease.lease_updated = lease.lease_created + session.add(lease) + else: + values = dict(lease_expires=lease.lease_expires, lease_updated=lease.lease_updated) + session.execute(update(Lease).where(and_(Lease.origin_ref == lease.origin_ref, Lease.lease_ref == lease.lease_ref)).values(**values)) + session.flush() + session.close() + + @staticmethod + def find_by_origin_ref(engine: Engine, origin_ref: str) -> ["Lease"]: + session = sessionmaker(autocommit=True, autoflush=True, bind=engine)() + entities = session.query(Lease).filter(Lease.origin_ref == origin_ref).all() + session.close() + return entities + + @staticmethod + def find_by_origin_ref_and_lease_ref(engine: Engine, origin_ref: str, lease_ref: str) -> "Lease": + session = sessionmaker(autocommit=True, autoflush=True, bind=engine)() + entity = session.query(Lease).filter(and_(Lease.origin_ref == origin_ref, Lease.lease_ref == lease_ref)).first() + session.close() + return entity + + @staticmethod + def renew(engine: Engine, lease: "Lease", lease_expires: datetime.datetime, lease_updated: datetime.datetime): + session = sessionmaker(autocommit=True, autoflush=True, bind=engine)() + values = dict(lease_expires=lease.lease_expires, lease_updated=lease.lease_updated) + session.execute(update(Lease).where(and_(Lease.origin_ref == lease.origin_ref, Lease.lease_ref == lease.lease_ref)).values(**values)) + session.close() + + @staticmethod + def cleanup(engine: Engine, origin_ref: str) -> int: + session = sessionmaker(autocommit=True, autoflush=True, bind=engine)() + deletions = session.query(Lease).filter(Lease.origin_ref == origin_ref).delete() + session.close() + return deletions + + +def init(engine: Engine): + tables = [Origin, Lease] + db = inspect(engine) + session = sessionmaker(bind=engine)() + for table in tables: + if not db.dialect.has_table(engine.connect(), table.__tablename__): + session.execute(str(table.create_statement(engine))) + session.close() diff --git a/requirements.txt b/requirements.txt index 413b6d1..cceb6a7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,6 +3,6 @@ uvicorn[standard]==0.20.0 python-jose==3.3.0 pycryptodome==3.16.0 python-dateutil==2.8.2 -dataset==1.5.2 +sqlalchemy==1.4.45 markdown==3.4.1 python-dotenv==0.21.0 diff --git a/test/main.py b/test/main.py index 0711de5..22c1c6b 100644 --- a/test/main.py +++ b/test/main.py @@ -2,15 +2,13 @@ from uuid import uuid4 from jose import jwt from starlette.testclient import TestClient -import importlib.util import sys -MODULE, PATH = 'main.app', '../app/main.py' +# add relative path to use packages as they were in the app/ dir +sys.path.append('../') +sys.path.append('../app') -spec = importlib.util.spec_from_file_location(MODULE, PATH) -main = importlib.util.module_from_spec(spec) -sys.modules[MODULE] = main -spec.loader.exec_module(main) +from app import main client = TestClient(main.app)