fastapi-dls/app/main.py

272 lines
10 KiB
Python

from base64 import b64encode
from hashlib import sha256
from uuid import uuid4
from fastapi import FastAPI, HTTPException
from fastapi.requests import Request
import json
from datetime import datetime, timedelta
from calendar import timegm
from jose import jws, jwk, jwt
from jose.constants import ALGORITHMS
from starlette.responses import StreamingResponse, JSONResponse
from app.helper import load_key, private_bytes, public_key
# todo: initialize certificate (or should be done by user, and passed through "volumes"?)
app = FastAPI()
URL = '192.168.178.196'
SITE_KEY_FILE = load_key('/opt/fastapi-dls/site.key')
SITE_KEY_XID = '00000000-0000-0000-0000-000000000000'
SITE_KEY_RSA = private_bytes(SITE_KEY_FILE)
SITE_KEY_PUB = public_key(SITE_KEY_FILE)
@app.get('/')
async def index():
return {'hello': 'world'}
@app.get('/status')
async def status(request: Request):
return JSONResponse({'status': 'up'})
# venv/lib/python3.9/site-packages/nls_core_service_instance/service_instance_token_manager.py
@app.get('/client-token')
async def client_token():
public_key_me = SITE_KEY_FILE.public_key().public_numbers()
service_instance_public_key_me = {
"mod": hex(public_key_me.n)[2:],
"exp": public_key_me.e,
},
cur_time = datetime.utcnow()
exp_time = cur_time + timedelta(days=1)
payload = {
"jti": str(uuid4()),
"iss": "NLS Service Instance",
"aud": "NLS Licensed Client",
"iat": cur_time,
"nbf": cur_time,
"exp": exp_time,
"update_mode": "ABSOLUTE",
"scope_ref_list": [
"482f24b5-0a60-4ec2-a63a-9ed00bc2534e"
# todo: "scope_ref_list" should be a unique client id (which identifies leases, etc.)
],
"fulfillment_class_ref_list": [],
"service_instance_configuration": {
"nls_service_instance_ref": "b43d6e46-d6d0-4943-8b8d-c66a5f6e0d38",
"svc_port_set_list": [
{
"idx": 0,
"d_name": "DLS",
"svc_port_map": [
{"service": "auth", "port": 443},
{"service": "lease", "port": 443}
]
}
],
"node_url_list": [{"idx": 0, "url": URL, "url_qr": URL, "svc_port_set_idx": 0}]
},
"service_instance_public_key_configuration": {
"service_instance_public_key_me": service_instance_public_key_me,
"service_instance_public_key_pem": SITE_KEY_PUB.decode('utf-8'),
"key_retention_mode": "LATEST_ONLY"
}
}
key = jwk.construct(SITE_KEY_RSA, algorithm=ALGORITHMS.RS512)
data = jwt.encode(payload, key=key, headers=None, algorithm='RS256')
response = StreamingResponse(iter([data]), media_type="text/plain")
response.headers["Content-Disposition"] = "attachment; filename=client_token.tok"
return response
# venv/lib/python3.9/site-packages/nls_services_auth/test/test_origins_controller.py
@app.post('/auth/v1/origin')
async def auth(request: Request, status_code=201):
body = await request.body()
body = body.decode('utf-8')
j = json.loads(body)
# {"candidate_origin_ref":"00112233-4455-6677-8899-aabbccddeeff","environment":{"fingerprint":{"mac_address_list":["fa:52:16:65:c5:28"]},"hostname":"debian-grid-test","ip_address_list":["192.168.178.12","fdfe:7fcd:e30f:40f5:f852:16ff:fe65:c528","fe80::f852:16ff:fe65:c528%enp6s18"],"guest_driver_version":"510.85.02","os_platform":"Debian GNU/Linux 11 (bullseye) 11","os_version":"11 (bullseye)"},"registration_pending":false,"update_pending":false}
cur_time = datetime.utcnow()
response = {
"origin_ref": j['candidate_origin_ref'],
"environment": {
"fingerprint": {"mac_address_list": ["e4:b9:7a:e5:7b:ff"]},
"guest_driver_version": "guest_driver_version",
"hostname": "myhost",
"os_platform": "os_platform",
"os_version": "os_version",
"ip_address_list": ["192.168.1.129"]
},
"svc_port_set_list": None,
"node_url_list": None,
"node_query_order": None,
"prompts": None,
"sync_timestamp": cur_time
}
return response
# venv/lib/python3.9/site-packages/nls_services_auth/test/test_auth_controller.py
# venv/lib/python3.9/site-packages/nls_core_auth/auth.py - CodeResponse
@app.post('/auth/v1/code')
async def code(request: Request):
body = await request.body()
body = body.decode('utf-8')
j = json.loads(body)
# {"code_challenge":"QhDaArKDQwFeQ5Jq4Dn5hy37ODF8Jq3igXCXvWEgs5I","origin_ref":"00112233-4455-6677-8899-aabbccddeeff"}
cur_time = datetime.utcnow()
expires = cur_time + timedelta(days=1)
payload = {
'iat': timegm(cur_time.timetuple()),
'exp': timegm(expires.timetuple()),
'challenge': j['code_challenge'],
'origin_ref': j['code_challenge'],
'key_ref': SITE_KEY_XID,
'kid': SITE_KEY_XID
}
headers = None
kid = payload.get('kid')
if kid:
headers = {'kid': kid}
key = jwk.construct(SITE_KEY_RSA, algorithm=ALGORITHMS.RS512)
auth_code = jws.sign(payload, key, headers=headers, algorithm='RS256')
response = {
"auth_code": auth_code,
"sync_timestamp": datetime.utcnow(),
"prompts": None
}
return response
# venv/lib/python3.9/site-packages/nls_services_auth/test/test_auth_controller.py
# venv/lib/python3.9/site-packages/nls_core_auth/auth.py - TokenResponse
@app.post('/auth/v1/token')
async def token(request: Request):
body = await request.body()
body = body.decode('utf-8')
j = json.loads(body)
# {"auth_code":"eyJhbGciOiJSUzI1NiIsImtpZCI6IjAwMDAwMDAwLTAwMDAtMDAwMC0wMDAwLTAwMDAwMDAwMDAwMCIsInR5cCI6IkpXVCJ9.eyJpYXQiOjE2NzExODI5MTQsImV4cCI6MTY3MTI2OTMxNCwiY2hhbGxlbmdlIjoiaXdZdFpIME03K0ZZUWdRQXEwbjhabThWcFpJbWdtV1NDSXI1MkdTSlMxayIsIm9yaWdpbl9yZWYiOiJpd1l0WkgwTTcrRllRZ1FBcTBuOFptOFZwWkltZ21XU0NJcjUyR1NKUzFrIiwia2V5X3JlZiI6IjAwMDAwMDAwLTAwMDAtMDAwMC0wMDAwLTAwMDAwMDAwMDAwMCIsImtpZCI6IjAwMDAwMDAwLTAwMDAtMDAwMC0wMDAwLTAwMDAwMDAwMDAwMCJ9.hkBPQx7UbXqwRzpTSp5fASwLg7rJOgjDOGD98Zh6pEkPW09KjxcsaHKeR8KIZmDS1S_kLed93-UzUY4wXAylFBlM-daL-TEbHJau2muZGWXPrtdsGLI9CLFcc0dmocq1_5rnRV3liqjdZwL8djK9Fx_5tOzEfeI9oCJ49Sh2LD_p1vkFcqUv9z9mVL9IGsoRM6y4hJ2YKBloijzhMLp5E7nojyD6Z8PQZ0mOIOc3tncAaXQS47JhgGsJPUDR-YoLF5uNpAlJKZP2eZWJt3P7MvhIz3lxFPUJ5jHX64Vf0Ds10-GBctZuy1-eCLBXj74uQy_U4KlnCif-5N8bPTvgxw","code_verifier":"CgnDPaugQCb4U6l3EfJSFsA/JxMqNO4TqONeb9yl8EVRWU88yTPlEeJgZQO0f/JVnScYOsvwa0jcvTAMBulEKgucfxDDVL1cBOylGugQ0QlJsXU5hJ8VLAQtOyPthnVyEutERNyOKVwl3YI5Z5EfUcfuhDqmxBUpnAFtQ9H3R3g"}
# payload = self._security.get_valid_payload(req.auth_code) # todo
key = jwk.construct(SITE_KEY_PUB, algorithm=ALGORITHMS.RS512)
payload = jwt.decode(token=j['auth_code'], key=key)
# validate the code challenge
if payload['challenge'] != b64encode(sha256(j['code_verifier'].encode('utf-8')).digest()).rstrip(b'=').decode(
'utf-8'):
raise HTTPException(status_code=403, detail='expected challenge did not match verifier')
cur_time = datetime.utcnow()
access_expires_on = cur_time + timedelta(days=1)
new_payload = {
'iat': timegm(cur_time.timetuple()),
'nbf': timegm(cur_time.timetuple()),
'iss': 'https://cls.nvidia.org',
'aud': 'https://cls.nvidia.org',
'exp': timegm(access_expires_on.timetuple()),
'origin_ref': payload['origin_ref'],
'key_ref': SITE_KEY_XID,
'kid': SITE_KEY_XID,
}
headers = None
kid = payload.get('kid')
if kid:
headers = {'kid': kid}
key = jwk.construct(SITE_KEY_RSA, algorithm=ALGORITHMS.RS512)
auth_token = jwt.encode(new_payload, key=key, headers=headers, algorithm='RS256')
response = {
"expires": access_expires_on,
"auth_token": auth_token,
"sync_timestamp": cur_time,
}
return response
@app.post('/leasing/v1/lessor')
async def lessor(request: Request):
body = await request.body()
body = body.decode('utf-8')
j = json.loads(body)
# {'fulfillment_context': {'fulfillment_class_ref_list': []}, 'lease_proposal_list': [{'license_type_qualifiers': {'count': 1}, 'product': {'name': 'NVIDIA RTX Virtual Workstation'}}], 'proposal_evaluation_mode': 'ALL_OF', 'scope_ref_list': ['482f24b5-0a60-4ec2-a63a-9ed00bc2534e']}
cur_time = datetime.utcnow()
# todo: keep track of leases, to return correct list on '/leasing/v1/lessor/leases'
lease_result_list = []
for scope_ref in j['scope_ref_list']:
lease_result_list.append({
"ordinal": 0,
"lease": {
"ref": scope_ref,
"created": cur_time,
"expires": cur_time + timedelta(days=90),
"recommended_lease_renewal": 0.15,
"offline_lease": "true",
"license_type": "CONCURRENT_COUNTED_SINGLE"
}
})
response = {
"lease_result_list": lease_result_list,
"result_code": "SUCCESS",
"sync_timestamp": cur_time,
"prompts": None
}
return response
# venv/lib/python3.9/site-packages/nls_services_lease/test/test_lease_multi_controller.py
@app.get('/leasing/v1/lessor/leases')
async def lease(request: Request):
cur_time = datetime.utcnow()
# venv/lib/python3.9/site-packages/nls_dal_service_instance_dls/schema/service_instance/V1_0_21__product_mapping.sql
response = {
# GRID-Virtual-WS 2.0 CONCURRENT_COUNTED_SINGLE
"active_lease_list": [
"BE276D7B-2CDB-11EC-9838-061A22468B59"
],
"sync_timestamp": cur_time,
"prompts": None
}
return response
@app.delete('/leasing/v1/lessor/leases')
async def lease_remove(request: Request, status_code=200):
cur_time = datetime.utcnow()
response = {
"released_lease_list": None,
"release_failure_list": None,
"sync_timestamp": cur_time,
"prompts": None
}
return response
if __name__ == '__main__':
import uvicorn
ssl_keyfile = 'key.pem'
ssl_certfile = 'cert.pem'
uvicorn.run('main:app', host='0.0.0.0', port=443, ssl_keyfile=ssl_keyfile, ssl_certfile=ssl_certfile)