Merge branch 'sqlalchemy' into 'dev'

Sqlalchemy

See merge request oscar.krause/fastapi-dls!10
This commit is contained in:
Oscar Krause 2022-12-28 09:15:03 +01:00
commit 943786099b
5 changed files with 148 additions and 26 deletions

View File

@ -18,6 +18,8 @@ build:
test: test:
image: python:3.10-slim-bullseye image: python:3.10-slim-bullseye
stage: test stage: test
variables:
DATABASE: sqlite:///../app/db.sqlite
before_script: before_script:
- pip install -r requirements.txt - pip install -r requirements.txt
- pip install pytest httpx - pip install pytest httpx

View File

@ -17,16 +17,18 @@ from jose import jws, jwk, jwt
from jose.constants import ALGORITHMS from jose.constants import ALGORITHMS
from starlette.middleware.cors import CORSMiddleware from starlette.middleware.cors import CORSMiddleware
from starlette.responses import StreamingResponse, JSONResponse, HTMLResponse from starlette.responses import StreamingResponse, JSONResponse, HTMLResponse
import dataset from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from Crypto.PublicKey import RSA from Crypto.PublicKey import RSA
from Crypto.PublicKey.RSA import RsaKey from Crypto.PublicKey.RSA import RsaKey
from orm import Origin, Lease, init as db_init
logger = logging.getLogger() logger = logging.getLogger()
load_dotenv('../version.env') load_dotenv('../version.env')
VERSION, COMMIT, DEBUG = getenv('VERSION', 'unknown'), getenv('COMMIT', 'unknown'), bool(getenv('DEBUG', False)) VERSION, COMMIT, DEBUG = getenv('VERSION', 'unknown'), getenv('COMMIT', 'unknown'), bool(getenv('DEBUG', False))
def load_file(filename) -> bytes: def load_file(filename) -> bytes:
with open(filename, 'rb') as file: with open(filename, 'rb') as file:
content = file.read() content = file.read()
@ -45,7 +47,8 @@ __details = dict(
version=VERSION, version=VERSION,
) )
app, db = FastAPI(**__details), dataset.connect(str(getenv('DATABASE', 'sqlite:///db.sqlite'))) app, db = FastAPI(**__details), create_engine(url=str(getenv('DATABASE', 'sqlite:///db.sqlite')))
db_init(db)
TOKEN_EXPIRE_DELTA = relativedelta(hours=1) # days=1 TOKEN_EXPIRE_DELTA = relativedelta(hours=1) # days=1
LEASE_EXPIRE_DELTA = relativedelta(days=int(getenv('LEASE_EXPIRE_DAYS', 90))) LEASE_EXPIRE_DELTA = relativedelta(days=int(getenv('LEASE_EXPIRE_DAYS', 90)))
@ -94,13 +97,17 @@ async def status(request: Request):
@app.get('/-/origins') @app.get('/-/origins')
async def _origins(request: Request): async def _origins(request: Request):
response = list(map(lambda x: jsonable_encoder(x), db['origin'].all())) session = sessionmaker(bind=db)()
response = list(map(lambda x: jsonable_encoder(x), session.query(Origin).all()))
session.close()
return JSONResponse(response) return JSONResponse(response)
@app.get('/-/leases') @app.get('/-/leases')
async def _leases(request: Request): async def _leases(request: Request):
response = list(map(lambda x: jsonable_encoder(x), db['lease'].all())) session = sessionmaker(bind=db)()
response = list(map(lambda x: jsonable_encoder(x), session.query(Lease).all()))
session.close()
return JSONResponse(response) return JSONResponse(response)
@ -159,14 +166,14 @@ async def auth_v1_origin(request: Request):
origin_ref = j['candidate_origin_ref'] origin_ref = j['candidate_origin_ref']
logging.info(f'> [ origin ]: {origin_ref}: {j}') logging.info(f'> [ origin ]: {origin_ref}: {j}')
data = dict( data = Origin(
origin_ref=origin_ref, origin_ref=origin_ref,
hostname=j['environment']['hostname'], hostname=j['environment']['hostname'],
guest_driver_version=j['environment']['guest_driver_version'], guest_driver_version=j['environment']['guest_driver_version'],
os_platform=j['environment']['os_platform'], os_version=j['environment']['os_version'], os_platform=j['environment']['os_platform'], os_version=j['environment']['os_version'],
) )
db['origin'].upsert(data, ['origin_ref']) Origin.create_or_update(db, data)
response = { response = {
"origin_ref": origin_ref, "origin_ref": origin_ref,
@ -190,14 +197,14 @@ async def auth_v1_origin_update(request: Request):
origin_ref = j['origin_ref'] origin_ref = j['origin_ref']
logging.info(f'> [ update ]: {origin_ref}: {j}') logging.info(f'> [ update ]: {origin_ref}: {j}')
data = dict( data = Origin(
origin_ref=origin_ref, origin_ref=origin_ref,
hostname=j['environment']['hostname'], hostname=j['environment']['hostname'],
guest_driver_version=j['environment']['guest_driver_version'], guest_driver_version=j['environment']['guest_driver_version'],
os_platform=j['environment']['os_platform'], os_version=j['environment']['os_version'], os_platform=j['environment']['os_platform'], os_version=j['environment']['os_version'],
) )
db['origin'].upsert(data, ['origin_ref']) Origin.create_or_update(db, data)
response = { response = {
"environment": j['environment'], "environment": j['environment'],
@ -306,8 +313,8 @@ async def leasing_v1_lessor(request: Request):
} }
}) })
data = dict(origin_ref=origin_ref, lease_ref=scope_ref, lease_created=cur_time, lease_expires=expires) data = Lease(origin_ref=origin_ref, lease_ref=scope_ref, lease_created=cur_time, lease_expires=expires)
db['lease'].insert_ignore(data, ['origin_ref', 'lease_ref']) # todo: handle update Lease.create_or_update(db, data)
response = { response = {
"lease_result_list": lease_result_list, "lease_result_list": lease_result_list,
@ -327,7 +334,7 @@ async def leasing_v1_lessor_lease(request: Request):
origin_ref = token['origin_ref'] origin_ref = token['origin_ref']
active_lease_list = list(map(lambda x: x['lease_ref'], db['lease'].find(origin_ref=origin_ref))) active_lease_list = list(map(lambda x: x.lease_ref, Lease.find_by_origin_ref(db, origin_ref)))
logging.info(f'> [ leases ]: {origin_ref}: found {len(active_lease_list)} active leases') logging.info(f'> [ leases ]: {origin_ref}: found {len(active_lease_list)} active leases')
response = { response = {
@ -347,14 +354,11 @@ async def leasing_v1_lease_renew(request: Request, lease_ref: str):
origin_ref = token['origin_ref'] origin_ref = token['origin_ref']
logging.info(f'> [ renew ]: {origin_ref}: renew {lease_ref}') logging.info(f'> [ renew ]: {origin_ref}: renew {lease_ref}')
if db['lease'].count(origin_ref=origin_ref, lease_ref=lease_ref) == 0: entity = Lease.find_by_origin_ref_and_lease_ref(db, origin_ref, lease_ref)
if entity is None:
raise HTTPException(status_code=404, detail='requested lease not available') raise HTTPException(status_code=404, detail='requested lease not available')
expires = cur_time + LEASE_EXPIRE_DELTA expires = cur_time + LEASE_EXPIRE_DELTA
data = dict(origin_ref=origin_ref, lease_ref=lease_ref, lease_expires=expires, lease_last_update=cur_time)
db['lease'].update(data, ['origin_ref', 'lease_ref'])
response = { response = {
"lease_ref": lease_ref, "lease_ref": lease_ref,
"expires": expires.isoformat(), "expires": expires.isoformat(),
@ -364,6 +368,8 @@ async def leasing_v1_lease_renew(request: Request, lease_ref: str):
"sync_timestamp": cur_time.isoformat(), "sync_timestamp": cur_time.isoformat(),
} }
Lease.renew(db, entity, expires, cur_time)
return JSONResponse(response) return JSONResponse(response)
@ -373,8 +379,8 @@ async def leasing_v1_lessor_lease_remove(request: Request):
origin_ref = token['origin_ref'] origin_ref = token['origin_ref']
released_lease_list = list(map(lambda x: x['lease_ref'], db['lease'].find(origin_ref=origin_ref))) released_lease_list = list(map(lambda x: x.lease_ref, Lease.find_by_origin_ref(db, origin_ref)))
deletions = db['lease'].delete(origin_ref=origin_ref) deletions = Lease.cleanup(db, origin_ref)
logging.info(f'> [ remove ]: {origin_ref}: removed {deletions} leases') logging.info(f'> [ remove ]: {origin_ref}: removed {deletions} leases')
response = { response = {

116
app/orm.py Normal file
View File

@ -0,0 +1,116 @@
import datetime
from sqlalchemy import Column, VARCHAR, CHAR, ForeignKey, DATETIME, UniqueConstraint, update, and_, delete, inspect
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.future import Engine
from sqlalchemy.orm import sessionmaker
Base = declarative_base()
class Origin(Base):
__tablename__ = "origin"
origin_ref = Column(CHAR(length=36), primary_key=True, unique=True, index=True) # uuid4
hostname = Column(VARCHAR(length=256), nullable=True)
guest_driver_version = Column(VARCHAR(length=10), nullable=True)
os_platform = Column(VARCHAR(length=256), nullable=True)
os_version = Column(VARCHAR(length=256), nullable=True)
def __repr__(self):
return f'Origin(origin_ref={self.origin_ref}, hostname={self.hostname})'
@staticmethod
def create_statement(engine: Engine):
from sqlalchemy.schema import CreateTable
return CreateTable(Origin.__table__).compile(engine)
@staticmethod
def create_or_update(engine: Engine, origin: "Origin"):
session = sessionmaker(autocommit=True, autoflush=True, bind=engine)()
entity = session.query(Origin).filter(Origin.origin_ref == origin.origin_ref).first()
print(entity)
if entity is None:
session.add(origin)
else:
values = dict(
hostname=origin.hostname,
guest_driver_version=origin.guest_driver_version,
os_platform=origin.os_platform,
os_version=origin.os_version,
)
session.execute(update(Origin).where(Origin.origin_ref == origin.origin_ref).values(**values))
session.flush()
session.close()
class Lease(Base):
__tablename__ = "lease"
origin_ref = Column(CHAR(length=36), ForeignKey(Origin.origin_ref), primary_key=True, nullable=False, index=True) # uuid4
lease_ref = Column(CHAR(length=36), primary_key=True, nullable=False, index=True) # uuid4
lease_created = Column(DATETIME(), nullable=False)
lease_expires = Column(DATETIME(), nullable=False)
lease_updated = Column(DATETIME(), nullable=False)
def __repr__(self):
return f'Lease(origin_ref={self.origin_ref}, lease_ref={self.lease_ref}, expires={self.lease_expires})'
@staticmethod
def create_statement(engine: Engine):
from sqlalchemy.schema import CreateTable
return CreateTable(Lease.__table__).compile(engine)
@staticmethod
def create_or_update(engine: Engine, lease: "Lease"):
session = sessionmaker(autocommit=True, autoflush=True, bind=engine)()
entity = session.query(Lease).filter(and_(Lease.origin_ref == lease.origin_ref, Lease.lease_ref == lease.lease_ref)).first()
if entity is None:
if lease.lease_updated is None:
lease.lease_updated = lease.lease_created
session.add(lease)
else:
values = dict(lease_expires=lease.lease_expires, lease_updated=lease.lease_updated)
session.execute(update(Lease).where(and_(Lease.origin_ref == lease.origin_ref, Lease.lease_ref == lease.lease_ref)).values(**values))
session.flush()
session.close()
@staticmethod
def find_by_origin_ref(engine: Engine, origin_ref: str) -> ["Lease"]:
session = sessionmaker(autocommit=True, autoflush=True, bind=engine)()
entities = session.query(Lease).filter(Lease.origin_ref == origin_ref).all()
session.close()
return entities
@staticmethod
def find_by_origin_ref_and_lease_ref(engine: Engine, origin_ref: str, lease_ref: str) -> "Lease":
session = sessionmaker(autocommit=True, autoflush=True, bind=engine)()
entity = session.query(Lease).filter(and_(Lease.origin_ref == origin_ref, Lease.lease_ref == lease_ref)).first()
session.close()
return entity
@staticmethod
def renew(engine: Engine, lease: "Lease", lease_expires: datetime.datetime, lease_updated: datetime.datetime):
session = sessionmaker(autocommit=True, autoflush=True, bind=engine)()
values = dict(lease_expires=lease.lease_expires, lease_updated=lease.lease_updated)
session.execute(update(Lease).where(and_(Lease.origin_ref == lease.origin_ref, Lease.lease_ref == lease.lease_ref)).values(**values))
session.close()
@staticmethod
def cleanup(engine: Engine, origin_ref: str) -> int:
session = sessionmaker(autocommit=True, autoflush=True, bind=engine)()
deletions = session.query(Lease).filter(Lease.origin_ref == origin_ref).delete()
session.close()
return deletions
def init(engine: Engine):
tables = [Origin, Lease]
db = inspect(engine)
session = sessionmaker(bind=engine)()
for table in tables:
if not db.dialect.has_table(engine.connect(), table.__tablename__):
session.execute(str(table.create_statement(engine)))
session.close()

View File

@ -3,6 +3,6 @@ uvicorn[standard]==0.20.0
python-jose==3.3.0 python-jose==3.3.0
pycryptodome==3.16.0 pycryptodome==3.16.0
python-dateutil==2.8.2 python-dateutil==2.8.2
dataset==1.5.2 sqlalchemy==1.4.45
markdown==3.4.1 markdown==3.4.1
python-dotenv==0.21.0 python-dotenv==0.21.0

View File

@ -2,15 +2,13 @@ from uuid import uuid4
from jose import jwt from jose import jwt
from starlette.testclient import TestClient from starlette.testclient import TestClient
import importlib.util
import sys import sys
MODULE, PATH = 'main.app', '../app/main.py' # add relative path to use packages as they were in the app/ dir
sys.path.append('../')
sys.path.append('../app')
spec = importlib.util.spec_from_file_location(MODULE, PATH) from app import main
main = importlib.util.module_from_spec(spec)
sys.modules[MODULE] = main
spec.loader.exec_module(main)
client = TestClient(main.app) client = TestClient(main.app)