created PrivateKey / PublicKey wrapper classes

This commit is contained in:
Oscar Krause 2025-03-18 09:43:44 +01:00
parent 958f23f79d
commit fd46eecfb3
3 changed files with 65 additions and 60 deletions

View File

@ -21,7 +21,7 @@ from starlette.middleware.cors import CORSMiddleware
from starlette.responses import StreamingResponse, JSONResponse as JSONr, HTMLResponse as HTMLr, Response, RedirectResponse from starlette.responses import StreamingResponse, JSONResponse as JSONr, HTMLResponse as HTMLr, Response, RedirectResponse
from orm import Origin, Lease, init as db_init, migrate from orm import Origin, Lease, init as db_init, migrate
from util import load_private_key, load_public_key, get_pem, load_file from util import PrivateKey, PublicKey, load_file
# Load variables # Load variables
load_dotenv('../version.env') load_dotenv('../version.env')
@ -42,8 +42,8 @@ DLS_PORT = int(env('DLS_PORT', '443'))
SITE_KEY_XID = str(env('SITE_KEY_XID', '00000000-0000-0000-0000-000000000000')) SITE_KEY_XID = str(env('SITE_KEY_XID', '00000000-0000-0000-0000-000000000000'))
INSTANCE_REF = str(env('INSTANCE_REF', '10000000-0000-0000-0000-000000000001')) INSTANCE_REF = str(env('INSTANCE_REF', '10000000-0000-0000-0000-000000000001'))
ALLOTMENT_REF = str(env('ALLOTMENT_REF', '20000000-0000-0000-0000-000000000001')) ALLOTMENT_REF = str(env('ALLOTMENT_REF', '20000000-0000-0000-0000-000000000001'))
INSTANCE_KEY_RSA = load_private_key(str(env('INSTANCE_KEY_RSA', join(dirname(__file__), 'cert/instance.private.pem')))) INSTANCE_KEY_RSA = PrivateKey(str(env('INSTANCE_KEY_RSA', join(dirname(__file__), 'cert/instance.private.pem'))))
INSTANCE_KEY_PUB = load_public_key(str(env('INSTANCE_KEY_PUB', join(dirname(__file__), 'cert/instance.public.pem')))) INSTANCE_KEY_PUB = PublicKey(str(env('INSTANCE_KEY_PUB', join(dirname(__file__), 'cert/instance.public.pem'))))
TOKEN_EXPIRE_DELTA = relativedelta(days=int(env('TOKEN_EXPIRE_DAYS', 1)), hours=int(env('TOKEN_EXPIRE_HOURS', 0))) TOKEN_EXPIRE_DELTA = relativedelta(days=int(env('TOKEN_EXPIRE_DAYS', 1)), hours=int(env('TOKEN_EXPIRE_HOURS', 0)))
LEASE_EXPIRE_DELTA = relativedelta(days=int(env('LEASE_EXPIRE_DAYS', 90)), hours=int(env('LEASE_EXPIRE_HOURS', 0))) LEASE_EXPIRE_DELTA = relativedelta(days=int(env('LEASE_EXPIRE_DAYS', 90)), hours=int(env('LEASE_EXPIRE_HOURS', 0)))
LEASE_RENEWAL_PERIOD = float(env('LEASE_RENEWAL_PERIOD', 0.15)) LEASE_RENEWAL_PERIOD = float(env('LEASE_RENEWAL_PERIOD', 0.15))
@ -51,8 +51,8 @@ LEASE_RENEWAL_DELTA = timedelta(days=int(env('LEASE_EXPIRE_DAYS', 90)), hours=in
CLIENT_TOKEN_EXPIRE_DELTA = relativedelta(years=12) CLIENT_TOKEN_EXPIRE_DELTA = relativedelta(years=12)
CORS_ORIGINS = str(env('CORS_ORIGINS', '')).split(',') if (env('CORS_ORIGINS')) else [f'https://{DLS_URL}'] CORS_ORIGINS = str(env('CORS_ORIGINS', '')).split(',') if (env('CORS_ORIGINS')) else [f'https://{DLS_URL}']
jwt_encode_key = jwk.construct(get_pem(INSTANCE_KEY_RSA), algorithm=ALGORITHMS.RS256) jwt_encode_key = jwk.construct(INSTANCE_KEY_RSA.pem(), algorithm=ALGORITHMS.RS256)
jwt_decode_key = jwk.construct(get_pem(INSTANCE_KEY_PUB), algorithm=ALGORITHMS.RS256) jwt_decode_key = jwk.construct(INSTANCE_KEY_PUB.pem(), algorithm=ALGORITHMS.RS256)
# Logging # Logging
LOG_LEVEL = logging.DEBUG if DEBUG else logging.INFO LOG_LEVEL = logging.DEBUG if DEBUG else logging.INFO
@ -264,10 +264,10 @@ async def _client_token():
}, },
"service_instance_public_key_configuration": { "service_instance_public_key_configuration": {
"service_instance_public_key_me": { "service_instance_public_key_me": {
"mod": hex(INSTANCE_KEY_PUB.public_numbers().n)[2:], "mod": hex(INSTANCE_KEY_PUB.raw().public_numbers().n)[2:],
"exp": int(INSTANCE_KEY_PUB.public_numbers().e), "exp": int(INSTANCE_KEY_PUB.raw().public_numbers().e),
}, },
"service_instance_public_key_pem": get_pem(INSTANCE_KEY_PUB).decode('utf-8'), "service_instance_public_key_pem": INSTANCE_KEY_PUB.pem().decode('utf-8'),
"key_retention_mode": "LATEST_ONLY" "key_retention_mode": "LATEST_ONLY"
}, },
} }

View File

@ -1,8 +1,60 @@
import logging import logging
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey, RSAPublicKey, generate_private_key
from cryptography.hazmat.primitives.serialization import load_pem_private_key, load_pem_public_key
logging.basicConfig() logging.basicConfig()
class PrivateKey:
def __init__(self, filename: str):
log = logging.getLogger(__name__)
log.debug(f'Importing RSA-Key from "{filename}"')
with open(filename, 'rb') as f:
data = f.read()
self.key = load_pem_private_key(data.strip(), password=None)
def raw(self) -> RSAPrivateKey:
return self.key
def pem(self) -> bytes:
return self.key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption()
)
@staticmethod
def generate(public_exponent: int = 65537, key_size: int = 2048) -> RSAPrivateKey:
log = logging.getLogger(__name__)
log.debug(f'Generating RSA-Key')
return generate_private_key(public_exponent=public_exponent, key_size=key_size)
class PublicKey:
def __init__(self, filename: str):
log = logging.getLogger(__name__)
log.debug(f'Importing RSA-Key from "{filename}"')
with open(filename, 'rb') as f:
data = f.read()
self.key = load_pem_public_key(data.strip())
def raw(self) -> RSAPublicKey:
return self.key
def pem(self) -> bytes:
return self.key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo
)
def load_file(filename: str) -> bytes: def load_file(filename: str) -> bytes:
log = logging.getLogger(f'{__name__}') log = logging.getLogger(f'{__name__}')
log.debug(f'Loading contents of file "{filename}') log.debug(f'Loading contents of file "{filename}')
@ -11,53 +63,6 @@ def load_file(filename: str) -> bytes:
return content return content
def load_private_key(filename: str) -> "RSAPrivateKey":
from cryptography.hazmat.primitives.serialization import load_pem_private_key
log = logging.getLogger(__name__)
log.debug(f'Importing RSA-Key from "{filename}"')
with open(filename, 'rb') as f:
data = f.read()
return load_pem_private_key(data.strip(), password=None)
def load_public_key(filename: str) -> "RSAPublicKey":
from cryptography.hazmat.primitives.serialization import load_pem_public_key
log = logging.getLogger(__name__)
log.debug(f'Importing RSA-Key from "{filename}"')
with open(filename, 'rb') as f:
data = f.read()
return load_pem_public_key(data.strip())
def get_pem(key) -> bytes | None:
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey, RSAPublicKey
from cryptography.hazmat.primitives import serialization
if isinstance(key, RSAPrivateKey):
return key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption()
)
if isinstance(key, RSAPublicKey):
return key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo
)
def generate_private_key() -> "RSAPrivateKey":
from cryptography.hazmat.primitives.asymmetric import rsa
log = logging.getLogger(__name__)
log.debug(f'Generating RSA-Key')
return rsa.generate_private_key(public_exponent=65537, key_size=2048)
class NV: class NV:
__DRIVER_MATRIX_FILENAME = 'static/driver_matrix.json' __DRIVER_MATRIX_FILENAME = 'static/driver_matrix.json'
__DRIVER_MATRIX: None | dict = None # https://docs.nvidia.com/grid/ => "Driver Versions" __DRIVER_MATRIX: None | dict = None # https://docs.nvidia.com/grid/ => "Driver Versions"

View File

@ -16,7 +16,7 @@ sys.path.append('../')
sys.path.append('../app') sys.path.append('../app')
from app import main from app import main
from util import load_private_key, load_public_key, get_pem from util import PrivateKey, PublicKey
client = TestClient(main.app) client = TestClient(main.app)
@ -25,11 +25,11 @@ ORIGIN_REF, ALLOTMENT_REF, SECRET = str(uuid4()), '20000000-0000-0000-0000-00000
# INSTANCE_KEY_RSA = generate_key() # INSTANCE_KEY_RSA = generate_key()
# INSTANCE_KEY_PUB = INSTANCE_KEY_RSA.public_key() # INSTANCE_KEY_PUB = INSTANCE_KEY_RSA.public_key()
INSTANCE_KEY_RSA = load_private_key(str(join(dirname(__file__), '../app/cert/instance.private.pem'))) INSTANCE_KEY_RSA = PrivateKey(str(join(dirname(__file__), '../app/cert/instance.private.pem')))
INSTANCE_KEY_PUB = load_public_key(str(join(dirname(__file__), '../app/cert/instance.public.pem'))) INSTANCE_KEY_PUB = PublicKey(str(join(dirname(__file__), '../app/cert/instance.public.pem')))
jwt_encode_key = jwk.construct(get_pem(INSTANCE_KEY_RSA), algorithm=ALGORITHMS.RS256) jwt_encode_key = jwk.construct(INSTANCE_KEY_RSA.pem(), algorithm=ALGORITHMS.RS256)
jwt_decode_key = jwk.construct(get_pem(INSTANCE_KEY_PUB), algorithm=ALGORITHMS.RS256) jwt_decode_key = jwk.construct(INSTANCE_KEY_PUB.pem(), algorithm=ALGORITHMS.RS256)
def __bearer_token(origin_ref: str) -> str: def __bearer_token(origin_ref: str) -> str: