fastapi-dls/app/main.py

295 lines
10 KiB
Python

from base64 import b64encode
from hashlib import sha256
from uuid import uuid4
from os.path import join, dirname
from os import getenv
from fastapi import FastAPI, HTTPException
from fastapi.requests import Request
import json
from datetime import datetime
from dateutil.relativedelta import relativedelta
from calendar import timegm
from jose import jws, jwk, jwt
from jose.constants import ALGORITHMS
from starlette.responses import StreamingResponse, JSONResponse
from helper import load_key, private_bytes, public_key
# todo: initialize certificate (or should be done by user, and passed through "volumes"?)
app = FastAPI()
LEASE_EXPIRE_DELTA = relativedelta(minutes=15) # days=90
DLS_URL = str(getenv('DLS_URL', 'localhost'))
DLS_PORT = int(getenv('DLS_PORT', '443'))
SITE_KEY_XID = getenv('SITE_KEY_XID', '00000000-0000-0000-0000-000000000000')
INSTANCE_KEY_RSA = load_key(join(dirname(__file__), 'cert/instance.private.pem'))
INSTANCE_KEY_PUB = load_key(join(dirname(__file__), 'cert/instance.public.pem'))
@app.get('/')
async def index():
return {'hello': 'world'}
@app.get('/status')
async def status(request: Request):
return JSONResponse({'status': 'up'})
# venv/lib/python3.9/site-packages/nls_core_service_instance/service_instance_token_manager.py
@app.get('/client-token')
async def client_token():
cur_time = datetime.utcnow()
exp_time = cur_time + relativedelta(years=12)
service_instance_public_key_configuration = {
"service_instance_public_key_me": {
"mod": hex(INSTANCE_KEY_PUB.public_key().n)[2:],
"exp": INSTANCE_KEY_PUB.public_key().e,
},
"service_instance_public_key_pem": INSTANCE_KEY_PUB.export_key().decode('utf-8'),
"key_retention_mode": "LATEST_ONLY"
}
payload = {
"jti": str(uuid4()),
"iss": "NLS Service Instance",
"aud": "NLS Licensed Client",
"iat": timegm(cur_time.timetuple()),
"nbf": timegm(cur_time.timetuple()),
"exp": timegm(exp_time.timetuple()),
"update_mode": "ABSOLUTE",
"scope_ref_list": [str(uuid4())],
"fulfillment_class_ref_list": [],
"service_instance_configuration": {
"nls_service_instance_ref": "00000000-0000-0000-0000-000000000000",
"svc_port_set_list": [
{
"idx": 0,
"d_name": "DLS",
"svc_port_map": [
{"service": "auth", "port": DLS_PORT},
{"service": "lease", "port": DLS_PORT}
]
}
],
"node_url_list": [{"idx": 0, "url": DLS_URL, "url_qr": DLS_URL, "svc_port_set_idx": 0}]
},
"service_instance_public_key_configuration": service_instance_public_key_configuration,
}
key = jwk.construct(INSTANCE_KEY_RSA.export_key().decode('utf-8'), algorithm=ALGORITHMS.RS256)
data = jws.sign(payload, key=key, headers=None, algorithm='RS256')
response = StreamingResponse(iter([data]), media_type="text/plain")
response.headers["Content-Disposition"] = f'attachment; filename=client_configuration_token_{datetime.now().strftime("%d-%m-%y-%H-%M-%S")}'
return response
# venv/lib/python3.9/site-packages/nls_services_auth/test/test_origins_controller.py
@app.post('/auth/v1/origin')
async def auth(request: Request, status_code=201):
body = await request.body()
body = body.decode('utf-8')
j = json.loads(body)
# {"candidate_origin_ref":"00112233-4455-6677-8899-aabbccddeeff","environment":{"fingerprint":{"mac_address_list":["ff:ff:ff:ff:ff:ff"]},"hostname":"my-hostname","ip_address_list":["192.168.178.123","fe80::","fe80::1%enp6s18"],"guest_driver_version":"510.85.02","os_platform":"Debian GNU/Linux 11 (bullseye) 11","os_version":"11 (bullseye)"},"registration_pending":false,"update_pending":false}
print(f'> [ origin ]: {j}')
cur_time = datetime.utcnow()
response = {
"origin_ref": j['candidate_origin_ref'],
"environment": j['environment'],
"svc_port_set_list": None,
"node_url_list": None,
"node_query_order": None,
"prompts": None,
"sync_timestamp": cur_time
}
return response
# venv/lib/python3.9/site-packages/nls_services_auth/test/test_auth_controller.py
# venv/lib/python3.9/site-packages/nls_core_auth/auth.py - CodeResponse
@app.post('/auth/v1/code')
async def code(request: Request):
body = await request.body()
body = body.decode('utf-8')
j = json.loads(body)
# {"code_challenge":"...","origin_ref":"00112233-4455-6677-8899-aabbccddeeff"}
print(f'> [ code ]: {j}')
cur_time = datetime.utcnow()
expires = cur_time + relativedelta(days=1)
payload = {
'iat': timegm(cur_time.timetuple()),
'exp': timegm(expires.timetuple()),
'challenge': j['code_challenge'],
'origin_ref': j['code_challenge'],
'key_ref': SITE_KEY_XID,
'kid': SITE_KEY_XID
}
headers = None
kid = payload.get('kid')
if kid:
headers = {'kid': kid}
key = jwk.construct(INSTANCE_KEY_RSA.export_key().decode('utf-8'), algorithm=ALGORITHMS.RS512)
auth_code = jws.sign(payload, key, headers=headers, algorithm='RS256')
response = {
"auth_code": auth_code,
"sync_timestamp": datetime.utcnow(),
"prompts": None
}
return response
# venv/lib/python3.9/site-packages/nls_services_auth/test/test_auth_controller.py
# venv/lib/python3.9/site-packages/nls_core_auth/auth.py - TokenResponse
@app.post('/auth/v1/token')
async def token(request: Request):
body = await request.body()
body = body.decode('utf-8')
j = json.loads(body)
# {"auth_code":"...","code_verifier":"..."}
# payload = self._security.get_valid_payload(req.auth_code) # todo
key = jwk.construct(INSTANCE_KEY_PUB.export_key().decode('utf-8'), algorithm=ALGORITHMS.RS512)
payload = jwt.decode(token=j['auth_code'], key=key)
# validate the code challenge
if payload['challenge'] != b64encode(sha256(j['code_verifier'].encode('utf-8')).digest()).rstrip(b'=').decode('utf-8'):
raise HTTPException(status_code=403, detail='expected challenge did not match verifier')
cur_time = datetime.utcnow()
access_expires_on = cur_time + relativedelta(days=1)
new_payload = {
'iat': timegm(cur_time.timetuple()),
'nbf': timegm(cur_time.timetuple()),
'iss': 'https://cls.nvidia.org',
'aud': 'https://cls.nvidia.org',
'exp': timegm(access_expires_on.timetuple()),
'origin_ref': payload['origin_ref'],
'key_ref': SITE_KEY_XID,
'kid': SITE_KEY_XID,
}
headers = None
kid = payload.get('kid')
if kid:
headers = {'kid': kid}
key = jwk.construct(INSTANCE_KEY_RSA.export_key().decode('utf-8'), algorithm=ALGORITHMS.RS512)
auth_token = jwt.encode(new_payload, key=key, headers=headers, algorithm='RS256')
response = {
"expires": access_expires_on,
"auth_token": auth_token,
"sync_timestamp": cur_time,
}
return response
@app.post('/leasing/v1/lessor')
async def lessor(request: Request):
body = await request.body()
body = body.decode('utf-8')
j = json.loads(body)
# {'fulfillment_context': {'fulfillment_class_ref_list': []}, 'lease_proposal_list': [{'license_type_qualifiers': {'count': 1}, 'product': {'name': 'NVIDIA RTX Virtual Workstation'}}], 'proposal_evaluation_mode': 'ALL_OF', 'scope_ref_list': ['00112233-4455-6677-8899-aabbccddeeff']}
print(f'> [ lessor ]: {j}')
cur_time = datetime.utcnow()
# todo: keep track of leases, to return correct list on '/leasing/v1/lessor/leases'
lease_result_list = []
for scope_ref in j['scope_ref_list']:
lease_result_list.append({
"ordinal": 0,
"lease": {
"ref": scope_ref,
"created": cur_time,
"expires": cur_time + LEASE_EXPIRE_DELTA,
"recommended_lease_renewal": 0.15,
"offline_lease": "true",
"license_type": "CONCURRENT_COUNTED_SINGLE"
}
})
response = {
"lease_result_list": lease_result_list,
"result_code": "SUCCESS",
"sync_timestamp": cur_time,
"prompts": None
}
return response
# venv/lib/python3.9/site-packages/nls_services_lease/test/test_lease_multi_controller.py
@app.get('/leasing/v1/lessor/leases')
async def lease(request: Request):
cur_time = datetime.utcnow()
# venv/lib/python3.9/site-packages/nls_dal_service_instance_dls/schema/service_instance/V1_0_21__product_mapping.sql
response = {
"active_lease_list": [
"BE276D7B-2CDB-11EC-9838-061A22468B59", # GRID-Virtual-WS 2.0 CONCURRENT_COUNTED_SINGLE 2
"BE276EFE-2CDB-11EC-9838-061A22468B59", # GRID-Virtual-WS 2.0 CONCURRENT_COUNTED_SINGLE 1
],
"sync_timestamp": cur_time,
"prompts": None
}
return response
# venv/lib/python3.9/site-packages/nls_core_lease/lease_single.py
@app.put('/leasing/v1/lease/{lease_ref}')
async def lease_renew(request: Request, lease_ref: str):
print(f'> [ renew ]: {lease_ref}')
cur_time = datetime.utcnow()
response = {
"lease_ref": lease_ref,
"expires": cur_time + LEASE_EXPIRE_DELTA,
"recommended_lease_renewal": 0.16,
"offline_lease": True,
"prompts": None,
"sync_timestamp": cur_time
}
return response
@app.delete('/leasing/v1/lessor/leases')
async def lease_remove(request: Request, status_code=200):
cur_time = datetime.utcnow()
response = {
"released_lease_list": None,
"release_failure_list": None,
"sync_timestamp": cur_time,
"prompts": None
}
return response
if __name__ == '__main__':
import uvicorn
###
#
# Running `python app/main.py` assumes that the user created a keypair, e.g. with openssl.
#
# openssl req -x509 -nodes -days 365 -newkey rsa:2048 -keyout app/cert/webserver.key -out app/cert/webserver.crt
#
###
print(f'> Starting dev-server ...')
ssl_keyfile = join(dirname(__file__), 'cert/webserver.key')
ssl_certfile = join(dirname(__file__), 'cert/webserver.crt')
uvicorn.run('main:app', host='0.0.0.0', port=443, ssl_keyfile=ssl_keyfile, ssl_certfile=ssl_certfile, reload=True)