migrated api to database-config (Site, Instance)

This commit is contained in:
Oscar Krause 2023-06-12 15:19:06 +02:00
parent 18807401e4
commit 39a2408d8d
3 changed files with 138 additions and 62 deletions

View File

@ -9,18 +9,17 @@ from dotenv import load_dotenv
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.requests import Request from fastapi.requests import Request
from json import loads as json_loads from json import loads as json_loads
from datetime import datetime, timedelta from datetime import datetime
from dateutil.relativedelta import relativedelta from dateutil.relativedelta import relativedelta
from calendar import timegm from calendar import timegm
from jose import jws, jwk, jwt, JWTError from jose import jws, jwt, JWTError
from jose.constants import ALGORITHMS from jose.constants import ALGORITHMS
from starlette.middleware.cors import CORSMiddleware from starlette.middleware.cors import CORSMiddleware
from starlette.responses import StreamingResponse, JSONResponse as JSONr, HTMLResponse as HTMLr, Response, RedirectResponse from starlette.responses import StreamingResponse, JSONResponse as JSONr, HTMLResponse as HTMLr, Response, RedirectResponse
from sqlalchemy import create_engine from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
from util import load_key, load_file from orm import init as db_init, migrate, Site, Instance, Origin, Lease
from orm import Origin, Lease, init as db_init, migrate
load_dotenv('../version.env') load_dotenv('../version.env')
@ -41,20 +40,9 @@ db_init(db), migrate(db)
# DLS setup (static) # DLS setup (static)
DLS_URL = str(env('DLS_URL', 'localhost')) DLS_URL = str(env('DLS_URL', 'localhost'))
DLS_PORT = int(env('DLS_PORT', '443')) DLS_PORT = int(env('DLS_PORT', '443'))
SITE_KEY_XID = str(env('SITE_KEY_XID', '00000000-0000-0000-0000-000000000000'))
INSTANCE_REF = str(env('INSTANCE_REF', '10000000-0000-0000-0000-000000000001'))
ALLOTMENT_REF = str(env('ALLOTMENT_REF', '20000000-0000-0000-0000-000000000001'))
INSTANCE_KEY_RSA = load_key(str(env('INSTANCE_KEY_RSA', join(dirname(__file__), 'cert/instance.private.pem'))))
INSTANCE_KEY_PUB = load_key(str(env('INSTANCE_KEY_PUB', join(dirname(__file__), 'cert/instance.public.pem'))))
TOKEN_EXPIRE_DELTA = relativedelta(days=int(env('TOKEN_EXPIRE_DAYS', 1)), hours=int(env('TOKEN_EXPIRE_HOURS', 0)))
LEASE_EXPIRE_DELTA = relativedelta(days=int(env('LEASE_EXPIRE_DAYS', 90)), hours=int(env('LEASE_EXPIRE_HOURS', 0)))
LEASE_RENEWAL_PERIOD = float(env('LEASE_RENEWAL_PERIOD', 0.15))
LEASE_RENEWAL_DELTA = timedelta(days=int(env('LEASE_EXPIRE_DAYS', 90)), hours=int(env('LEASE_EXPIRE_HOURS', 0)))
CLIENT_TOKEN_EXPIRE_DELTA = relativedelta(years=12)
CORS_ORIGINS = str(env('CORS_ORIGINS', '')).split(',') if (env('CORS_ORIGINS')) else [f'https://{DLS_URL}'] CORS_ORIGINS = str(env('CORS_ORIGINS', '')).split(',') if (env('CORS_ORIGINS')) else [f'https://{DLS_URL}']
jwt_encode_key = jwk.construct(INSTANCE_KEY_RSA.export_key().decode('utf-8'), algorithm=ALGORITHMS.RS256) ALLOTMENT_REF = str(env('ALLOTMENT_REF', '20000000-0000-0000-0000-000000000001')) # todo
jwt_decode_key = jwk.construct(INSTANCE_KEY_PUB.export_key().decode('utf-8'), algorithm=ALGORITHMS.RS256)
# fastapi middleware # fastapi middleware
app.debug = DEBUG app.debug = DEBUG
@ -72,7 +60,19 @@ logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG if DEBUG else logging.INFO) logger.setLevel(logging.DEBUG if DEBUG else logging.INFO)
def __get_token(request: Request) -> dict: def validate_settings():
session = sessionmaker(bind=db)()
lease_expire_delta_min, lease_expire_delta_max = 86_400, 7_776_000
for instance in session.query(Instance).all():
lease_expire_delta = instance.lease_expire_delta
if lease_expire_delta < 86_400 or lease_expire_delta > 7_776_000:
logging.warning(f'> [ instance ]: {instance.instance_ref}: "lease_expire_delta" should be between {lease_expire_delta_min} and {lease_expire_delta_max}')
session.close()
def __get_token(request: Request, jwt_decode_key: "jose.jwt") -> dict:
authorization_header = request.headers.get('authorization') authorization_header = request.headers.get('authorization')
token = authorization_header.split(' ')[1] token = authorization_header.split(' ')[1]
return jwt.decode(token=token, key=jwt_decode_key, algorithms=ALGORITHMS.RS256, options={'verify_aud': False}) return jwt.decode(token=token, key=jwt_decode_key, algorithms=ALGORITHMS.RS256, options={'verify_aud': False})
@ -95,18 +95,20 @@ async def _health():
@app.get('/-/config', summary='* Config', description='returns environment variables.') @app.get('/-/config', summary='* Config', description='returns environment variables.')
async def _config(): async def _config():
default_site, default_instance = Site.get_default_site(db), Instance.get_default_instance(db)
return JSONr({ return JSONr({
'VERSION': str(VERSION), 'VERSION': str(VERSION),
'COMMIT': str(COMMIT), 'COMMIT': str(COMMIT),
'DEBUG': str(DEBUG), 'DEBUG': str(DEBUG),
'DLS_URL': str(DLS_URL), 'DLS_URL': str(DLS_URL),
'DLS_PORT': str(DLS_PORT), 'DLS_PORT': str(DLS_PORT),
'SITE_KEY_XID': str(SITE_KEY_XID), 'SITE_KEY_XID': str(default_site.site_key),
'INSTANCE_REF': str(INSTANCE_REF), 'INSTANCE_REF': str(default_instance.instance_ref),
'ALLOTMENT_REF': [str(ALLOTMENT_REF)], 'ALLOTMENT_REF': [str(ALLOTMENT_REF)],
'TOKEN_EXPIRE_DELTA': str(TOKEN_EXPIRE_DELTA), 'TOKEN_EXPIRE_DELTA': str(default_instance.get_token_expire_delta()),
'LEASE_EXPIRE_DELTA': str(LEASE_EXPIRE_DELTA), 'LEASE_EXPIRE_DELTA': str(default_instance.get_lease_expire_delta()),
'LEASE_RENEWAL_PERIOD': str(LEASE_RENEWAL_PERIOD), 'LEASE_RENEWAL_PERIOD': str(default_instance.lease_renewal_period),
'CORS_ORIGINS': str(CORS_ORIGINS), 'CORS_ORIGINS': str(CORS_ORIGINS),
'TZ': str(TZ), 'TZ': str(TZ),
}) })
@ -166,8 +168,7 @@ async def _origins(request: Request, leases: bool = False):
for origin in session.query(Origin).all(): for origin in session.query(Origin).all():
x = origin.serialize() x = origin.serialize()
if leases: if leases:
serialize = dict(renewal_period=LEASE_RENEWAL_PERIOD, renewal_delta=LEASE_RENEWAL_DELTA) x['leases'] = list(map(lambda _: _.serialize(), Lease.find_by_origin_ref(db, origin.origin_ref)))
x['leases'] = list(map(lambda _: _.serialize(**serialize), Lease.find_by_origin_ref(db, origin.origin_ref)))
response.append(x) response.append(x)
session.close() session.close()
return JSONr(response) return JSONr(response)
@ -184,8 +185,7 @@ async def _leases(request: Request, origin: bool = False):
session = sessionmaker(bind=db)() session = sessionmaker(bind=db)()
response = [] response = []
for lease in session.query(Lease).all(): for lease in session.query(Lease).all():
serialize = dict(renewal_period=LEASE_RENEWAL_PERIOD, renewal_delta=LEASE_RENEWAL_DELTA) x = lease.serialize()
x = lease.serialize(**serialize)
if origin: if origin:
lease_origin = session.query(Origin).filter(Origin.origin_ref == lease.origin_ref).first() lease_origin = session.query(Origin).filter(Origin.origin_ref == lease.origin_ref).first()
if lease_origin is not None: if lease_origin is not None:
@ -206,7 +206,13 @@ async def _lease_delete(request: Request, lease_ref: str):
@app.get('/-/client-token', summary='* Client-Token', description='creates a new messenger token for this service instance') @app.get('/-/client-token', summary='* Client-Token', description='creates a new messenger token for this service instance')
async def _client_token(): async def _client_token():
cur_time = datetime.utcnow() cur_time = datetime.utcnow()
exp_time = cur_time + CLIENT_TOKEN_EXPIRE_DELTA
default_instance = Instance.get_default_instance(db)
public_key = default_instance.get_public_key()
# todo: implemented request parameter to support different instances
jwt_encode_key = default_instance.get_jwt_encode_key()
exp_time = cur_time + default_instance.get_client_token_expire_delta()
payload = { payload = {
"jti": str(uuid4()), "jti": str(uuid4()),
@ -219,7 +225,7 @@ async def _client_token():
"scope_ref_list": [ALLOTMENT_REF], "scope_ref_list": [ALLOTMENT_REF],
"fulfillment_class_ref_list": [], "fulfillment_class_ref_list": [],
"service_instance_configuration": { "service_instance_configuration": {
"nls_service_instance_ref": INSTANCE_REF, "nls_service_instance_ref": default_instance.instance_ref,
"svc_port_set_list": [ "svc_port_set_list": [
{ {
"idx": 0, "idx": 0,
@ -231,10 +237,10 @@ async def _client_token():
}, },
"service_instance_public_key_configuration": { "service_instance_public_key_configuration": {
"service_instance_public_key_me": { "service_instance_public_key_me": {
"mod": hex(INSTANCE_KEY_PUB.public_key().n)[2:], "mod": hex(public_key.public_key().n)[2:],
"exp": int(INSTANCE_KEY_PUB.public_key().e), "exp": int(public_key.public_key().e),
}, },
"service_instance_public_key_pem": INSTANCE_KEY_PUB.export_key().decode('utf-8'), "service_instance_public_key_pem": public_key.export_key().decode('utf-8'),
"key_retention_mode": "LATEST_ONLY" "key_retention_mode": "LATEST_ONLY"
}, },
} }
@ -316,13 +322,16 @@ async def auth_v1_code(request: Request):
delta = relativedelta(minutes=15) delta = relativedelta(minutes=15)
expires = cur_time + delta expires = cur_time + delta
default_site = Site.get_default_site(db)
jwt_encode_key = Instance.get_default_instance(db).get_jwt_encode_key()
payload = { payload = {
'iat': timegm(cur_time.timetuple()), 'iat': timegm(cur_time.timetuple()),
'exp': timegm(expires.timetuple()), 'exp': timegm(expires.timetuple()),
'challenge': j.get('code_challenge'), 'challenge': j.get('code_challenge'),
'origin_ref': j.get('origin_ref'), 'origin_ref': j.get('origin_ref'),
'key_ref': SITE_KEY_XID, 'key_ref': default_site.site_key,
'kid': SITE_KEY_XID 'kid': default_site.site_key,
} }
auth_code = jws.sign(payload, key=jwt_encode_key, headers={'kid': payload.get('kid')}, algorithm=ALGORITHMS.RS256) auth_code = jws.sign(payload, key=jwt_encode_key, headers={'kid': payload.get('kid')}, algorithm=ALGORITHMS.RS256)
@ -342,8 +351,11 @@ async def auth_v1_code(request: Request):
async def auth_v1_token(request: Request): async def auth_v1_token(request: Request):
j, cur_time = json_loads((await request.body()).decode('utf-8')), datetime.utcnow() j, cur_time = json_loads((await request.body()).decode('utf-8')), datetime.utcnow()
default_site, default_instance = Site.get_default_site(db), Instance.get_default_instance(db)
jwt_encode_key, jwt_decode_key = default_instance.get_jwt_encode_key(), default_instance.get_jwt_decode_key()
try: try:
payload = jwt.decode(token=j.get('auth_code'), key=jwt_decode_key) payload = jwt.decode(token=j.get('auth_code'), key=jwt_decode_key, algorithms=[ALGORITHMS.RS256])
except JWTError as e: except JWTError as e:
return JSONr(status_code=400, content={'status': 400, 'title': 'invalid token', 'detail': str(e)}) return JSONr(status_code=400, content={'status': 400, 'title': 'invalid token', 'detail': str(e)})
@ -355,7 +367,7 @@ async def auth_v1_token(request: Request):
if payload.get('challenge') != challenge: if payload.get('challenge') != challenge:
return JSONr(status_code=401, content={'status': 401, 'detail': 'expected challenge did not match verifier'}) return JSONr(status_code=401, content={'status': 401, 'detail': 'expected challenge did not match verifier'})
access_expires_on = cur_time + TOKEN_EXPIRE_DELTA access_expires_on = cur_time + default_instance.get_token_expire_delta()
new_payload = { new_payload = {
'iat': timegm(cur_time.timetuple()), 'iat': timegm(cur_time.timetuple()),
@ -364,8 +376,8 @@ async def auth_v1_token(request: Request):
'aud': 'https://cls.nvidia.org', 'aud': 'https://cls.nvidia.org',
'exp': timegm(access_expires_on.timetuple()), 'exp': timegm(access_expires_on.timetuple()),
'origin_ref': origin_ref, 'origin_ref': origin_ref,
'key_ref': SITE_KEY_XID, 'key_ref': default_site.site_key,
'kid': SITE_KEY_XID, 'kid': default_site.site_key,
} }
auth_token = jwt.encode(new_payload, key=jwt_encode_key, headers={'kid': payload.get('kid')}, algorithm=ALGORITHMS.RS256) auth_token = jwt.encode(new_payload, key=jwt_encode_key, headers={'kid': payload.get('kid')}, algorithm=ALGORITHMS.RS256)
@ -382,10 +394,13 @@ async def auth_v1_token(request: Request):
# venv/lib/python3.9/site-packages/nls_services_lease/test/test_lease_multi_controller.py # venv/lib/python3.9/site-packages/nls_services_lease/test/test_lease_multi_controller.py
@app.post('/leasing/v1/lessor', description='request multiple leases (borrow) for current origin') @app.post('/leasing/v1/lessor', description='request multiple leases (borrow) for current origin')
async def leasing_v1_lessor(request: Request): async def leasing_v1_lessor(request: Request):
j, token, cur_time = json_loads((await request.body()).decode('utf-8')), __get_token(request), datetime.utcnow() j, cur_time = json_loads((await request.body()).decode('utf-8')), datetime.utcnow()
default_instance = Instance.get_default_instance(db)
jwt_decode_key = default_instance.get_jwt_decode_key()
try: try:
token = __get_token(request) token = __get_token(request, jwt_decode_key)
except JWTError: except JWTError:
return JSONr(status_code=401, content={'status': 401, 'detail': 'token is not valid'}) return JSONr(status_code=401, content={'status': 401, 'detail': 'token is not valid'})
@ -399,7 +414,7 @@ async def leasing_v1_lessor(request: Request):
# return JSONr(status_code=500, detail=f'no service instances found for scopes: ["{scope_ref}"]') # return JSONr(status_code=500, detail=f'no service instances found for scopes: ["{scope_ref}"]')
lease_ref = str(uuid4()) lease_ref = str(uuid4())
expires = cur_time + LEASE_EXPIRE_DELTA expires = cur_time + default_instance.get_lease_expire_delta()
lease_result_list.append({ lease_result_list.append({
"ordinal": 0, "ordinal": 0,
# https://docs.nvidia.com/license-system/latest/nvidia-license-system-user-guide/index.html # https://docs.nvidia.com/license-system/latest/nvidia-license-system-user-guide/index.html
@ -407,13 +422,13 @@ async def leasing_v1_lessor(request: Request):
"ref": lease_ref, "ref": lease_ref,
"created": cur_time.isoformat(), "created": cur_time.isoformat(),
"expires": expires.isoformat(), "expires": expires.isoformat(),
"recommended_lease_renewal": LEASE_RENEWAL_PERIOD, "recommended_lease_renewal": default_instance.lease_renewal_period,
"offline_lease": "true", "offline_lease": "true",
"license_type": "CONCURRENT_COUNTED_SINGLE" "license_type": "CONCURRENT_COUNTED_SINGLE"
} }
}) })
data = Lease(origin_ref=origin_ref, lease_ref=lease_ref, lease_created=cur_time, lease_expires=expires) data = Lease(instance_ref=default_instance.instance_ref, origin_ref=origin_ref, lease_ref=lease_ref, lease_created=cur_time, lease_expires=expires)
Lease.create_or_update(db, data) Lease.create_or_update(db, data)
response = { response = {
@ -430,7 +445,14 @@ async def leasing_v1_lessor(request: Request):
# venv/lib/python3.9/site-packages/nls_dal_service_instance_dls/schema/service_instance/V1_0_21__product_mapping.sql # venv/lib/python3.9/site-packages/nls_dal_service_instance_dls/schema/service_instance/V1_0_21__product_mapping.sql
@app.get('/leasing/v1/lessor/leases', description='get active leases for current origin') @app.get('/leasing/v1/lessor/leases', description='get active leases for current origin')
async def leasing_v1_lessor_lease(request: Request): async def leasing_v1_lessor_lease(request: Request):
token, cur_time = __get_token(request), datetime.utcnow() cur_time = datetime.utcnow()
jwt_decode_key = Instance.get_default_instance(db).get_jwt_decode_key()
try:
token = __get_token(request, jwt_decode_key)
except JWTError:
return JSONr(status_code=401, content={'status': 401, 'detail': 'token is not valid'})
origin_ref = token.get('origin_ref') origin_ref = token.get('origin_ref')
@ -450,7 +472,15 @@ async def leasing_v1_lessor_lease(request: Request):
# venv/lib/python3.9/site-packages/nls_core_lease/lease_single.py # venv/lib/python3.9/site-packages/nls_core_lease/lease_single.py
@app.put('/leasing/v1/lease/{lease_ref}', description='renew a lease') @app.put('/leasing/v1/lease/{lease_ref}', description='renew a lease')
async def leasing_v1_lease_renew(request: Request, lease_ref: str): async def leasing_v1_lease_renew(request: Request, lease_ref: str):
token, cur_time = __get_token(request), datetime.utcnow() cur_time = datetime.utcnow()
default_instance = Instance.get_default_instance(db)
jwt_decode_key = default_instance.get_jwt_decode_key()
try:
token = __get_token(request, jwt_decode_key)
except JWTError:
return JSONr(status_code=401, content={'status': 401, 'detail': 'token is not valid'})
origin_ref = token.get('origin_ref') origin_ref = token.get('origin_ref')
logging.info(f'> [ renew ]: {origin_ref}: renew {lease_ref}') logging.info(f'> [ renew ]: {origin_ref}: renew {lease_ref}')
@ -459,11 +489,11 @@ async def leasing_v1_lease_renew(request: Request, lease_ref: str):
if entity is None: if entity is None:
return JSONr(status_code=404, content={'status': 404, 'detail': 'requested lease not available'}) return JSONr(status_code=404, content={'status': 404, 'detail': 'requested lease not available'})
expires = cur_time + LEASE_EXPIRE_DELTA expires = cur_time + default_instance.get_lease_expire_delta()
response = { response = {
"lease_ref": lease_ref, "lease_ref": lease_ref,
"expires": expires.isoformat(), "expires": expires.isoformat(),
"recommended_lease_renewal": LEASE_RENEWAL_PERIOD, "recommended_lease_renewal": default_instance.lease_renewal_period,
"offline_lease": True, "offline_lease": True,
"prompts": None, "prompts": None,
"sync_timestamp": cur_time.isoformat(), "sync_timestamp": cur_time.isoformat(),
@ -477,7 +507,14 @@ async def leasing_v1_lease_renew(request: Request, lease_ref: str):
# venv/lib/python3.9/site-packages/nls_services_lease/test/test_lease_single_controller.py # venv/lib/python3.9/site-packages/nls_services_lease/test/test_lease_single_controller.py
@app.delete('/leasing/v1/lease/{lease_ref}', description='release (return) a lease') @app.delete('/leasing/v1/lease/{lease_ref}', description='release (return) a lease')
async def leasing_v1_lease_delete(request: Request, lease_ref: str): async def leasing_v1_lease_delete(request: Request, lease_ref: str):
token, cur_time = __get_token(request), datetime.utcnow() cur_time = datetime.utcnow()
jwt_decode_key = Instance.get_default_instance(db).get_jwt_decode_key()
try:
token = __get_token(request, jwt_decode_key)
except JWTError:
return JSONr(status_code=401, content={'status': 401, 'detail': 'token is not valid'})
origin_ref = token.get('origin_ref') origin_ref = token.get('origin_ref')
logging.info(f'> [ return ]: {origin_ref}: return {lease_ref}') logging.info(f'> [ return ]: {origin_ref}: return {lease_ref}')
@ -503,7 +540,14 @@ async def leasing_v1_lease_delete(request: Request, lease_ref: str):
# venv/lib/python3.9/site-packages/nls_services_lease/test/test_lease_multi_controller.py # venv/lib/python3.9/site-packages/nls_services_lease/test/test_lease_multi_controller.py
@app.delete('/leasing/v1/lessor/leases', description='release all leases') @app.delete('/leasing/v1/lessor/leases', description='release all leases')
async def leasing_v1_lessor_lease_remove(request: Request): async def leasing_v1_lessor_lease_remove(request: Request):
token, cur_time = __get_token(request), datetime.utcnow() cur_time = datetime.utcnow()
jwt_decode_key = Instance.get_default_instance(db).get_jwt_decode_key()
try:
token = __get_token(request, jwt_decode_key)
except JWTError:
return JSONr(status_code=401, content={'status': 401, 'detail': 'token is not valid'})
origin_ref = token.get('origin_ref') origin_ref = token.get('origin_ref')
@ -525,6 +569,8 @@ async def leasing_v1_lessor_lease_remove(request: Request):
async def leasing_v1_lessor_shutdown(request: Request): async def leasing_v1_lessor_shutdown(request: Request):
j, cur_time = json_loads((await request.body()).decode('utf-8')), datetime.utcnow() j, cur_time = json_loads((await request.body()).decode('utf-8')), datetime.utcnow()
jwt_decode_key = Instance.get_default_instance(db).get_jwt_decode_key()
token = j.get('token') token = j.get('token')
token = jwt.decode(token=token, key=jwt_decode_key, algorithms=ALGORITHMS.RS256, options={'verify_aud': False}) token = jwt.decode(token=token, key=jwt_decode_key, algorithms=ALGORITHMS.RS256, options={'verify_aud': False})
origin_ref = token.get('origin_ref') origin_ref = token.get('origin_ref')
@ -545,15 +591,23 @@ async def leasing_v1_lessor_shutdown(request: Request):
@app.on_event('startup') @app.on_event('startup')
async def app_on_startup(): async def app_on_startup():
default_instance = Instance.get_default_instance(db)
lease_renewal_period = default_instance.lease_renewal_period
lease_renewal_delta = default_instance.get_lease_renewal_delta()
client_token_expire_delta = default_instance.get_client_token_expire_delta()
logger.info(f''' logger.info(f'''
Using timezone: {str(TZ)}. Make sure this is correct and match your clients! Using timezone: {str(TZ)}. Make sure this is correct and match your clients!
Your clients renew their license every {str(Lease.calculate_renewal(LEASE_RENEWAL_PERIOD, LEASE_RENEWAL_DELTA))}. Your clients will renew their license every {str(Lease.calculate_renewal(lease_renewal_period, lease_renewal_delta))}.
If the renewal fails, the license is {str(LEASE_RENEWAL_DELTA)} valid. If the renewal fails, the license is valid for {str(lease_renewal_delta)}.
Your client-token file (.tok) is valid for {str(CLIENT_TOKEN_EXPIRE_DELTA)}. Your client-token file (.tok) is valid for {str(client_token_expire_delta)}.
''') ''')
validate_settings()
if __name__ == '__main__': if __name__ == '__main__':
import uvicorn import uvicorn

View File

@ -16,6 +16,18 @@ def load_key(filename) -> "RsaKey":
return RSA.import_key(extern_key=load_file(filename), passphrase=None) return RSA.import_key(extern_key=load_file(filename), passphrase=None)
def parse_key(content: bytes) -> "RsaKey":
try:
# Crypto | Cryptodome on Debian
from Crypto.PublicKey import RSA
from Crypto.PublicKey.RSA import RsaKey
except ModuleNotFoundError:
from Cryptodome.PublicKey import RSA
from Cryptodome.PublicKey.RSA import RsaKey
return RSA.import_key(extern_key=content, passphrase=None)
def generate_key() -> "RsaKey": def generate_key() -> "RsaKey":
try: try:
# Crypto | Cryptodome on Debian # Crypto | Cryptodome on Debian

View File

@ -1,14 +1,15 @@
from os import getenv as env
from base64 import b64encode as b64enc from base64 import b64encode as b64enc
from hashlib import sha256 from hashlib import sha256
from calendar import timegm from calendar import timegm
from datetime import datetime from datetime import datetime
from os.path import dirname, join from uuid import UUID, uuid4
from uuid import uuid4, UUID
from dateutil.relativedelta import relativedelta from dateutil.relativedelta import relativedelta
from jose import jwt, jwk from jose import jwt
from jose.constants import ALGORITHMS from jose.constants import ALGORITHMS
from starlette.testclient import TestClient from starlette.testclient import TestClient
from sqlalchemy import create_engine
import sys import sys
# add relative path to use packages as they were in the app/ dir # add relative path to use packages as they were in the app/ dir
@ -16,20 +17,23 @@ sys.path.append('../')
sys.path.append('../app') sys.path.append('../app')
from app import main from app import main
from app.util import load_key from app.orm import init as db_init, migrate, Site, Instance
client = TestClient(main.app)
ORIGIN_REF, ALLOTMENT_REF, SECRET = str(uuid4()), '20000000-0000-0000-0000-000000000001', 'HelloWorld' ORIGIN_REF, ALLOTMENT_REF, SECRET = str(uuid4()), '20000000-0000-0000-0000-000000000001', 'HelloWorld'
# INSTANCE_KEY_RSA = generate_key() # fastapi setup
# INSTANCE_KEY_PUB = INSTANCE_KEY_RSA.public_key() client = TestClient(main.app)
INSTANCE_KEY_RSA = load_key(str(join(dirname(__file__), '../app/cert/instance.private.pem'))) # database setup
INSTANCE_KEY_PUB = load_key(str(join(dirname(__file__), '../app/cert/instance.public.pem'))) db = create_engine(str(env('DATABASE', 'sqlite:///db.sqlite')))
db_init(db), migrate(db)
jwt_encode_key = jwk.construct(INSTANCE_KEY_RSA.export_key().decode('utf-8'), algorithm=ALGORITHMS.RS256) # test vars
jwt_decode_key = jwk.construct(INSTANCE_KEY_PUB.export_key().decode('utf-8'), algorithm=ALGORITHMS.RS256) DEFAULT_SITE, DEFAULT_INSTANCE = Site.get_default_site(db), Instance.get_default_instance(db)
SITE_KEY = DEFAULT_SITE.site_key
jwt_encode_key, jwt_decode_key = DEFAULT_INSTANCE.get_jwt_encode_key(), DEFAULT_INSTANCE.get_jwt_decode_key()
def __bearer_token(origin_ref: str) -> str: def __bearer_token(origin_ref: str) -> str:
@ -38,6 +42,12 @@ def __bearer_token(origin_ref: str) -> str:
return token return token
def test_initial_default_site_and_instance():
default_site, default_instance = Site.get_default_site(db), Instance.get_default_instance(db)
assert default_site.site_key == Site.INITIAL_SITE_KEY_XID
assert default_instance.instance_ref == Instance.DEFAULT_INSTANCE_REF
def test_index(): def test_index():
response = client.get('/') response = client.get('/')
assert response.status_code == 200 assert response.status_code == 200