Merge branch 'dev' into 'main'

v0.2

See merge request oscar.krause/fastapi-dls!3
This commit is contained in:
Oscar Krause 2022-12-20 15:08:56 +01:00
commit 68aeeb785d
3 changed files with 101 additions and 71 deletions

View File

@ -16,6 +16,14 @@ openssl req -x509 -nodes -days 3650 -newkey rsa:2048 -keyout $WORKING_DIR/webse
docker run -e DLS_URL=`hostname -i` -e DLS_PORT=443 -p 443:443 -v $WORKING_DIR:/app/cert collinwebdesigns/fastapi-dls:latest docker run -e DLS_URL=`hostname -i` -e DLS_PORT=443 -p 443:443 -v $WORKING_DIR:/app/cert collinwebdesigns/fastapi-dls:latest
``` ```
# Configuration
| Variable | Default | Usage |
|---------------------|-------------|---------------------------------------------------------------------------|
| `DLS_URL` | `localhost` | Used in client-token to tell guest driver where dls instance is reachable |
| `DLS_PORT` | `443` | Used in client-token to tell guest driver where dls instance is reachable |
| `LEASE_EXPIRE_DAYS` | `90` | Lease time in days |
# Installation # Installation
**The token file has to be copied! It's not enough to C&P file contents, because there can be special characters.** **The token file has to be copied! It's not enough to C&P file contents, because there can be special characters.**

View File

@ -1,20 +0,0 @@
from Crypto.PublicKey import RSA
from Crypto.PublicKey.RSA import RsaKey
def load_file(filename) -> bytes:
with open(filename, 'rb') as file:
content = file.read()
return content
def load_key(filename) -> RsaKey:
return RSA.import_key(extern_key=load_file(filename), passphrase=None)
def private_bytes(rsa: RsaKey) -> bytes:
return rsa.export_key(format='PEM', passphrase=None, protection=None)
def public_key(rsa: RsaKey) -> bytes:
return rsa.public_key().export_key(format='PEM')

View File

@ -1,4 +1,4 @@
from base64 import b64encode from base64 import b64encode as b64enc
from hashlib import sha256 from hashlib import sha256
from uuid import uuid4 from uuid import uuid4
from os.path import join, dirname from os.path import join, dirname
@ -12,14 +12,26 @@ from calendar import timegm
from jose import jws, jwk, jwt from jose import jws, jwk, jwt
from jose.constants import ALGORITHMS from jose.constants import ALGORITHMS
from starlette.responses import StreamingResponse, JSONResponse from starlette.responses import StreamingResponse, JSONResponse
import dataset
from Crypto.PublicKey import RSA
from Crypto.PublicKey.RSA import RsaKey
def load_file(filename) -> bytes:
with open(filename, 'rb') as file:
content = file.read()
return content
def load_key(filename) -> RsaKey:
return RSA.import_key(extern_key=load_file(filename), passphrase=None)
from helper import load_key, private_bytes, public_key
# todo: initialize certificate (or should be done by user, and passed through "volumes"?) # todo: initialize certificate (or should be done by user, and passed through "volumes"?)
app = FastAPI() app, db = FastAPI(), dataset.connect('sqlite:///db.sqlite')
LEASE_EXPIRE_DELTA = relativedelta(minutes=15) # days=90 LEASE_EXPIRE_DELTA = relativedelta(days=int(getenv('LEASE_EXPIRE_DAYS', 90)))
DLS_URL = str(getenv('DLS_URL', 'localhost')) DLS_URL = str(getenv('DLS_URL', 'localhost'))
DLS_PORT = int(getenv('DLS_PORT', '443')) DLS_PORT = int(getenv('DLS_PORT', '443'))
@ -27,6 +39,9 @@ SITE_KEY_XID = getenv('SITE_KEY_XID', '00000000-0000-0000-0000-000000000000')
INSTANCE_KEY_RSA = load_key(join(dirname(__file__), 'cert/instance.private.pem')) INSTANCE_KEY_RSA = load_key(join(dirname(__file__), 'cert/instance.private.pem'))
INSTANCE_KEY_PUB = load_key(join(dirname(__file__), 'cert/instance.public.pem')) INSTANCE_KEY_PUB = load_key(join(dirname(__file__), 'cert/instance.public.pem'))
jwt_encode_key = jwk.construct(INSTANCE_KEY_RSA.export_key().decode('utf-8'), algorithm=ALGORITHMS.RS256)
jwt_decode_key = jwk.construct(INSTANCE_KEY_PUB.export_key().decode('utf-8'), algorithm=ALGORITHMS.RS512)
@app.get('/') @app.get('/')
async def index(): async def index():
@ -80,8 +95,7 @@ async def client_token():
"service_instance_public_key_configuration": service_instance_public_key_configuration, "service_instance_public_key_configuration": service_instance_public_key_configuration,
} }
key = jwk.construct(INSTANCE_KEY_RSA.export_key().decode('utf-8'), algorithm=ALGORITHMS.RS256) data = jws.sign(payload, key=jwt_encode_key, headers=None, algorithm='RS256')
data = jws.sign(payload, key=key, headers=None, algorithm='RS256')
response = StreamingResponse(iter([data]), media_type="text/plain") response = StreamingResponse(iter([data]), media_type="text/plain")
filename = f'client_configuration_token_{datetime.now().strftime("%d-%m-%y-%H-%M-%S")}' filename = f'client_configuration_token_{datetime.now().strftime("%d-%m-%y-%H-%M-%S")}'
@ -90,17 +104,25 @@ async def client_token():
# venv/lib/python3.9/site-packages/nls_services_auth/test/test_origins_controller.py # venv/lib/python3.9/site-packages/nls_services_auth/test/test_origins_controller.py
# {"candidate_origin_ref":"00112233-4455-6677-8899-aabbccddeeff","environment":{"fingerprint":{"mac_address_list":["ff:ff:ff:ff:ff:ff"]},"hostname":"my-hostname","ip_address_list":["192.168.178.123","fe80::","fe80::1%enp6s18"],"guest_driver_version":"510.85.02","os_platform":"Debian GNU/Linux 11 (bullseye) 11","os_version":"11 (bullseye)"},"registration_pending":false,"update_pending":false}
@app.post('/auth/v1/origin') @app.post('/auth/v1/origin')
async def auth_origin(request: Request): async def auth_origin(request: Request):
body = await request.body() j = json.loads((await request.body()).decode('utf-8'))
body = body.decode('utf-8')
j = json.loads(body) candidate_origin_ref = j['candidate_origin_ref']
# {"candidate_origin_ref":"00112233-4455-6677-8899-aabbccddeeff","environment":{"fingerprint":{"mac_address_list":["ff:ff:ff:ff:ff:ff"]},"hostname":"my-hostname","ip_address_list":["192.168.178.123","fe80::","fe80::1%enp6s18"],"guest_driver_version":"510.85.02","os_platform":"Debian GNU/Linux 11 (bullseye) 11","os_version":"11 (bullseye)"},"registration_pending":false,"update_pending":false} print(f'> [ origin ]: {candidate_origin_ref}: {j}')
print(f'> [ origin ]: {j}')
data = dict(
candidate_origin_ref=candidate_origin_ref,
hostname=j['environment']['hostname'],
guest_driver_version=j['environment']['guest_driver_version'],
os_platform=j['environment']['os_platform'], os_version=j['environment']['os_version']
)
db['origin'].insert_ignore(data, ['candidate_origin_ref'])
cur_time = datetime.utcnow() cur_time = datetime.utcnow()
response = { response = {
"origin_ref": j['candidate_origin_ref'], "origin_ref": candidate_origin_ref,
"environment": j['environment'], "environment": j['environment'],
"svc_port_set_list": None, "svc_port_set_list": None,
"node_url_list": None, "node_url_list": None,
@ -113,13 +135,13 @@ async def auth_origin(request: Request):
# venv/lib/python3.9/site-packages/nls_services_auth/test/test_auth_controller.py # venv/lib/python3.9/site-packages/nls_services_auth/test/test_auth_controller.py
# venv/lib/python3.9/site-packages/nls_core_auth/auth.py - CodeResponse # venv/lib/python3.9/site-packages/nls_core_auth/auth.py - CodeResponse
# {"code_challenge":"...","origin_ref":"00112233-4455-6677-8899-aabbccddeeff"}
@app.post('/auth/v1/code') @app.post('/auth/v1/code')
async def auth_code(request: Request): async def auth_code(request: Request):
body = await request.body() j = json.loads((await request.body()).decode('utf-8'))
body = body.decode('utf-8')
j = json.loads(body) origin_ref = j['origin_ref']
# {"code_challenge":"...","origin_ref":"00112233-4455-6677-8899-aabbccddeeff"} print(f'> [ code ]: {origin_ref}: {j}')
print(f'> [ code ]: {j}')
cur_time = datetime.utcnow() cur_time = datetime.utcnow()
expires = cur_time + relativedelta(days=1) expires = cur_time + relativedelta(days=1)
@ -133,12 +155,7 @@ async def auth_code(request: Request):
'kid': SITE_KEY_XID 'kid': SITE_KEY_XID
} }
headers = None auth_code = jws.sign(payload, key=jwt_encode_key, headers={'kid': payload.get('kid')}, algorithm='RS256')
kid = payload.get('kid')
if kid:
headers = {'kid': kid}
key = jwk.construct(INSTANCE_KEY_RSA.export_key().decode('utf-8'), algorithm=ALGORITHMS.RS512)
auth_code = jws.sign(payload, key, headers=headers, algorithm='RS256')
response = { response = {
"auth_code": auth_code, "auth_code": auth_code,
@ -150,19 +167,17 @@ async def auth_code(request: Request):
# venv/lib/python3.9/site-packages/nls_services_auth/test/test_auth_controller.py # venv/lib/python3.9/site-packages/nls_services_auth/test/test_auth_controller.py
# venv/lib/python3.9/site-packages/nls_core_auth/auth.py - TokenResponse # venv/lib/python3.9/site-packages/nls_core_auth/auth.py - TokenResponse
# {"auth_code":"...","code_verifier":"..."}
@app.post('/auth/v1/token') @app.post('/auth/v1/token')
async def auth_token(request: Request): async def auth_token(request: Request):
body = await request.body() j = json.loads((await request.body()).decode('utf-8'))
body = body.decode('utf-8') payload = jwt.decode(token=j['auth_code'], key=jwt_decode_key)
j = json.loads(body)
# {"auth_code":"...","code_verifier":"..."}
# payload = self._security.get_valid_payload(req.auth_code) # todo origin_ref = payload['origin_ref']
key = jwk.construct(INSTANCE_KEY_PUB.export_key().decode('utf-8'), algorithm=ALGORITHMS.RS512) print(f'> [ auth ]: {origin_ref}: {j}')
payload = jwt.decode(token=j['auth_code'], key=key)
# validate the code challenge # validate the code challenge
if payload['challenge'] != b64encode(sha256(j['code_verifier'].encode('utf-8')).digest()).rstrip(b'=').decode('utf-8'): if payload['challenge'] != b64enc(sha256(j['code_verifier'].encode('utf-8')).digest()).rstrip(b'=').decode('utf-8'):
raise HTTPException(status_code=403, detail='expected challenge did not match verifier') raise HTTPException(status_code=403, detail='expected challenge did not match verifier')
cur_time = datetime.utcnow() cur_time = datetime.utcnow()
@ -179,12 +194,7 @@ async def auth_token(request: Request):
'kid': SITE_KEY_XID, 'kid': SITE_KEY_XID,
} }
headers = None auth_token = jwt.encode(new_payload, key=jwt_encode_key, headers={'kid': payload.get('kid')}, algorithm='RS256')
kid = payload.get('kid')
if kid:
headers = {'kid': kid}
key = jwk.construct(INSTANCE_KEY_RSA.export_key().decode('utf-8'), algorithm=ALGORITHMS.RS512)
auth_token = jwt.encode(new_payload, key=key, headers=headers, algorithm='RS256')
response = { response = {
"expires": access_expires_on.isoformat(), "expires": access_expires_on.isoformat(),
@ -195,18 +205,20 @@ async def auth_token(request: Request):
return JSONResponse(response) return JSONResponse(response)
# {'fulfillment_context': {'fulfillment_class_ref_list': []}, 'lease_proposal_list': [{'license_type_qualifiers': {'count': 1}, 'product': {'name': 'NVIDIA RTX Virtual Workstation'}}], 'proposal_evaluation_mode': 'ALL_OF', 'scope_ref_list': ['00112233-4455-6677-8899-aabbccddeeff']}
@app.post('/leasing/v1/lessor') @app.post('/leasing/v1/lessor')
async def leasing_lessor(request: Request): async def leasing_lessor(request: Request):
body = await request.body() j = json.loads((await request.body()).decode('utf-8'))
body = body.decode('utf-8') token = jwt.decode(request.headers['authorization'].split(' ')[1], key=jwt_decode_key, algorithms='RS256', options={'verify_aud': False})
j = json.loads(body)
# {'fulfillment_context': {'fulfillment_class_ref_list': []}, 'lease_proposal_list': [{'license_type_qualifiers': {'count': 1}, 'product': {'name': 'NVIDIA RTX Virtual Workstation'}}], 'proposal_evaluation_mode': 'ALL_OF', 'scope_ref_list': ['00112233-4455-6677-8899-aabbccddeeff']} code_challenge = token['origin_ref']
print(f'> [ lessor ]: {j}') scope_ref_list = j['scope_ref_list']
print(f'> [ lessor ]: {code_challenge}: {j}')
print(f'> {code_challenge}: create leases for scope_ref_list {scope_ref_list}')
cur_time = datetime.utcnow() cur_time = datetime.utcnow()
# todo: keep track of leases, to return correct list on '/leasing/v1/lessor/leases'
lease_result_list = [] lease_result_list = []
for scope_ref in j['scope_ref_list']: for scope_ref in scope_ref_list:
lease_result_list.append({ lease_result_list.append({
"ordinal": 0, "ordinal": 0,
"lease": { "lease": {
@ -218,6 +230,8 @@ async def leasing_lessor(request: Request):
"license_type": "CONCURRENT_COUNTED_SINGLE" "license_type": "CONCURRENT_COUNTED_SINGLE"
} }
}) })
data = dict(origin_ref=code_challenge, lease_ref=scope_ref, expires=None, last_update=None)
db['leases'].insert_ignore(data, ['origin_ref', 'lease_ref'])
response = { response = {
"lease_result_list": lease_result_list, "lease_result_list": lease_result_list,
@ -232,13 +246,23 @@ async def leasing_lessor(request: Request):
# venv/lib/python3.9/site-packages/nls_services_lease/test/test_lease_multi_controller.py # venv/lib/python3.9/site-packages/nls_services_lease/test/test_lease_multi_controller.py
@app.get('/leasing/v1/lessor/leases') @app.get('/leasing/v1/lessor/leases')
async def leasing_lessor_lease(request: Request): async def leasing_lessor_lease(request: Request):
token = jwt.decode(request.headers['authorization'].split(' ')[1], key=key, algorithms='RS256', options={'verify_aud': False})
code_challenge = token['origin_ref']
active_lease_list = list(map(lambda x: x['lease_ref'], db['leases'].find(origin_ref=code_challenge)))
print(f'> {code_challenge}: found {len(active_lease_list)} active leases')
if len(active_lease_list) == 0:
raise HTTPException(status_code=400, detail="No leases available")
cur_time = datetime.utcnow() cur_time = datetime.utcnow()
# venv/lib/python3.9/site-packages/nls_dal_service_instance_dls/schema/service_instance/V1_0_21__product_mapping.sql # venv/lib/python3.9/site-packages/nls_dal_service_instance_dls/schema/service_instance/V1_0_21__product_mapping.sql
response = { response = {
# GRID-Virtual-WS 2.0 CONCURRENT_COUNTED_SINGLE # "active_lease_list": [
"active_lease_list": [ # # "BE276D7B-2CDB-11EC-9838-061A22468B59" # (works on Linux) GRID-Virtual-WS 2.0 CONCURRENT_COUNTED_SINGLE // 'NVIDIA Virtual PC','NVIDIA Virtual PC'
"BE276D7B-2CDB-11EC-9838-061A22468B59" # "BE276EFE-2CDB-11EC-9838-061A22468B59" # GRID-Virtual-WS 2.0 CONCURRENT_COUNTED_SINGLE // 'NVIDIA RTX Virtual Workstation','NVIDIA RTX Virtual Workstation
], # ],
"active_lease_list": active_lease_list,
"sync_timestamp": cur_time.isoformat(), "sync_timestamp": cur_time.isoformat(),
"prompts": None "prompts": None
} }
@ -249,26 +273,44 @@ async def leasing_lessor_lease(request: Request):
# venv/lib/python3.9/site-packages/nls_core_lease/lease_single.py # venv/lib/python3.9/site-packages/nls_core_lease/lease_single.py
@app.put('/leasing/v1/lease/{lease_ref}') @app.put('/leasing/v1/lease/{lease_ref}')
async def leasing_lease_renew(request: Request, lease_ref: str): async def leasing_lease_renew(request: Request, lease_ref: str):
print(f'> [ renew ]: lease: {lease_ref}') token = jwt.decode(request.headers['authorization'].split(' ')[1], key=jwt_decode_key, algorithms='RS256', options={'verify_aud': False})
code_challenge = token['origin_ref']
print(f'> {code_challenge}: renew {lease_ref}')
if db['leases'].count(lease_ref=lease_ref) == 0:
raise HTTPException(status_code=400, detail="No leases available")
cur_time = datetime.utcnow() cur_time = datetime.utcnow()
expires = cur_time + LEASE_EXPIRE_DELTA
response = { response = {
"lease_ref": lease_ref, "lease_ref": lease_ref,
"expires": (cur_time + LEASE_EXPIRE_DELTA).isoformat(), "expires": expires.isoformat(),
"recommended_lease_renewal": 0.16, "recommended_lease_renewal": 0.16,
# 0.16 = 10 min, 0.25 = 15 min, 0.33 = 20 min, 0.5 = 30 min (should be lower than "LEASE_EXPIRE_DELTA")
"offline_lease": True, "offline_lease": True,
"prompts": None, "prompts": None,
"sync_timestamp": cur_time.isoformat(), "sync_timestamp": cur_time.isoformat(),
} }
data = dict(lease_ref=lease_ref, origin_ref=code_challenge, expires=expires, last_update=cur_time)
db['leases'].update(data, ['lease_ref'])
return JSONResponse(response) return JSONResponse(response)
@app.delete('/leasing/v1/lessor/leases') @app.delete('/leasing/v1/lessor/leases')
async def leasing_lessor_lease_remove(request: Request): async def leasing_lessor_lease_remove(request: Request):
token = jwt.decode(request.headers['authorization'].split(' ')[1], key=jwt_decode_key, algorithms='RS256', options={'verify_aud': False})
code_challenge = token['origin_ref']
released_lease_list = list(map(lambda x: x['lease_ref'], db['leases'].find(origin_ref=code_challenge)))
deletions = db['leases'].delete(origin_ref=code_challenge)
print(f'> {code_challenge}: removed {deletions} leases')
cur_time = datetime.utcnow() cur_time = datetime.utcnow()
response = { response = {
"released_lease_list": None, "released_lease_list": released_lease_list,
"release_failure_list": None, "release_failure_list": None,
"sync_timestamp": cur_time.isoformat(), "sync_timestamp": cur_time.isoformat(),
"prompts": None "prompts": None