Compare commits

..

13 Commits

Author SHA1 Message Date
e7102c4de6 fixed updates 2022-12-23 08:16:58 +01:00
d1db441df4 removed Auth 2022-12-23 08:16:34 +01:00
d5b51bd83c Merge branch 'dev' into sqlalchemy
# Conflicts:
#	app/main.py
2022-12-23 08:08:35 +01:00
3f71c88d48 added some test 2022-12-23 07:48:47 +01:00
a58549a162 .gitlab-ci.yml - fixed test cert path 2022-12-23 07:43:02 +01:00
2c1c9b63b4 .gitignore 2022-12-23 07:41:23 +01:00
3367977652 .gitlab-ci.yml - fixed cd into test 2022-12-23 07:41:18 +01:00
67ed6108a2 .gitlab-ci.yml - changed test image to bullseye 2022-12-23 07:40:27 +01:00
d5d156e70e .gitlab-ci.yml - create test certificates 2022-12-23 07:38:53 +01:00
906af9430a .gitlab-ci.yml - fixed installing dependencies 2022-12-23 07:36:33 +01:00
3f5e3b16c5 added api tests 2022-12-23 07:35:37 +01:00
9809bbdbd1 bump version to 0.6 2022-12-23 07:16:41 +01:00
a0b9eae15b main.py - fixed wrong "origin_ref" in CodeResponse
- fixed issue
- removed the now unnecessary table "auth"
2022-12-23 06:56:29 +01:00
6 changed files with 107 additions and 109 deletions

1
.gitignore vendored
View File

@ -3,3 +3,4 @@ venv/
.idea/ .idea/
app/*.sqlite* app/*.sqlite*
app/cert/*.* app/cert/*.*
.pytest_cache

View File

@ -16,9 +16,17 @@ build:
- docker push ${CI_REGISTRY}/${CI_PROJECT_PATH}/${CI_BUILD_REF_NAME}:${CI_BUILD_REF} - docker push ${CI_REGISTRY}/${CI_PROJECT_PATH}/${CI_BUILD_REF_NAME}:${CI_BUILD_REF}
test: test:
image: python:3.10-slim-bullseye
stage: test stage: test
before_script:
- pip install -r requirements.txt
- pip install pytest httpx
- mkdir -p app/cert
- openssl genrsa -out app/cert/instance.private.pem 2048
- openssl rsa -in app/cert/instance.private.pem -outform PEM -pubout -out app/cert/instance.public.pem
- cd test
script: script:
- echo "Nothing to do ..." - pytest main.py
deploy: deploy:
stage: deploy stage: deploy

View File

@ -22,7 +22,7 @@ from sqlalchemy.orm import sessionmaker
from Crypto.PublicKey import RSA from Crypto.PublicKey import RSA
from Crypto.PublicKey.RSA import RsaKey from Crypto.PublicKey.RSA import RsaKey
from orm import Origin, Lease, Auth from orm import Origin, Lease
logger = logging.getLogger() logger = logging.getLogger()
load_dotenv('../version.env') load_dotenv('../version.env')
@ -207,17 +207,13 @@ async def auth_v1_code(request: Request):
'iat': timegm(cur_time.timetuple()), 'iat': timegm(cur_time.timetuple()),
'exp': timegm(expires.timetuple()), 'exp': timegm(expires.timetuple()),
'challenge': j['code_challenge'], 'challenge': j['code_challenge'],
'origin_ref': j['code_challenge'], 'origin_ref': j['origin_ref'],
'key_ref': SITE_KEY_XID, 'key_ref': SITE_KEY_XID,
'kid': SITE_KEY_XID 'kid': SITE_KEY_XID
} }
auth_code = jws.sign(payload, key=jwt_encode_key, headers={'kid': payload.get('kid')}, algorithm=ALGORITHMS.RS256) auth_code = jws.sign(payload, key=jwt_encode_key, headers={'kid': payload.get('kid')}, algorithm=ALGORITHMS.RS256)
Auth.cleanup(db, origin_ref, cur_time - delta)
data = Auth(origin_ref=origin_ref, code_challenge=j['code_challenge'], expires=expires)
Auth.create(db, data)
response = { response = {
"auth_code": auth_code, "auth_code": auth_code,
"sync_timestamp": cur_time.isoformat(), "sync_timestamp": cur_time.isoformat(),
@ -235,13 +231,8 @@ async def auth_v1_token(request: Request):
j = json.loads((await request.body()).decode('utf-8')) j = json.loads((await request.body()).decode('utf-8'))
payload = jwt.decode(token=j['auth_code'], key=jwt_decode_key) payload = jwt.decode(token=j['auth_code'], key=jwt_decode_key)
code_challenge = payload['origin_ref'] origin_ref = payload['origin_ref']
logging.info(f'> [ auth ]: {origin_ref}: {j}')
entity = Auth.find_by_code_challenge(db, code_challenge)
if entity is None:
raise HTTPException(status_code=400, detail='code challenge not found')
origin_ref = entity.origin_ref
logging.info(f'> [ auth ]: {origin_ref} ({code_challenge}): {j}')
# validate the code challenge # validate the code challenge
if payload['challenge'] != b64enc(sha256(j['code_verifier'].encode('utf-8')).digest()).rstrip(b'=').decode('utf-8'): if payload['challenge'] != b64enc(sha256(j['code_verifier'].encode('utf-8')).digest()).rstrip(b'=').decode('utf-8'):
@ -256,7 +247,7 @@ async def auth_v1_token(request: Request):
'iss': 'https://cls.nvidia.org', 'iss': 'https://cls.nvidia.org',
'aud': 'https://cls.nvidia.org', 'aud': 'https://cls.nvidia.org',
'exp': timegm(access_expires_on.timetuple()), 'exp': timegm(access_expires_on.timetuple()),
'origin_ref': payload['origin_ref'], 'origin_ref': origin_ref,
'key_ref': SITE_KEY_XID, 'key_ref': SITE_KEY_XID,
'kid': SITE_KEY_XID, 'kid': SITE_KEY_XID,
} }
@ -277,14 +268,9 @@ async def auth_v1_token(request: Request):
async def leasing_v1_lessor(request: Request): async def leasing_v1_lessor(request: Request):
j, token = json.loads((await request.body()).decode('utf-8')), get_token(request) j, token = json.loads((await request.body()).decode('utf-8')), get_token(request)
code_challenge = token['origin_ref'] origin_ref = token['origin_ref']
scope_ref_list = j['scope_ref_list'] scope_ref_list = j['scope_ref_list']
logging.info(f'> [ create ]: {origin_ref}: create leases for scope_ref_list {scope_ref_list}')
entity = Auth.find_by_code_challenge(db, code_challenge)
if entity is None:
raise HTTPException(status_code=400, detail='code challenge not found')
origin_ref = entity.origin_ref
logging.info(f'> [ create ]: {origin_ref} ({code_challenge}): create leases for scope_ref_list {scope_ref_list}')
cur_time = datetime.utcnow() cur_time = datetime.utcnow()
lease_result_list = [] lease_result_list = []
@ -323,14 +309,10 @@ async def leasing_v1_lessor(request: Request):
async def leasing_v1_lessor_lease(request: Request): async def leasing_v1_lessor_lease(request: Request):
token = get_token(request) token = get_token(request)
code_challenge = token['origin_ref'] origin_ref = token['origin_ref']
entity = Auth.find_by_code_challenge(db, code_challenge) active_lease_list = list(map(lambda x: x['lease_ref'], db['lease'].find(origin_ref=origin_ref)))
if entity is None: logging.info(f'> [ leases ]: {origin_ref}: found {len(active_lease_list)} active leases')
raise HTTPException(status_code=400, detail='code challenge not found')
origin_ref = entity.origin_ref
active_lease_list = list(map(lambda x: x.lease_ref, Lease.find_by_origin_ref(db, origin_ref)))
logging.info(f'> [ leases ]: {origin_ref} ({code_challenge}): found {len(active_lease_list)} active leases')
cur_time = datetime.utcnow() cur_time = datetime.utcnow()
response = { response = {
@ -347,13 +329,8 @@ async def leasing_v1_lessor_lease(request: Request):
async def leasing_v1_lease_renew(request: Request, lease_ref: str): async def leasing_v1_lease_renew(request: Request, lease_ref: str):
token = get_token(request) token = get_token(request)
code_challenge = token['origin_ref'] origin_ref = token['origin_ref']
logging.info(f'> [ renew ]: {origin_ref}: renew {lease_ref}')
entity = Auth.find_by_code_challenge(db, code_challenge)
if entity is None:
raise HTTPException(status_code=400, detail='code challenge not found')
origin_ref = entity.origin_ref
logging.info(f'> [ renew ]: {origin_ref} ({code_challenge}): renew {lease_ref}')
entity = Lease.find_by_origin_ref_and_lease_ref(db, origin_ref, lease_ref) entity = Lease.find_by_origin_ref_and_lease_ref(db, origin_ref, lease_ref)
if entity is None: if entity is None:
@ -379,16 +356,11 @@ async def leasing_v1_lease_renew(request: Request, lease_ref: str):
async def leasing_v1_lessor_lease_remove(request: Request): async def leasing_v1_lessor_lease_remove(request: Request):
token = get_token(request) token = get_token(request)
code_challenge = token['origin_ref'] origin_ref = token['origin_ref']
entity = Auth.find_by_code_challenge(db, code_challenge)
if entity is None:
raise HTTPException(status_code=400, detail='code challenge not found')
origin_ref = entity.origin_ref
released_lease_list = list(map(lambda x: x.lease_ref, Lease.find_by_origin_ref(db, origin_ref))) released_lease_list = list(map(lambda x: x.lease_ref, Lease.find_by_origin_ref(db, origin_ref)))
deletions = Lease.ceanup(db, origin_ref) deletions = Lease.ceanup(db, origin_ref)
logging.info(f'> [ remove ]: {origin_ref} ({code_challenge}): removed {deletions} leases') logging.info(f'> [ remove ]: {origin_ref}: removed {deletions} leases')
cur_time = datetime.utcnow() cur_time = datetime.utcnow()
response = { response = {

View File

@ -55,71 +55,17 @@ class Origin(Base):
if entity is None: if entity is None:
session.add(origin) session.add(origin)
else: else:
session.execute(update(Origin).where(Origin.origin_ref == origin.origin_ref).values(**origin.values())) values = dict(
hostname=origin.hostname,
guest_driver_version=origin.guest_driver_version,
os_platform=origin.os_platform,
os_version=origin.os_version,
)
session.execute(update(Origin).where(Origin.origin_ref == origin.origin_ref).values(values))
session.flush() session.flush()
session.close() session.close()
class Auth(Base):
__tablename__ = "auth"
"""
CREATE TABLE auth (
id INTEGER NOT NULL,
origin_ref TEXT,
code_challenge TEXT,
expires DATETIME,
PRIMARY KEY (id)
);
"""
"""
20|B210CF72-FEC7-4440-9499-1156D1ACD13A|p8oeBJPumrosywezCQ6VQvI/J2LZbMRK0s+OfsqzAiI|2022-12-21 05:00:57.467359
61|723EA079-7B0C-4E25-A8D4-DD3E89F9D177|9Nnv5FMtV9nF8qYRtCKG5lfF23HGvvNCQvpCh3FUITo|2022-12-22 05:08:40.713022
65|230b0000-a356-4000-8a2b-0000564c0000|9PivDr3PYRfcdUgODBR5+gi2ZdAbmPb07yTO05uui4A|2022-12-22 06:22:27.409642
66|41720000-FA43-4000-9472-0000E8660000|VnyasehSayRX/2OD3YyP8Xn9nsIBVefZpscnIIj2Rpk|2022-12-22 08:58:04.279664
67|41720000-FA43-4000-9472-0000E8660000|uisrxDFKB8KuD+JvtgT1ol5pNm/GKKlhO69u2ntg7z0|2022-12-22 08:59:37.509520
68|908B202D-CC43-420F-A2EF-FC092AAE8D38|VtWk7It+k33FxiGjm9rlSgAg1ZigfreFJd/0tt30FgQ|2022-12-22 09:43:56.680163
"""
_table_args__ = (
# this can be db.PrimaryKeyConstraint if you want it to be a primary key
UniqueConstraint('origin_ref', 'code_challenge'),
)
origin_ref = Column(CHAR(length=36), ForeignKey(Origin.origin_ref), primary_key=True, nullable=False, index=True)
code_challenge = Column(VARCHAR(length=43), primary_key=True, nullable=False)
expires = Column(DATETIME(), nullable=False)
def __repr__(self):
return f'Auth(origin_ref={self.origin_ref}, code_challenge={self.code_challenge}, expires={self.expires})'
@staticmethod
def create_statement(engine: Engine):
from sqlalchemy.schema import CreateTable
return CreateTable(Auth.__table__).compile(engine)
@staticmethod
def create(engine: Engine, auth: "Auth"):
session = sessionmaker(autocommit=True, autoflush=True, bind=engine)()
session.add(auth)
session.flush()
session.close()
@staticmethod
def cleanup(engine: Engine, origin_ref: str, older_than: datetime.datetime):
session = sessionmaker(autocommit=True, autoflush=True, bind=engine)()
session.execute(delete(Auth).where(and_(Auth.origin_ref == origin_ref, Auth.expires <= older_than)))
session.close()
@staticmethod
def find_by_code_challenge(engine: Engine, code_challenge: str) -> "Auth":
session = sessionmaker(autocommit=True, autoflush=True, bind=engine)()
entity = session.query(Auth).filter(Auth.code_challenge == code_challenge).first()
session.close()
return entity
class Lease(Base): class Lease(Base):
__tablename__ = "lease" __tablename__ = "lease"
@ -164,7 +110,8 @@ class Lease(Base):
if entity is None: if entity is None:
session.add(lease) session.add(lease)
else: else:
session.execute(update(Lease).where(and_(Lease.origin_ref == lease.origin_ref, Lease.lease_ref == lease.lease_ref)).values(**lease.values())) values = dict(lease_expires=lease.lease_expires, lease_updated=lease.lease_updated)
session.execute(update(Lease).where(and_(Lease.origin_ref == lease.origin_ref, Lease.lease_ref == lease.lease_ref)).values(values))
session.flush() session.flush()
session.close() session.close()
@ -185,9 +132,8 @@ class Lease(Base):
@staticmethod @staticmethod
def renew(engine: Engine, lease: "Lease", lease_expires: datetime.datetime, lease_updated: datetime.datetime): def renew(engine: Engine, lease: "Lease", lease_expires: datetime.datetime, lease_updated: datetime.datetime):
session = sessionmaker(autocommit=True, autoflush=True, bind=engine)() session = sessionmaker(autocommit=True, autoflush=True, bind=engine)()
lease.lease_expires = lease_expires values = dict(lease_expires=lease.lease_expires, lease_updated=lease.lease_updated)
lease.lease_updated = lease_updated session.execute(update(Lease).where(and_(Lease.origin_ref == lease.origin_ref, Lease.lease_ref == lease.lease_ref)).values(values))
session.execute(update(Lease).where(and_(Lease.origin_ref == lease.origin_ref, Lease.lease_ref == lease.lease_ref)).values(**lease.values()))
session.close() session.close()
@staticmethod @staticmethod

71
test/main.py Normal file
View File

@ -0,0 +1,71 @@
from starlette.testclient import TestClient
import importlib.util
import sys
MODULE, PATH = 'main.app', '../app/main.py'
spec = importlib.util.spec_from_file_location(MODULE, PATH)
main = importlib.util.module_from_spec(spec)
sys.modules[MODULE] = main
spec.loader.exec_module(main)
client = TestClient(main.app)
def test_index():
response = client.get('/')
assert response.status_code == 200
def test_status():
response = client.get('/status')
assert response.status_code == 200
assert response.json()['status'] == 'up'
def test_client_token():
response = client.get('/client-token')
assert response.status_code == 200
def test_auth_v1_origin():
payload = {
"registration_pending": False,
"environment": {
"guest_driver_version": "guest_driver_version",
"hostname": "myhost",
"ip_address_list": ["192.168.1.123"],
"os_version": "os_version",
"os_platform": "os_platform",
"fingerprint": {"mac_address_list": ["ff:ff:ff:ff:ff:ff"]},
"host_driver_version": "host_driver_version"
},
"update_pending": False,
"candidate_origin_ref": "00112233-4455-6677-8899-aabbccddeeff"
}
response = client.post('/auth/v1/origin', json=payload)
assert response.status_code == 200
def test_auth_v1_code():
pass
def test_auth_v1_token():
pass
def test_leasing_v1_lessor():
pass
def test_leasing_v1_lessor_lease():
pass
def test_leasing_v1_lease_renew():
pass
def test_leasing_v1_lessor_lease_remove():
pass

View File

@ -1 +1 @@
VERSION=0.5 VERSION=0.6