Compare commits

..

No commits in common. "e7102c4de655756102d7b76d84de9aebc973a9dc" and "394180652ea8c5e7215307fd78a1525d6ca19319" have entirely different histories.

6 changed files with 109 additions and 107 deletions

1
.gitignore vendored
View File

@ -3,4 +3,3 @@ venv/
.idea/ .idea/
app/*.sqlite* app/*.sqlite*
app/cert/*.* app/cert/*.*
.pytest_cache

View File

@ -16,17 +16,9 @@ build:
- docker push ${CI_REGISTRY}/${CI_PROJECT_PATH}/${CI_BUILD_REF_NAME}:${CI_BUILD_REF} - docker push ${CI_REGISTRY}/${CI_PROJECT_PATH}/${CI_BUILD_REF_NAME}:${CI_BUILD_REF}
test: test:
image: python:3.10-slim-bullseye
stage: test stage: test
before_script:
- pip install -r requirements.txt
- pip install pytest httpx
- mkdir -p app/cert
- openssl genrsa -out app/cert/instance.private.pem 2048
- openssl rsa -in app/cert/instance.private.pem -outform PEM -pubout -out app/cert/instance.public.pem
- cd test
script: script:
- pytest main.py - echo "Nothing to do ..."
deploy: deploy:
stage: deploy stage: deploy

View File

@ -22,7 +22,7 @@ from sqlalchemy.orm import sessionmaker
from Crypto.PublicKey import RSA from Crypto.PublicKey import RSA
from Crypto.PublicKey.RSA import RsaKey from Crypto.PublicKey.RSA import RsaKey
from orm import Origin, Lease from orm import Origin, Lease, Auth
logger = logging.getLogger() logger = logging.getLogger()
load_dotenv('../version.env') load_dotenv('../version.env')
@ -207,13 +207,17 @@ async def auth_v1_code(request: Request):
'iat': timegm(cur_time.timetuple()), 'iat': timegm(cur_time.timetuple()),
'exp': timegm(expires.timetuple()), 'exp': timegm(expires.timetuple()),
'challenge': j['code_challenge'], 'challenge': j['code_challenge'],
'origin_ref': j['origin_ref'], 'origin_ref': j['code_challenge'],
'key_ref': SITE_KEY_XID, 'key_ref': SITE_KEY_XID,
'kid': SITE_KEY_XID 'kid': SITE_KEY_XID
} }
auth_code = jws.sign(payload, key=jwt_encode_key, headers={'kid': payload.get('kid')}, algorithm=ALGORITHMS.RS256) auth_code = jws.sign(payload, key=jwt_encode_key, headers={'kid': payload.get('kid')}, algorithm=ALGORITHMS.RS256)
Auth.cleanup(db, origin_ref, cur_time - delta)
data = Auth(origin_ref=origin_ref, code_challenge=j['code_challenge'], expires=expires)
Auth.create(db, data)
response = { response = {
"auth_code": auth_code, "auth_code": auth_code,
"sync_timestamp": cur_time.isoformat(), "sync_timestamp": cur_time.isoformat(),
@ -231,8 +235,13 @@ async def auth_v1_token(request: Request):
j = json.loads((await request.body()).decode('utf-8')) j = json.loads((await request.body()).decode('utf-8'))
payload = jwt.decode(token=j['auth_code'], key=jwt_decode_key) payload = jwt.decode(token=j['auth_code'], key=jwt_decode_key)
origin_ref = payload['origin_ref'] code_challenge = payload['origin_ref']
logging.info(f'> [ auth ]: {origin_ref}: {j}')
entity = Auth.find_by_code_challenge(db, code_challenge)
if entity is None:
raise HTTPException(status_code=400, detail='code challenge not found')
origin_ref = entity.origin_ref
logging.info(f'> [ auth ]: {origin_ref} ({code_challenge}): {j}')
# validate the code challenge # validate the code challenge
if payload['challenge'] != b64enc(sha256(j['code_verifier'].encode('utf-8')).digest()).rstrip(b'=').decode('utf-8'): if payload['challenge'] != b64enc(sha256(j['code_verifier'].encode('utf-8')).digest()).rstrip(b'=').decode('utf-8'):
@ -247,7 +256,7 @@ async def auth_v1_token(request: Request):
'iss': 'https://cls.nvidia.org', 'iss': 'https://cls.nvidia.org',
'aud': 'https://cls.nvidia.org', 'aud': 'https://cls.nvidia.org',
'exp': timegm(access_expires_on.timetuple()), 'exp': timegm(access_expires_on.timetuple()),
'origin_ref': origin_ref, 'origin_ref': payload['origin_ref'],
'key_ref': SITE_KEY_XID, 'key_ref': SITE_KEY_XID,
'kid': SITE_KEY_XID, 'kid': SITE_KEY_XID,
} }
@ -268,9 +277,14 @@ async def auth_v1_token(request: Request):
async def leasing_v1_lessor(request: Request): async def leasing_v1_lessor(request: Request):
j, token = json.loads((await request.body()).decode('utf-8')), get_token(request) j, token = json.loads((await request.body()).decode('utf-8')), get_token(request)
origin_ref = token['origin_ref'] code_challenge = token['origin_ref']
scope_ref_list = j['scope_ref_list'] scope_ref_list = j['scope_ref_list']
logging.info(f'> [ create ]: {origin_ref}: create leases for scope_ref_list {scope_ref_list}')
entity = Auth.find_by_code_challenge(db, code_challenge)
if entity is None:
raise HTTPException(status_code=400, detail='code challenge not found')
origin_ref = entity.origin_ref
logging.info(f'> [ create ]: {origin_ref} ({code_challenge}): create leases for scope_ref_list {scope_ref_list}')
cur_time = datetime.utcnow() cur_time = datetime.utcnow()
lease_result_list = [] lease_result_list = []
@ -309,10 +323,14 @@ async def leasing_v1_lessor(request: Request):
async def leasing_v1_lessor_lease(request: Request): async def leasing_v1_lessor_lease(request: Request):
token = get_token(request) token = get_token(request)
origin_ref = token['origin_ref'] code_challenge = token['origin_ref']
active_lease_list = list(map(lambda x: x['lease_ref'], db['lease'].find(origin_ref=origin_ref))) entity = Auth.find_by_code_challenge(db, code_challenge)
logging.info(f'> [ leases ]: {origin_ref}: found {len(active_lease_list)} active leases') if entity is None:
raise HTTPException(status_code=400, detail='code challenge not found')
origin_ref = entity.origin_ref
active_lease_list = list(map(lambda x: x.lease_ref, Lease.find_by_origin_ref(db, origin_ref)))
logging.info(f'> [ leases ]: {origin_ref} ({code_challenge}): found {len(active_lease_list)} active leases')
cur_time = datetime.utcnow() cur_time = datetime.utcnow()
response = { response = {
@ -329,8 +347,13 @@ async def leasing_v1_lessor_lease(request: Request):
async def leasing_v1_lease_renew(request: Request, lease_ref: str): async def leasing_v1_lease_renew(request: Request, lease_ref: str):
token = get_token(request) token = get_token(request)
origin_ref = token['origin_ref'] code_challenge = token['origin_ref']
logging.info(f'> [ renew ]: {origin_ref}: renew {lease_ref}')
entity = Auth.find_by_code_challenge(db, code_challenge)
if entity is None:
raise HTTPException(status_code=400, detail='code challenge not found')
origin_ref = entity.origin_ref
logging.info(f'> [ renew ]: {origin_ref} ({code_challenge}): renew {lease_ref}')
entity = Lease.find_by_origin_ref_and_lease_ref(db, origin_ref, lease_ref) entity = Lease.find_by_origin_ref_and_lease_ref(db, origin_ref, lease_ref)
if entity is None: if entity is None:
@ -356,11 +379,16 @@ async def leasing_v1_lease_renew(request: Request, lease_ref: str):
async def leasing_v1_lessor_lease_remove(request: Request): async def leasing_v1_lessor_lease_remove(request: Request):
token = get_token(request) token = get_token(request)
origin_ref = token['origin_ref'] code_challenge = token['origin_ref']
entity = Auth.find_by_code_challenge(db, code_challenge)
if entity is None:
raise HTTPException(status_code=400, detail='code challenge not found')
origin_ref = entity.origin_ref
released_lease_list = list(map(lambda x: x.lease_ref, Lease.find_by_origin_ref(db, origin_ref))) released_lease_list = list(map(lambda x: x.lease_ref, Lease.find_by_origin_ref(db, origin_ref)))
deletions = Lease.ceanup(db, origin_ref) deletions = Lease.ceanup(db, origin_ref)
logging.info(f'> [ remove ]: {origin_ref}: removed {deletions} leases') logging.info(f'> [ remove ]: {origin_ref} ({code_challenge}): removed {deletions} leases')
cur_time = datetime.utcnow() cur_time = datetime.utcnow()
response = { response = {

View File

@ -55,17 +55,71 @@ class Origin(Base):
if entity is None: if entity is None:
session.add(origin) session.add(origin)
else: else:
values = dict( session.execute(update(Origin).where(Origin.origin_ref == origin.origin_ref).values(**origin.values()))
hostname=origin.hostname,
guest_driver_version=origin.guest_driver_version,
os_platform=origin.os_platform,
os_version=origin.os_version,
)
session.execute(update(Origin).where(Origin.origin_ref == origin.origin_ref).values(values))
session.flush() session.flush()
session.close() session.close()
class Auth(Base):
__tablename__ = "auth"
"""
CREATE TABLE auth (
id INTEGER NOT NULL,
origin_ref TEXT,
code_challenge TEXT,
expires DATETIME,
PRIMARY KEY (id)
);
"""
"""
20|B210CF72-FEC7-4440-9499-1156D1ACD13A|p8oeBJPumrosywezCQ6VQvI/J2LZbMRK0s+OfsqzAiI|2022-12-21 05:00:57.467359
61|723EA079-7B0C-4E25-A8D4-DD3E89F9D177|9Nnv5FMtV9nF8qYRtCKG5lfF23HGvvNCQvpCh3FUITo|2022-12-22 05:08:40.713022
65|230b0000-a356-4000-8a2b-0000564c0000|9PivDr3PYRfcdUgODBR5+gi2ZdAbmPb07yTO05uui4A|2022-12-22 06:22:27.409642
66|41720000-FA43-4000-9472-0000E8660000|VnyasehSayRX/2OD3YyP8Xn9nsIBVefZpscnIIj2Rpk|2022-12-22 08:58:04.279664
67|41720000-FA43-4000-9472-0000E8660000|uisrxDFKB8KuD+JvtgT1ol5pNm/GKKlhO69u2ntg7z0|2022-12-22 08:59:37.509520
68|908B202D-CC43-420F-A2EF-FC092AAE8D38|VtWk7It+k33FxiGjm9rlSgAg1ZigfreFJd/0tt30FgQ|2022-12-22 09:43:56.680163
"""
_table_args__ = (
# this can be db.PrimaryKeyConstraint if you want it to be a primary key
UniqueConstraint('origin_ref', 'code_challenge'),
)
origin_ref = Column(CHAR(length=36), ForeignKey(Origin.origin_ref), primary_key=True, nullable=False, index=True)
code_challenge = Column(VARCHAR(length=43), primary_key=True, nullable=False)
expires = Column(DATETIME(), nullable=False)
def __repr__(self):
return f'Auth(origin_ref={self.origin_ref}, code_challenge={self.code_challenge}, expires={self.expires})'
@staticmethod
def create_statement(engine: Engine):
from sqlalchemy.schema import CreateTable
return CreateTable(Auth.__table__).compile(engine)
@staticmethod
def create(engine: Engine, auth: "Auth"):
session = sessionmaker(autocommit=True, autoflush=True, bind=engine)()
session.add(auth)
session.flush()
session.close()
@staticmethod
def cleanup(engine: Engine, origin_ref: str, older_than: datetime.datetime):
session = sessionmaker(autocommit=True, autoflush=True, bind=engine)()
session.execute(delete(Auth).where(and_(Auth.origin_ref == origin_ref, Auth.expires <= older_than)))
session.close()
@staticmethod
def find_by_code_challenge(engine: Engine, code_challenge: str) -> "Auth":
session = sessionmaker(autocommit=True, autoflush=True, bind=engine)()
entity = session.query(Auth).filter(Auth.code_challenge == code_challenge).first()
session.close()
return entity
class Lease(Base): class Lease(Base):
__tablename__ = "lease" __tablename__ = "lease"
@ -110,8 +164,7 @@ class Lease(Base):
if entity is None: if entity is None:
session.add(lease) session.add(lease)
else: else:
values = dict(lease_expires=lease.lease_expires, lease_updated=lease.lease_updated) session.execute(update(Lease).where(and_(Lease.origin_ref == lease.origin_ref, Lease.lease_ref == lease.lease_ref)).values(**lease.values()))
session.execute(update(Lease).where(and_(Lease.origin_ref == lease.origin_ref, Lease.lease_ref == lease.lease_ref)).values(values))
session.flush() session.flush()
session.close() session.close()
@ -132,8 +185,9 @@ class Lease(Base):
@staticmethod @staticmethod
def renew(engine: Engine, lease: "Lease", lease_expires: datetime.datetime, lease_updated: datetime.datetime): def renew(engine: Engine, lease: "Lease", lease_expires: datetime.datetime, lease_updated: datetime.datetime):
session = sessionmaker(autocommit=True, autoflush=True, bind=engine)() session = sessionmaker(autocommit=True, autoflush=True, bind=engine)()
values = dict(lease_expires=lease.lease_expires, lease_updated=lease.lease_updated) lease.lease_expires = lease_expires
session.execute(update(Lease).where(and_(Lease.origin_ref == lease.origin_ref, Lease.lease_ref == lease.lease_ref)).values(values)) lease.lease_updated = lease_updated
session.execute(update(Lease).where(and_(Lease.origin_ref == lease.origin_ref, Lease.lease_ref == lease.lease_ref)).values(**lease.values()))
session.close() session.close()
@staticmethod @staticmethod

View File

@ -1,71 +0,0 @@
from starlette.testclient import TestClient
import importlib.util
import sys
MODULE, PATH = 'main.app', '../app/main.py'
spec = importlib.util.spec_from_file_location(MODULE, PATH)
main = importlib.util.module_from_spec(spec)
sys.modules[MODULE] = main
spec.loader.exec_module(main)
client = TestClient(main.app)
def test_index():
response = client.get('/')
assert response.status_code == 200
def test_status():
response = client.get('/status')
assert response.status_code == 200
assert response.json()['status'] == 'up'
def test_client_token():
response = client.get('/client-token')
assert response.status_code == 200
def test_auth_v1_origin():
payload = {
"registration_pending": False,
"environment": {
"guest_driver_version": "guest_driver_version",
"hostname": "myhost",
"ip_address_list": ["192.168.1.123"],
"os_version": "os_version",
"os_platform": "os_platform",
"fingerprint": {"mac_address_list": ["ff:ff:ff:ff:ff:ff"]},
"host_driver_version": "host_driver_version"
},
"update_pending": False,
"candidate_origin_ref": "00112233-4455-6677-8899-aabbccddeeff"
}
response = client.post('/auth/v1/origin', json=payload)
assert response.status_code == 200
def test_auth_v1_code():
pass
def test_auth_v1_token():
pass
def test_leasing_v1_lessor():
pass
def test_leasing_v1_lessor_lease():
pass
def test_leasing_v1_lease_renew():
pass
def test_leasing_v1_lessor_lease_remove():
pass

View File

@ -1 +1 @@
VERSION=0.6 VERSION=0.5