Compare commits

..

76 Commits
main ... db

Author SHA1 Message Date
cd4674caad fixes 2024-06-21 19:35:42 +02:00
b0b627a3f0 Merge branch 'refs/heads/dev' into db
# Conflicts:
#	.gitlab-ci.yml
#	Dockerfile
#	README.md
#	app/main.py
#	app/orm.py
#	requirements.txt
2024-06-21 19:28:23 +02:00
c79455b84d added way to include driver version in api
use `create_driver_matrix_json.py` to generate file in `static/driver_matrix.json`. Logging currently is disabled to not confuse users when file is missing. This is optional!
2024-06-21 19:01:33 +02:00
35fc5ea6b0 set log format 2024-06-21 18:59:23 +02:00
c9ac915055 code styling & comments 2024-06-13 20:34:45 +02:00
8c5850beda migrated from deprecated "startup" to "lifespan" hook (fastapi) 2024-06-13 20:34:27 +02:00
0d9e814d0d show if debug is enabled on app startup 2024-06-13 20:18:33 +02:00
5438317fb7 rearranged imports 2024-06-13 20:18:18 +02:00
21f19be8ab code styling & improved logging 2024-06-13 20:16:51 +02:00
eff6aae25d code styling 2024-06-13 19:53:57 +02:00
6473655e57 import fixes 2024-06-13 19:24:46 +02:00
c45aa1a2a8 fixed [[__TOC__]] to [TOC] to support "markdown" package 2024-06-12 21:27:56 +02:00
16f80cd78b added "16.3" support 2024-03-04 21:19:59 +01:00
07aec53787 requirements.txt updated 2024-03-04 21:19:59 +01:00
3e87820f63 removed todo 2024-03-04 21:19:59 +01:00
a927e291b5 fixes 2024-03-04 21:19:59 +01:00
72054d30c4 make tests interruptible 2024-03-04 21:19:59 +01:00
00dc848083 only run test matrix when "app" or "test" changes 2024-03-04 21:19:59 +01:00
78b6fe52c7 fixed CI/CD path from "/builds" to "/tmp/builds" 2024-03-04 21:19:59 +01:00
f82d73bb01 run different jobs on "$CI_DEFAULT_BRANCH" 2024-03-04 21:19:59 +01:00
416df311b8 removed pylint 2024-03-04 21:19:59 +01:00
a6ea8241c2 disabled pylint 2024-03-04 21:19:59 +01:00
e70f70d806 disabled code_quality debug 2024-03-04 21:19:59 +01:00
77be5772c4 Update .codeclimate.yml 2024-03-04 21:19:59 +01:00
6c1b05c66a fixed test_coverage (fail on matrix) 2024-03-04 21:19:59 +01:00
a54411a957 added code_quality debug 2024-03-04 21:19:59 +01:00
90e0cb8e84 added code_quality “SOURCE_CODE” variable 2024-03-04 21:19:59 +01:00
eecb59e2e4 removed "cython" from "test" 2024-03-04 21:19:59 +01:00
4c0f65faec removed tests for "23.04"
> gcc -Wsign-compare -DNDEBUG -g -fwrapv -O3 -Wall -fPIC -I/tmp/pip-install-sazb8fvo/httptools_694f06fa2e354ed9ba9f5c167df7fce4/vendor/llhttp/include -I/tmp/pip-install-sazb8fvo/httptools_694f06fa2e354ed9ba9f5c167df7fce4/vendor/llhttp/src -I/usr/local/include/python3.11 -c httptools/parser/parser.c -o build/temp.linux-x86_64-cpython-311/httptools/parser/parser.o -O2
      httptools/parser/parser.c:212:12: fatal error: longintrepr.h: No such file or directory
2024-03-04 21:19:59 +01:00
e41084f5c5 added tests for Ubuntu "Mantic Minotaur" 2024-03-04 21:19:59 +01:00
36d5b83fb8 requirements.txt updated 2024-03-04 21:19:59 +01:00
11138c2191 updated debian bookworm 12 dependencies 2024-03-04 21:19:59 +01:00
ff9e85e32b updated test to debian bookworm 2024-03-04 21:19:59 +01:00
cb6a089678 fixed testing dependency 2024-03-04 21:19:59 +01:00
085186f82a added gcc as dependency 2024-03-04 21:19:59 +01:00
f77d3feee1 fixes 2024-03-04 21:19:59 +01:00
f2721c7663 fixed debian package versions 2024-03-04 21:19:59 +01:00
40cb5518cb fixed versions & added 16.2 as supported 2024-03-04 21:19:59 +01:00
021c0ac38d added os specific requirements.txt 2024-03-04 21:19:59 +01:00
9c22628b4e implemented python test matrix for different python dependencies on different os releases 2024-03-04 21:19:59 +01:00
966b421dad README.md updated 2024-03-04 21:19:59 +01:00
7f8752a93d updated ubuntu from 22.10 (EOL) to 23.04 2024-03-04 21:19:59 +01:00
30979fd18e requirements.txt updated 2024-03-04 21:19:59 +01:00
72965cc879 added 16.1 as supported nvidia driver release 2024-03-04 21:19:59 +01:00
1887cbc534 added macOS as supported host (using python-venv) 2024-03-04 21:19:59 +01:00
2e942f4553 added Docker supported system architectures 2024-03-04 21:19:59 +01:00
3dda920a52 added linkt to driver compatibility section 2024-03-04 21:19:59 +01:00
765a994d83 requirements.txt updated 2024-03-04 21:19:59 +01:00
23488f94d4 added support for 16.0 drivers to readme 2024-03-04 21:19:59 +01:00
f9341cdab4 fixed docker image name (gitlab registry) 2024-03-04 21:19:59 +01:00
cad81ad1d6 fixed deploy docker 2024-03-04 21:19:59 +01:00
b07b7da2f3 fixed new docker registry image path 2024-03-04 21:19:59 +01:00
1ef7dd82f6 toggle api endpoints 2024-03-04 21:19:25 +01:00
5a1b1a5950 typos 2024-03-04 21:19:25 +01:00
83f4b42f01 added information about ipv6 may be must disabled 2024-03-04 21:19:25 +01:00
a3baaab26f removed mysql from included docker drivers 2024-03-04 21:19:25 +01:00
aa4ebfce73 added docker command to logging section
thanks to @libreshare (#2)
2024-03-04 21:19:25 +01:00
aa746feb13 improvements
thanks to @AbsolutelyFree (#1)
2024-03-04 21:19:25 +01:00
fce0eb6d74 fixed "deploy:pacman" 2024-03-04 21:19:25 +01:00
32806e5cca push multiarch image to docker-hub 2024-03-04 21:19:25 +01:00
50eddeecfc fixed mariadb-client installation
ref. https://github.com/PyMySQL/mysqlclient/discussions/624
2024-03-04 21:19:25 +01:00
092e6186ab added missing "pkg-config" for "mysqlclient==2.2.0"
ref. https://stackoverflow.com/questions/76533384/docker-alpine-build-fails-on-mysqlclient-installation-with-error-exception-can
2024-03-04 21:19:25 +01:00
acbe889fd9 fixed versions 2024-03-04 21:19:25 +01:00
05cad95c2a refactored docker-compose.yml so very simple example, and moved proxy to "examples" directory 2024-03-04 21:19:25 +01:00
c9e36759e3 added 15.3 to supported drivers list 2024-03-04 21:19:25 +01:00
d116eec626 updated compatibility list 2024-03-04 21:19:25 +01:00
b1620154db docker-compose.yml - added note for TZ 2024-03-04 21:19:25 +01:00
4181095791 requirements.txt updated 2024-03-04 21:19:25 +01:00
248c70a862 Merge branch 'dev' into db 2023-06-12 15:19:28 +02:00
39a2408d8d migrated api to database-config (Site, Instance) 2023-06-12 15:19:06 +02:00
18807401e4 orm improvements & fixes 2023-06-12 15:14:12 +02:00
5e47ad7729 fastapi openapi url 2023-06-12 15:13:53 +02:00
20448bc587 improved relationships 2023-06-12 15:13:29 +02:00
5e945bc43a code styling 2023-06-12 14:47:32 +02:00
b4150fa527 implemented Site and Instance orm models including initialization 2023-06-12 12:42:13 +02:00
38e1a1725c code styling 2023-06-12 12:40:10 +02:00
7 changed files with 616 additions and 122 deletions

View File

@ -297,17 +297,13 @@ gemnasium-python-dependency_scanning:
deploy:docker:
extends: .deploy
image: docker:dind
stage: deploy
tags: [ docker ]
before_script:
- echo "Building docker image for commit $CI_COMMIT_SHA with version $CI_COMMIT_REF_NAME"
- docker buildx inspect
- docker buildx create --use
script:
- echo "========== GitLab-Registry =========="
- docker login -u $CI_REGISTRY_USER -p $CI_REGISTRY_PASSWORD $CI_REGISTRY
- IMAGE=$CI_REGISTRY/$CI_PROJECT_PATH
- IMAGE=$CI_REGISTRY/$CI_PROJECT_PATH/$CI_COMMIT_REF_NAME
- docker buildx build --progress=plain --platform $DOCKER_BUILDX_PLATFORM --build-arg VERSION=$CI_COMMIT_REF_NAME --build-arg COMMIT=$CI_COMMIT_SHA --tag $IMAGE:$CI_COMMIT_REF_NAME --push .
- docker buildx build --progress=plain --platform $DOCKER_BUILDX_PLATFORM --build-arg VERSION=$CI_COMMIT_REF_NAME --build-arg COMMIT=$CI_COMMIT_SHA --tag $IMAGE:latest --push .
- echo "========== Docker-Hub =========="

View File

@ -26,7 +26,7 @@ Only the clients need a connection to this service on configured port.
---
[[_TOC_]]
[TOC]
# Setup (Service)

View File

@ -1,55 +1,80 @@
import logging
from base64 import b64encode as b64enc
from calendar import timegm
from contextlib import asynccontextmanager
from datetime import datetime
from hashlib import sha256
from uuid import uuid4
from os.path import join, dirname
from json import loads as json_loads
from os import getenv as env
from os.path import join, dirname
from uuid import uuid4
from dateutil.relativedelta import relativedelta
from dotenv import load_dotenv
from fastapi import FastAPI
from fastapi.requests import Request
from json import loads as json_loads
from datetime import datetime, timedelta
from dateutil.relativedelta import relativedelta
from calendar import timegm
from jose import jws, jwk, jwt, JWTError
from jose import jws, jwt, JWTError
from jose.constants import ALGORITHMS
from starlette.middleware.cors import CORSMiddleware
from starlette.responses import StreamingResponse, JSONResponse as JSONr, HTMLResponse as HTMLr, Response, RedirectResponse
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from starlette.middleware.cors import CORSMiddleware
from starlette.responses import StreamingResponse, JSONResponse as JSONr, HTMLResponse as HTMLr, Response, \
RedirectResponse
from util import load_key, load_file
from orm import Origin, Lease, init as db_init, migrate
from orm import init as db_init, migrate, Site, Instance, Origin, Lease
# Load variables
load_dotenv('../version.env')
# Get current timezone
TZ = datetime.now().astimezone().tzinfo
# Load basic variables
VERSION, COMMIT, DEBUG = env('VERSION', 'unknown'), env('COMMIT', 'unknown'), bool(env('DEBUG', False))
config = dict(openapi_url=None, docs_url=None, redoc_url=None) # dict(openapi_url='/-/openapi.json', docs_url='/-/docs', redoc_url='/-/redoc')
app = FastAPI(title='FastAPI-DLS', description='Minimal Delegated License Service (DLS).', version=VERSION, **config)
# Database connection
db = create_engine(str(env('DATABASE', 'sqlite:///db.sqlite')))
db_init(db), migrate(db)
# everything prefixed with "INSTANCE_*" is used as "SERVICE_INSTANCE_*" or "SI_*" in official dls service
# Load DLS variables (all prefixed with "INSTANCE_*" is used as "SERVICE_INSTANCE_*" or "SI_*" in official dls service)
DLS_URL = str(env('DLS_URL', 'localhost'))
DLS_PORT = int(env('DLS_PORT', '443'))
SITE_KEY_XID = str(env('SITE_KEY_XID', '00000000-0000-0000-0000-000000000000'))
INSTANCE_REF = str(env('INSTANCE_REF', '10000000-0000-0000-0000-000000000001'))
ALLOTMENT_REF = str(env('ALLOTMENT_REF', '20000000-0000-0000-0000-000000000001'))
INSTANCE_KEY_RSA = load_key(str(env('INSTANCE_KEY_RSA', join(dirname(__file__), 'cert/instance.private.pem'))))
INSTANCE_KEY_PUB = load_key(str(env('INSTANCE_KEY_PUB', join(dirname(__file__), 'cert/instance.public.pem'))))
TOKEN_EXPIRE_DELTA = relativedelta(days=int(env('TOKEN_EXPIRE_DAYS', 1)), hours=int(env('TOKEN_EXPIRE_HOURS', 0)))
LEASE_EXPIRE_DELTA = relativedelta(days=int(env('LEASE_EXPIRE_DAYS', 90)), hours=int(env('LEASE_EXPIRE_HOURS', 0)))
LEASE_RENEWAL_PERIOD = float(env('LEASE_RENEWAL_PERIOD', 0.15))
LEASE_RENEWAL_DELTA = timedelta(days=int(env('LEASE_EXPIRE_DAYS', 90)), hours=int(env('LEASE_EXPIRE_HOURS', 0)))
CLIENT_TOKEN_EXPIRE_DELTA = relativedelta(years=12)
CORS_ORIGINS = str(env('CORS_ORIGINS', '')).split(',') if (env('CORS_ORIGINS')) else [f'https://{DLS_URL}']
jwt_encode_key = jwk.construct(INSTANCE_KEY_RSA.export_key().decode('utf-8'), algorithm=ALGORITHMS.RS256)
jwt_decode_key = jwk.construct(INSTANCE_KEY_PUB.export_key().decode('utf-8'), algorithm=ALGORITHMS.RS256)
ALLOTMENT_REF = str(env('ALLOTMENT_REF', '20000000-0000-0000-0000-000000000001')) # todo
# FastAPI
@asynccontextmanager
async def lifespan(_: FastAPI):
# on startup
default_instance = Instance.get_default_instance(db)
lease_renewal_period = default_instance.lease_renewal_period
lease_renewal_delta = default_instance.get_lease_renewal_delta()
client_token_expire_delta = default_instance.get_client_token_expire_delta()
logger.info(f'''
Using timezone: {str(TZ)}. Make sure this is correct and match your clients!
Your clients will renew their license every {str(Lease.calculate_renewal(lease_renewal_period, lease_renewal_delta))}.
If the renewal fails, the license is valid for {str(lease_renewal_delta)}.
Your client-token file (.tok) is valid for {str(client_token_expire_delta)}.
''')
logger.info(f'Debug is {"enabled" if DEBUG else "disabled"}.')
validate_settings()
yield
# on shutdown
logger.info(f'Shutting down ...')
config = dict(openapi_url=None, docs_url=None, redoc_url=None) # dict(openapi_url='/-/openapi.json', docs_url='/-/docs', redoc_url='/-/redoc')
app = FastAPI(title='FastAPI-DLS', description='Minimal Delegated License Service (DLS).', version=VERSION, lifespan=lifespan, **config)
app.debug = DEBUG
app.add_middleware(
@ -60,17 +85,36 @@ app.add_middleware(
allow_headers=['*'],
)
logging.basicConfig()
# Logging
LOG_LEVEL = logging.DEBUG if DEBUG else logging.INFO
logging.basicConfig(format='[{levelname:^7}] [{module:^15}] {message}', style='{')
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG if DEBUG else logging.INFO)
logger.setLevel(LOG_LEVEL)
logging.getLogger('util').setLevel(LOG_LEVEL)
logging.getLogger('NV').setLevel(LOG_LEVEL)
def __get_token(request: Request) -> dict:
# Helper
def __get_token(request: Request, jwt_decode_key: "jose.jwt") -> dict:
authorization_header = request.headers.get('authorization')
token = authorization_header.split(' ')[1]
return jwt.decode(token=token, key=jwt_decode_key, algorithms=ALGORITHMS.RS256, options={'verify_aud': False})
def validate_settings():
session = sessionmaker(bind=db)()
lease_expire_delta_min, lease_expire_delta_max = 86_400, 7_776_000
for instance in session.query(Instance).all():
lease_expire_delta = instance.lease_expire_delta
if lease_expire_delta < 86_400 or lease_expire_delta > 7_776_000:
logging.warning(f'> [ instance ]: {instance.instance_ref}: "lease_expire_delta" should be between {lease_expire_delta_min} and {lease_expire_delta_max}')
session.close()
# Endpoints
@app.get('/', summary='Index')
async def index():
return RedirectResponse('/-/readme')
@ -88,18 +132,20 @@ async def _health():
@app.get('/-/config', summary='* Config', description='returns environment variables.')
async def _config():
default_site, default_instance = Site.get_default_site(db), Instance.get_default_instance(db)
return JSONr({
'VERSION': str(VERSION),
'COMMIT': str(COMMIT),
'DEBUG': str(DEBUG),
'DLS_URL': str(DLS_URL),
'DLS_PORT': str(DLS_PORT),
'SITE_KEY_XID': str(SITE_KEY_XID),
'INSTANCE_REF': str(INSTANCE_REF),
'SITE_KEY_XID': str(default_site.site_key),
'INSTANCE_REF': str(default_instance.instance_ref),
'ALLOTMENT_REF': [str(ALLOTMENT_REF)],
'TOKEN_EXPIRE_DELTA': str(TOKEN_EXPIRE_DELTA),
'LEASE_EXPIRE_DELTA': str(LEASE_EXPIRE_DELTA),
'LEASE_RENEWAL_PERIOD': str(LEASE_RENEWAL_PERIOD),
'TOKEN_EXPIRE_DELTA': str(default_instance.get_token_expire_delta()),
'LEASE_EXPIRE_DELTA': str(default_instance.get_lease_expire_delta()),
'LEASE_RENEWAL_PERIOD': str(default_instance.lease_renewal_period),
'CORS_ORIGINS': str(CORS_ORIGINS),
'TZ': str(TZ),
})
@ -108,7 +154,8 @@ async def _config():
@app.get('/-/readme', summary='* Readme')
async def _readme():
from markdown import markdown
content = load_file('../README.md').decode('utf-8')
from util import load_file
content = load_file(join(dirname(__file__), '../README.md')).decode('utf-8')
return HTMLr(markdown(text=content, extensions=['tables', 'fenced_code', 'md_in_html', 'nl2br', 'toc']))
@ -157,8 +204,7 @@ async def _origins(request: Request, leases: bool = False):
for origin in session.query(Origin).all():
x = origin.serialize()
if leases:
serialize = dict(renewal_period=LEASE_RENEWAL_PERIOD, renewal_delta=LEASE_RENEWAL_DELTA)
x['leases'] = list(map(lambda _: _.serialize(**serialize), Lease.find_by_origin_ref(db, origin.origin_ref)))
x['leases'] = list(map(lambda _: _.serialize(), Lease.find_by_origin_ref(db, origin.origin_ref)))
response.append(x)
session.close()
return JSONr(response)
@ -175,8 +221,7 @@ async def _leases(request: Request, origin: bool = False):
session = sessionmaker(bind=db)()
response = []
for lease in session.query(Lease).all():
serialize = dict(renewal_period=LEASE_RENEWAL_PERIOD, renewal_delta=LEASE_RENEWAL_DELTA)
x = lease.serialize(**serialize)
x = lease.serialize()
if origin:
lease_origin = session.query(Origin).filter(Origin.origin_ref == lease.origin_ref).first()
if lease_origin is not None:
@ -203,7 +248,13 @@ async def _lease_delete(request: Request, lease_ref: str):
@app.get('/-/client-token', summary='* Client-Token', description='creates a new messenger token for this service instance')
async def _client_token():
cur_time = datetime.utcnow()
exp_time = cur_time + CLIENT_TOKEN_EXPIRE_DELTA
default_instance = Instance.get_default_instance(db)
public_key = default_instance.get_public_key()
# todo: implemented request parameter to support different instances
jwt_encode_key = default_instance.get_jwt_encode_key()
exp_time = cur_time + default_instance.get_client_token_expire_delta()
payload = {
"jti": str(uuid4()),
@ -216,7 +267,7 @@ async def _client_token():
"scope_ref_list": [ALLOTMENT_REF],
"fulfillment_class_ref_list": [],
"service_instance_configuration": {
"nls_service_instance_ref": INSTANCE_REF,
"nls_service_instance_ref": default_instance.instance_ref,
"svc_port_set_list": [
{
"idx": 0,
@ -228,10 +279,10 @@ async def _client_token():
},
"service_instance_public_key_configuration": {
"service_instance_public_key_me": {
"mod": hex(INSTANCE_KEY_PUB.public_key().n)[2:],
"exp": int(INSTANCE_KEY_PUB.public_key().e),
"mod": hex(public_key.public_key().n)[2:],
"exp": int(public_key.public_key().e),
},
"service_instance_public_key_pem": INSTANCE_KEY_PUB.export_key().decode('utf-8'),
"service_instance_public_key_pem": public_key.export_key().decode('utf-8'),
"key_retention_mode": "LATEST_ONLY"
},
}
@ -313,13 +364,16 @@ async def auth_v1_code(request: Request):
delta = relativedelta(minutes=15)
expires = cur_time + delta
default_site = Site.get_default_site(db)
jwt_encode_key = Instance.get_default_instance(db).get_jwt_encode_key()
payload = {
'iat': timegm(cur_time.timetuple()),
'exp': timegm(expires.timetuple()),
'challenge': j.get('code_challenge'),
'origin_ref': j.get('origin_ref'),
'key_ref': SITE_KEY_XID,
'kid': SITE_KEY_XID
'key_ref': default_site.site_key,
'kid': default_site.site_key,
}
auth_code = jws.sign(payload, key=jwt_encode_key, headers={'kid': payload.get('kid')}, algorithm=ALGORITHMS.RS256)
@ -339,8 +393,11 @@ async def auth_v1_code(request: Request):
async def auth_v1_token(request: Request):
j, cur_time = json_loads((await request.body()).decode('utf-8')), datetime.utcnow()
default_site, default_instance = Site.get_default_site(db), Instance.get_default_instance(db)
jwt_encode_key, jwt_decode_key = default_instance.get_jwt_encode_key(), default_instance.get_jwt_decode_key()
try:
payload = jwt.decode(token=j.get('auth_code'), key=jwt_decode_key)
payload = jwt.decode(token=j.get('auth_code'), key=jwt_decode_key, algorithms=[ALGORITHMS.RS256])
except JWTError as e:
return JSONr(status_code=400, content={'status': 400, 'title': 'invalid token', 'detail': str(e)})
@ -352,7 +409,7 @@ async def auth_v1_token(request: Request):
if payload.get('challenge') != challenge:
return JSONr(status_code=401, content={'status': 401, 'detail': 'expected challenge did not match verifier'})
access_expires_on = cur_time + TOKEN_EXPIRE_DELTA
access_expires_on = cur_time + default_instance.get_token_expire_delta()
new_payload = {
'iat': timegm(cur_time.timetuple()),
@ -361,8 +418,8 @@ async def auth_v1_token(request: Request):
'aud': 'https://cls.nvidia.org',
'exp': timegm(access_expires_on.timetuple()),
'origin_ref': origin_ref,
'key_ref': SITE_KEY_XID,
'kid': SITE_KEY_XID,
'key_ref': default_site.site_key,
'kid': default_site.site_key,
}
auth_token = jwt.encode(new_payload, key=jwt_encode_key, headers={'kid': payload.get('kid')}, algorithm=ALGORITHMS.RS256)
@ -379,10 +436,13 @@ async def auth_v1_token(request: Request):
# venv/lib/python3.9/site-packages/nls_services_lease/test/test_lease_multi_controller.py
@app.post('/leasing/v1/lessor', description='request multiple leases (borrow) for current origin')
async def leasing_v1_lessor(request: Request):
j, token, cur_time = json_loads((await request.body()).decode('utf-8')), __get_token(request), datetime.utcnow()
j, cur_time = json_loads((await request.body()).decode('utf-8')), datetime.utcnow()
default_instance = Instance.get_default_instance(db)
jwt_decode_key = default_instance.get_jwt_decode_key()
try:
token = __get_token(request)
token = __get_token(request, jwt_decode_key)
except JWTError:
return JSONr(status_code=401, content={'status': 401, 'detail': 'token is not valid'})
@ -396,7 +456,7 @@ async def leasing_v1_lessor(request: Request):
# return JSONr(status_code=500, detail=f'no service instances found for scopes: ["{scope_ref}"]')
lease_ref = str(uuid4())
expires = cur_time + LEASE_EXPIRE_DELTA
expires = cur_time + default_instance.get_lease_expire_delta()
lease_result_list.append({
"ordinal": 0,
# https://docs.nvidia.com/license-system/latest/nvidia-license-system-user-guide/index.html
@ -404,13 +464,13 @@ async def leasing_v1_lessor(request: Request):
"ref": lease_ref,
"created": cur_time.isoformat(),
"expires": expires.isoformat(),
"recommended_lease_renewal": LEASE_RENEWAL_PERIOD,
"recommended_lease_renewal": default_instance.lease_renewal_period,
"offline_lease": "true",
"license_type": "CONCURRENT_COUNTED_SINGLE"
}
})
data = Lease(origin_ref=origin_ref, lease_ref=lease_ref, lease_created=cur_time, lease_expires=expires)
data = Lease(instance_ref=default_instance.instance_ref, origin_ref=origin_ref, lease_ref=lease_ref, lease_created=cur_time, lease_expires=expires)
Lease.create_or_update(db, data)
response = {
@ -427,7 +487,14 @@ async def leasing_v1_lessor(request: Request):
# venv/lib/python3.9/site-packages/nls_dal_service_instance_dls/schema/service_instance/V1_0_21__product_mapping.sql
@app.get('/leasing/v1/lessor/leases', description='get active leases for current origin')
async def leasing_v1_lessor_lease(request: Request):
token, cur_time = __get_token(request), datetime.utcnow()
cur_time = datetime.utcnow()
jwt_decode_key = Instance.get_default_instance(db).get_jwt_decode_key()
try:
token = __get_token(request, jwt_decode_key)
except JWTError:
return JSONr(status_code=401, content={'status': 401, 'detail': 'token is not valid'})
origin_ref = token.get('origin_ref')
@ -447,7 +514,15 @@ async def leasing_v1_lessor_lease(request: Request):
# venv/lib/python3.9/site-packages/nls_core_lease/lease_single.py
@app.put('/leasing/v1/lease/{lease_ref}', description='renew a lease')
async def leasing_v1_lease_renew(request: Request, lease_ref: str):
token, cur_time = __get_token(request), datetime.utcnow()
cur_time = datetime.utcnow()
default_instance = Instance.get_default_instance(db)
jwt_decode_key = default_instance.get_jwt_decode_key()
try:
token = __get_token(request, jwt_decode_key)
except JWTError:
return JSONr(status_code=401, content={'status': 401, 'detail': 'token is not valid'})
origin_ref = token.get('origin_ref')
logging.info(f'> [ renew ]: {origin_ref}: renew {lease_ref}')
@ -456,11 +531,11 @@ async def leasing_v1_lease_renew(request: Request, lease_ref: str):
if entity is None:
return JSONr(status_code=404, content={'status': 404, 'detail': 'requested lease not available'})
expires = cur_time + LEASE_EXPIRE_DELTA
expires = cur_time + default_instance.get_lease_expire_delta()
response = {
"lease_ref": lease_ref,
"expires": expires.isoformat(),
"recommended_lease_renewal": LEASE_RENEWAL_PERIOD,
"recommended_lease_renewal": default_instance.lease_renewal_period,
"offline_lease": True,
"prompts": None,
"sync_timestamp": cur_time.isoformat(),
@ -474,7 +549,14 @@ async def leasing_v1_lease_renew(request: Request, lease_ref: str):
# venv/lib/python3.9/site-packages/nls_services_lease/test/test_lease_single_controller.py
@app.delete('/leasing/v1/lease/{lease_ref}', description='release (return) a lease')
async def leasing_v1_lease_delete(request: Request, lease_ref: str):
token, cur_time = __get_token(request), datetime.utcnow()
cur_time = datetime.utcnow()
jwt_decode_key = Instance.get_default_instance(db).get_jwt_decode_key()
try:
token = __get_token(request, jwt_decode_key)
except JWTError:
return JSONr(status_code=401, content={'status': 401, 'detail': 'token is not valid'})
origin_ref = token.get('origin_ref')
logging.info(f'> [ return ]: {origin_ref}: return {lease_ref}')
@ -500,7 +582,14 @@ async def leasing_v1_lease_delete(request: Request, lease_ref: str):
# venv/lib/python3.9/site-packages/nls_services_lease/test/test_lease_multi_controller.py
@app.delete('/leasing/v1/lessor/leases', description='release all leases')
async def leasing_v1_lessor_lease_remove(request: Request):
token, cur_time = __get_token(request), datetime.utcnow()
cur_time = datetime.utcnow()
jwt_decode_key = Instance.get_default_instance(db).get_jwt_decode_key()
try:
token = __get_token(request, jwt_decode_key)
except JWTError:
return JSONr(status_code=401, content={'status': 401, 'detail': 'token is not valid'})
origin_ref = token.get('origin_ref')
@ -522,6 +611,8 @@ async def leasing_v1_lessor_lease_remove(request: Request):
async def leasing_v1_lessor_shutdown(request: Request):
j, cur_time = json_loads((await request.body()).decode('utf-8')), datetime.utcnow()
jwt_decode_key = Instance.get_default_instance(db).get_jwt_decode_key()
token = j.get('token')
token = jwt.decode(token=token, key=jwt_decode_key, algorithms=ALGORITHMS.RS256, options={'verify_aud': False})
origin_ref = token.get('origin_ref')
@ -540,18 +631,6 @@ async def leasing_v1_lessor_shutdown(request: Request):
return JSONr(response)
@app.on_event('startup')
async def app_on_startup():
logger.info(f'''
Using timezone: {str(TZ)}. Make sure this is correct and match your clients!
Your clients renew their license every {str(Lease.calculate_renewal(LEASE_RENEWAL_PERIOD, LEASE_RENEWAL_DELTA))}.
If the renewal fails, the license is {str(LEASE_RENEWAL_DELTA)} valid.
Your client-token file (.tok) is valid for {str(CLIENT_TOKEN_EXPIRE_DELTA)}.
''')
if __name__ == '__main__':
import uvicorn

View File

@ -1,18 +1,144 @@
import logging
from datetime import datetime, timedelta
from dateutil.relativedelta import relativedelta
from sqlalchemy import Column, VARCHAR, CHAR, ForeignKey, DATETIME, update, and_, inspect, text
from dateutil.relativedelta import relativedelta
from sqlalchemy import Column, VARCHAR, CHAR, ForeignKey, DATETIME, update, and_, inspect, text, BLOB, INT, FLOAT
from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker, declarative_base
from sqlalchemy.orm import sessionmaker, declarative_base, Session, relationship
from util import NV
logging.basicConfig()
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
Base = declarative_base()
class Site(Base):
__tablename__ = "site"
INITIAL_SITE_KEY_XID = '00000000-0000-0000-0000-000000000000'
INITIAL_SITE_NAME = 'default'
site_key = Column(CHAR(length=36), primary_key=True, unique=True, index=True) # uuid4, SITE_KEY_XID
name = Column(VARCHAR(length=256), nullable=False)
def __str__(self):
return f'SITE_KEY_XID: {self.site_key}'
@staticmethod
def create_statement(engine: Engine):
from sqlalchemy.schema import CreateTable
return CreateTable(Site.__table__).compile(engine)
@staticmethod
def get_default_site(engine: Engine) -> "Site":
session = sessionmaker(bind=engine)()
entity = session.query(Site).filter(Site.site_key == Site.INITIAL_SITE_KEY_XID).first()
session.close()
return entity
class Instance(Base):
__tablename__ = "instance"
DEFAULT_INSTANCE_REF = '10000000-0000-0000-0000-000000000001'
DEFAULT_TOKEN_EXPIRE_DELTA = 86_400 # 1 day
DEFAULT_LEASE_EXPIRE_DELTA = 7_776_000 # 90 days
DEFAULT_LEASE_RENEWAL_PERIOD = 0.15
DEFAULT_CLIENT_TOKEN_EXPIRE_DELTA = 378_432_000 # 12 years
# 1 day = 86400 (min. in production setup, max 90 days), 1 hour = 3600
instance_ref = Column(CHAR(length=36), primary_key=True, unique=True, index=True) # uuid4, INSTANCE_REF
site_key = Column(CHAR(length=36), ForeignKey(Site.site_key, ondelete='CASCADE'), nullable=False, index=True) # uuid4
private_key = Column(BLOB(length=2048), nullable=False)
public_key = Column(BLOB(length=512), nullable=False)
token_expire_delta = Column(INT(), nullable=False, default=DEFAULT_TOKEN_EXPIRE_DELTA, comment='in seconds')
lease_expire_delta = Column(INT(), nullable=False, default=DEFAULT_LEASE_EXPIRE_DELTA, comment='in seconds')
lease_renewal_period = Column(FLOAT(precision=2), nullable=False, default=DEFAULT_LEASE_RENEWAL_PERIOD)
client_token_expire_delta = Column(INT(), nullable=False, default=DEFAULT_CLIENT_TOKEN_EXPIRE_DELTA, comment='in seconds')
__origin = relationship(Site, foreign_keys=[site_key])
def __str__(self):
return f'INSTANCE_REF: {self.instance_ref} (SITE_KEY_XID: {self.site_key})'
@staticmethod
def create_statement(engine: Engine):
from sqlalchemy.schema import CreateTable
return CreateTable(Instance.__table__).compile(engine)
@staticmethod
def create_or_update(engine: Engine, instance: "Instance"):
session = sessionmaker(bind=engine)()
entity = session.query(Instance).filter(Instance.instance_ref == instance.instance_ref).first()
if entity is None:
session.add(instance)
else:
x = dict(
site_key=instance.site_key,
private_key=instance.private_key,
public_key=instance.public_key,
token_expire_delta=instance.token_expire_delta,
lease_expire_delta=instance.lease_expire_delta,
lease_renewal_period=instance.lease_renewal_period,
client_token_expire_delta=instance.client_token_expire_delta,
)
session.execute(update(Instance).where(Instance.instance_ref == instance.instance_ref).values(**x))
session.commit()
session.flush()
session.close()
# todo: validate on startup that "lease_expire_delta" is between 1 day and 90 days
@staticmethod
def get_default_instance(engine: Engine) -> "Instance":
session = sessionmaker(bind=engine)()
site = Site.get_default_site(engine)
entity = session.query(Instance).filter(Instance.site_key == site.site_key).first()
session.close()
return entity
def get_token_expire_delta(self) -> "dateutil.relativedelta.relativedelta":
return relativedelta(seconds=self.token_expire_delta)
def get_lease_expire_delta(self) -> "dateutil.relativedelta.relativedelta":
return relativedelta(seconds=self.lease_expire_delta)
def get_lease_renewal_delta(self) -> "datetime.timedelta":
return timedelta(seconds=self.lease_expire_delta)
def get_client_token_expire_delta(self) -> "dateutil.relativedelta.relativedelta":
return relativedelta(seconds=self.client_token_expire_delta)
def __get_private_key(self) -> "RsaKey":
return parse_key(self.private_key)
def get_public_key(self) -> "RsaKey":
return parse_key(self.public_key)
def get_jwt_encode_key(self) -> "jose.jkw":
from jose import jwk
from jose.constants import ALGORITHMS
return jwk.construct(self.__get_private_key().export_key().decode('utf-8'), algorithm=ALGORITHMS.RS256)
def get_jwt_decode_key(self) -> "jose.jwt":
from jose import jwk
from jose.constants import ALGORITHMS
return jwk.construct(self.get_public_key().export_key().decode('utf-8'), algorithm=ALGORITHMS.RS256)
def get_private_key_str(self, encoding: str = 'utf-8') -> str:
return self.private_key.decode(encoding)
def get_public_key_str(self, encoding: str = 'utf-8') -> str:
return self.private_key.decode(encoding)
class Origin(Base):
__tablename__ = "origin"
origin_ref = Column(CHAR(length=36), primary_key=True, unique=True, index=True) # uuid4
# service_instance_xid = Column(CHAR(length=36), nullable=False, index=True) # uuid4 # not necessary, we only support one service_instance_xid ('INSTANCE_REF')
hostname = Column(VARCHAR(length=256), nullable=True)
guest_driver_version = Column(VARCHAR(length=10), nullable=True)
@ -23,6 +149,8 @@ class Origin(Base):
return f'Origin(origin_ref={self.origin_ref}, hostname={self.hostname})'
def serialize(self) -> dict:
_ = NV().find(self.guest_driver_version)
return {
'origin_ref': self.origin_ref,
# 'service_instance_xid': self.service_instance_xid,
@ -30,6 +158,7 @@ class Origin(Base):
'guest_driver_version': self.guest_driver_version,
'os_platform': self.os_platform,
'os_version': self.os_version,
'$driver': _ if _ is not None else None,
}
@staticmethod
@ -70,18 +199,24 @@ class Origin(Base):
class Lease(Base):
__tablename__ = "lease"
instance_ref = Column(CHAR(length=36), ForeignKey(Instance.instance_ref, ondelete='CASCADE'), nullable=False, index=True) # uuid4
lease_ref = Column(CHAR(length=36), primary_key=True, nullable=False, index=True) # uuid4
origin_ref = Column(CHAR(length=36), ForeignKey(Origin.origin_ref, ondelete='CASCADE'), nullable=False, index=True) # uuid4
# scope_ref = Column(CHAR(length=36), nullable=False, index=True) # uuid4 # not necessary, we only support one scope_ref ('ALLOTMENT_REF')
lease_created = Column(DATETIME(), nullable=False)
lease_expires = Column(DATETIME(), nullable=False)
lease_updated = Column(DATETIME(), nullable=False)
__instance = relationship(Instance, foreign_keys=[instance_ref])
__origin = relationship(Origin, foreign_keys=[origin_ref])
def __repr__(self):
return f'Lease(origin_ref={self.origin_ref}, lease_ref={self.lease_ref}, expires={self.lease_expires})'
def serialize(self, renewal_period: float, renewal_delta: timedelta) -> dict:
def serialize(self) -> dict:
renewal_period = self.__instance.lease_renewal_period
renewal_delta = self.__instance.get_lease_renewal_delta
lease_renewal = int(Lease.calculate_renewal(renewal_period, renewal_delta).total_seconds())
lease_renewal = self.lease_updated + relativedelta(seconds=lease_renewal)
@ -191,38 +326,110 @@ class Lease(Base):
return renew
def init_default_site(session: Session):
from app.util import generate_key
private_key = generate_key()
public_key = private_key.public_key()
site = Site(
site_key=Site.INITIAL_SITE_KEY_XID,
name=Site.INITIAL_SITE_NAME
)
session.add(site)
session.commit()
instance = Instance(
instance_ref=Instance.DEFAULT_INSTANCE_REF,
site_key=site.site_key,
private_key=private_key.export_key(),
public_key=public_key.export_key(),
)
session.add(instance)
session.commit()
def init(engine: Engine):
tables = [Origin, Lease]
tables = [Site, Instance, Origin, Lease]
db = inspect(engine)
session = sessionmaker(bind=engine)()
for table in tables:
if not db.dialect.has_table(engine.connect(), table.__tablename__):
exists = db.dialect.has_table(engine.connect(), table.__tablename__)
logger.info(f'> Table "{table.__tablename__:<16}" exists: {exists}')
if not exists:
session.execute(text(str(table.create_statement(engine))))
session.commit()
# create default site
cnt = session.query(Site).count()
if cnt == 0:
init_default_site(session)
session.flush()
session.close()
def migrate(engine: Engine):
from os import getenv as env
from os.path import join, dirname, isfile
from util import load_key
db = inspect(engine)
def upgrade_1_0_to_1_1():
x = db.dialect.get_columns(engine.connect(), Lease.__tablename__)
x = next(_ for _ in x if _['name'] == 'origin_ref')
if x['primary_key'] > 0:
print('Found old database schema with "origin_ref" as primary-key in "lease" table. Dropping table!')
print(' Your leases are recreated on next renewal!')
print(' If an error message appears on the client, you can ignore it.')
Lease.__table__.drop(bind=engine)
init(engine)
# todo: add update guide to use 1.LATEST to 2.0
def upgrade_1_x_to_2_0():
site = Site.get_default_site(engine)
logger.info(site)
instance = Instance.get_default_instance(engine)
logger.info(instance)
# def upgrade_1_2_to_1_3():
# x = db.dialect.get_columns(engine.connect(), Lease.__tablename__)
# x = next((_ for _ in x if _['name'] == 'scope_ref'), None)
# if x is None:
# Lease.scope_ref.compile()
# column_name = Lease.scope_ref.name
# column_type = Lease.scope_ref.type.compile(engine.dialect)
# engine.execute(f'ALTER TABLE "{Lease.__tablename__}" ADD COLUMN "{column_name}" {column_type}')
# SITE_KEY_XID
if site_key := env('SITE_KEY_XID', None) is not None:
site.site_key = str(site_key)
upgrade_1_0_to_1_1()
# upgrade_1_2_to_1_3()
# INSTANCE_REF
if instance_ref := env('INSTANCE_REF', None) is not None:
instance.instance_ref = str(instance_ref)
# ALLOTMENT_REF
if allotment_ref := env('ALLOTMENT_REF', None) is not None:
pass # todo
# INSTANCE_KEY_RSA, INSTANCE_KEY_PUB
default_instance_private_key_path = str(join(dirname(__file__), 'cert/instance.private.pem'))
instance_private_key = env('INSTANCE_KEY_RSA', None)
if instance_private_key is not None:
instance.private_key = load_key(str(instance_private_key))
elif isfile(default_instance_private_key_path):
instance.private_key = load_key(default_instance_private_key_path)
default_instance_public_key_path = str(join(dirname(__file__), 'cert/instance.public.pem'))
instance_public_key = env('INSTANCE_KEY_PUB', None)
if instance_public_key is not None:
instance.public_key = load_key(str(instance_public_key))
elif isfile(default_instance_public_key_path):
instance.public_key = load_key(default_instance_public_key_path)
# TOKEN_EXPIRE_DELTA
token_expire_delta = env('TOKEN_EXPIRE_DAYS', None)
if token_expire_delta not in (None, 0):
instance.token_expire_delta = token_expire_delta * 86_400
token_expire_delta = env('TOKEN_EXPIRE_HOURS', None)
if token_expire_delta not in (None, 0):
instance.token_expire_delta = token_expire_delta * 3_600
# LEASE_EXPIRE_DELTA, LEASE_RENEWAL_DELTA
lease_expire_delta = env('LEASE_EXPIRE_DAYS', None)
if lease_expire_delta not in (None, 0):
instance.lease_expire_delta = lease_expire_delta * 86_400
lease_expire_delta = env('LEASE_EXPIRE_HOURS', None)
if lease_expire_delta not in (None, 0):
instance.lease_expire_delta = lease_expire_delta * 3_600
# LEASE_RENEWAL_PERIOD
lease_renewal_period = env('LEASE_RENEWAL_PERIOD', None)
if lease_renewal_period is not None:
instance.lease_renewal_period = lease_renewal_period
# todo: update site, instance
upgrade_1_x_to_2_0()

View File

@ -1,10 +1,17 @@
def load_file(filename) -> bytes:
import logging
logging.basicConfig()
def load_file(filename: str) -> bytes:
log = logging.getLogger(f'{__name__}')
log.debug(f'Loading contents of file "{filename}')
with open(filename, 'rb') as file:
content = file.read()
return content
def load_key(filename) -> "RsaKey":
def load_key(filename: str) -> "RsaKey":
try:
# Crypto | Cryptodome on Debian
from Crypto.PublicKey import RSA
@ -13,9 +20,23 @@ def load_key(filename) -> "RsaKey":
from Cryptodome.PublicKey import RSA
from Cryptodome.PublicKey.RSA import RsaKey
log = logging.getLogger(__name__)
log.debug(f'Importing RSA-Key from "{filename}"')
return RSA.import_key(extern_key=load_file(filename), passphrase=None)
def parse_key(content: bytes) -> "RsaKey":
try:
# Crypto | Cryptodome on Debian
from Crypto.PublicKey import RSA
from Crypto.PublicKey.RSA import RsaKey
except ModuleNotFoundError:
from Cryptodome.PublicKey import RSA
from Cryptodome.PublicKey.RSA import RsaKey
return RSA.import_key(extern_key=content, passphrase=None)
def generate_key() -> "RsaKey":
try:
# Crypto | Cryptodome on Debian
@ -24,5 +45,50 @@ def generate_key() -> "RsaKey":
except ModuleNotFoundError:
from Cryptodome.PublicKey import RSA
from Cryptodome.PublicKey.RSA import RsaKey
log = logging.getLogger(__name__)
log.debug(f'Generating RSA-Key')
return RSA.generate(bits=2048)
class NV:
__DRIVER_MATRIX_FILENAME = 'static/driver_matrix.json'
__DRIVER_MATRIX: None | dict = None # https://docs.nvidia.com/grid/ => "Driver Versions"
def __init__(self):
self.log = logging.getLogger(self.__class__.__name__)
if NV.__DRIVER_MATRIX is None:
from json import load as json_load
try:
file = open(NV.__DRIVER_MATRIX_FILENAME)
NV.__DRIVER_MATRIX = json_load(file)
file.close()
self.log.debug(f'Successfully loaded "{NV.__DRIVER_MATRIX_FILENAME}".')
except Exception as e:
NV.__DRIVER_MATRIX = {} # init empty dict to not try open file everytime, just when restarting app
# self.log.warning(f'Failed to load "{NV.__DRIVER_MATRIX_FILENAME}": {e}')
@staticmethod
def find(version: str) -> dict | None:
if NV.__DRIVER_MATRIX is None:
return None
for idx, (key, branch) in enumerate(NV.__DRIVER_MATRIX.items()):
for release in branch.get('$releases'):
linux_driver = release.get('Linux Driver')
windows_driver = release.get('Windows Driver')
if version == linux_driver or version == windows_driver:
tmp = branch.copy()
tmp.pop('$releases')
is_latest = release.get('vGPU Software') == branch.get('Latest Release in Branch')
return {
'software_branch': branch.get('vGPU Software Branch'),
'branch_version': release.get('vGPU Software'),
'driver_branch': branch.get('Driver Branch'),
'branch_status': branch.get('vGPU Branch Status'),
'release_date': release.get('Release Date'),
'eol': branch.get('EOL Date') if is_latest else None,
'is_latest': is_latest,
}
return None

View File

@ -0,0 +1,137 @@
import logging
logging.basicConfig()
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
URL = 'https://docs.nvidia.com/grid/'
BRANCH_STATUS_KEY, SOFTWARE_BRANCH_KEY, = 'vGPU Branch Status', 'vGPU Software Branch'
VGPU_KEY, GRID_KEY, DRIVER_BRANCH_KEY = 'vGPU Software', 'vGPU Software', 'Driver Branch'
LINUX_VGPU_MANAGER_KEY, LINUX_DRIVER_KEY = 'Linux vGPU Manager', 'Linux Driver'
WINDOWS_VGPU_MANAGER_KEY, WINDOWS_DRIVER_KEY = 'Windows vGPU Manager', 'Windows Driver'
ALT_VGPU_MANAGER_KEY = 'vGPU Manager'
RELEASE_DATE_KEY, LATEST_KEY, EOL_KEY = 'Release Date', 'Latest Release in Branch', 'EOL Date'
JSON_RELEASES_KEY = '$releases'
def __driver_versions(html: 'BeautifulSoup'):
def __strip(_: str) -> str:
# removes content after linebreak (e.g. "Hello\n World" to "Hello")
_ = _.strip()
tmp = _.split('\n')
if len(tmp) > 0:
return tmp[0]
return _
# find wrapper for "DriverVersions" and find tables
data = html.find('div', {'id': 'DriverVersions'})
tables = data.findAll('table')
for table in tables:
# parse software-branch (e.g. "vGPU software 17 Releases" and remove " Releases" for "matrix_key")
software_branch = table.parent.find_previous_sibling('button', {'class': 'accordion'}).text.strip()
software_branch = software_branch.replace(' Releases', '')
matrix_key = software_branch.lower()
# driver version info from table-heads (ths) and table-rows (trs)
ths, trs = table.find_all('th'), table.find_all('tr')
headers, releases = [header.text.strip() for header in ths], []
for trs in trs:
tds = trs.find_all('td')
if len(tds) == 0: # skip empty
continue
# create dict with table-heads as key and cell content as value
x = {headers[i]: __strip(cell.text) for i, cell in enumerate(tds)}
releases.append(x)
# add to matrix
MATRIX.update({matrix_key: {JSON_RELEASES_KEY: releases}})
def __release_branches(html: 'BeautifulSoup'):
# find wrapper for "AllReleaseBranches" and find table
data = html.find('div', {'id': 'AllReleaseBranches'})
table = data.find('table')
# branch releases info from table-heads (ths) and table-rows (trs)
ths, trs = table.find_all('th'), table.find_all('tr')
headers = [header.text.strip() for header in ths]
for trs in trs:
tds = trs.find_all('td')
if len(tds) == 0: # skip empty
continue
# create dict with table-heads as key and cell content as value
x = {headers[i]: cell.text.strip() for i, cell in enumerate(tds)}
# get matrix_key
software_branch = x.get(SOFTWARE_BRANCH_KEY)
matrix_key = software_branch.lower()
# add to matrix
MATRIX.update({matrix_key: MATRIX.get(matrix_key) | x})
def __debug():
# print table head
s = f'{SOFTWARE_BRANCH_KEY:^21} | {BRANCH_STATUS_KEY:^21} | {VGPU_KEY:^13} | {LINUX_VGPU_MANAGER_KEY:^21} | {LINUX_DRIVER_KEY:^21} | {WINDOWS_VGPU_MANAGER_KEY:^21} | {WINDOWS_DRIVER_KEY:^21} | {RELEASE_DATE_KEY:>21} | {EOL_KEY:>21}'
print(s)
# iterate over dict & format some variables to not overload table
for idx, (key, branch) in enumerate(MATRIX.items()):
branch_status = branch.get(BRANCH_STATUS_KEY)
branch_status = branch_status.replace('Branch ', '')
branch_status = branch_status.replace('Long-Term Support', 'LTS')
branch_status = branch_status.replace('Production', 'Prod.')
software_branch = branch.get(SOFTWARE_BRANCH_KEY).replace('NVIDIA ', '')
for release in branch.get(JSON_RELEASES_KEY):
version = release.get(VGPU_KEY, release.get(GRID_KEY, ''))
linux_manager = release.get(LINUX_VGPU_MANAGER_KEY, release.get(ALT_VGPU_MANAGER_KEY, ''))
linux_driver = release.get(LINUX_DRIVER_KEY)
windows_manager = release.get(WINDOWS_VGPU_MANAGER_KEY, release.get(ALT_VGPU_MANAGER_KEY, ''))
windows_driver = release.get(WINDOWS_DRIVER_KEY)
release_date = release.get(RELEASE_DATE_KEY)
is_latest = release.get(VGPU_KEY) == branch.get(LATEST_KEY)
version = f'{version} *' if is_latest else version
eol = branch.get(EOL_KEY) if is_latest else ''
s = f'{software_branch:^21} | {branch_status:^21} | {version:<13} | {linux_manager:<21} | {linux_driver:<21} | {windows_manager:<21} | {windows_driver:<21} | {release_date:>21} | {eol:>21}'
print(s)
def __dump(filename: str):
import json
file = open(filename, 'w')
json.dump(MATRIX, file)
file.close()
if __name__ == '__main__':
MATRIX = {}
try:
import httpx
from bs4 import BeautifulSoup
except Exception as e:
logger.error(f'Failed to import module: {e}')
logger.info('Run "pip install beautifulsoup4 httpx"')
exit(1)
r = httpx.get(URL)
if r.status_code != 200:
logger.error(f'Error loading "{URL}" with status code {r.status_code}.')
exit(2)
# parse html
soup = BeautifulSoup(r.text, features='html.parser')
# build matrix
__driver_versions(soup)
__release_branches(soup)
# debug output
__debug()
# dump data to file
__dump('../app/static/driver_matrix.json')

View File

@ -1,14 +1,15 @@
from os import getenv as env
from base64 import b64encode as b64enc
from hashlib import sha256
from calendar import timegm
from datetime import datetime
from os.path import dirname, join
from uuid import uuid4, UUID
from uuid import UUID, uuid4
from dateutil.relativedelta import relativedelta
from jose import jwt, jwk
from jose import jwt
from jose.constants import ALGORITHMS
from starlette.testclient import TestClient
from sqlalchemy import create_engine
import sys
# add relative path to use packages as they were in the app/ dir
@ -16,20 +17,23 @@ sys.path.append('../')
sys.path.append('../app')
from app import main
from app.util import load_key
from app.orm import init as db_init, migrate, Site, Instance
client = TestClient(main.app)
ORIGIN_REF, ALLOTMENT_REF, SECRET = str(uuid4()), '20000000-0000-0000-0000-000000000001', 'HelloWorld'
# INSTANCE_KEY_RSA = generate_key()
# INSTANCE_KEY_PUB = INSTANCE_KEY_RSA.public_key()
# fastapi setup
client = TestClient(main.app)
INSTANCE_KEY_RSA = load_key(str(join(dirname(__file__), '../app/cert/instance.private.pem')))
INSTANCE_KEY_PUB = load_key(str(join(dirname(__file__), '../app/cert/instance.public.pem')))
# database setup
db = create_engine(str(env('DATABASE', 'sqlite:///db.sqlite')))
db_init(db), migrate(db)
jwt_encode_key = jwk.construct(INSTANCE_KEY_RSA.export_key().decode('utf-8'), algorithm=ALGORITHMS.RS256)
jwt_decode_key = jwk.construct(INSTANCE_KEY_PUB.export_key().decode('utf-8'), algorithm=ALGORITHMS.RS256)
# test vars
DEFAULT_SITE, DEFAULT_INSTANCE = Site.get_default_site(db), Instance.get_default_instance(db)
SITE_KEY = DEFAULT_SITE.site_key
jwt_encode_key, jwt_decode_key = DEFAULT_INSTANCE.get_jwt_encode_key(), DEFAULT_INSTANCE.get_jwt_decode_key()
def __bearer_token(origin_ref: str) -> str:
@ -38,6 +42,12 @@ def __bearer_token(origin_ref: str) -> str:
return token
def test_initial_default_site_and_instance():
default_site, default_instance = Site.get_default_site(db), Instance.get_default_instance(db)
assert default_site.site_key == Site.INITIAL_SITE_KEY_XID
assert default_instance.instance_ref == Instance.DEFAULT_INSTANCE_REF
def test_index():
response = client.get('/')
assert response.status_code == 200
@ -153,8 +163,7 @@ def test_auth_v1_token():
"kid": "00000000-0000-0000-0000-000000000000"
}
payload = {
"auth_code": jwt.encode(payload, key=jwt_encode_key, headers={'kid': payload.get('kid')},
algorithm=ALGORITHMS.RS256),
"auth_code": jwt.encode(payload, key=jwt_encode_key, headers={'kid': payload.get('kid')}, algorithm=ALGORITHMS.RS256),
"code_verifier": SECRET,
}