thomasgauthier commited on
Commit
62b6778
1 Parent(s): d119bcb
Files changed (1) hide show
  1. main.py +17 -15
main.py CHANGED
@@ -386,34 +386,36 @@ async def get_random_sample(dataset_name: str = Query(..., alias="dataset-name")
386
 
387
  @app.get("/login/callback")
388
  async def oauth_callback(code: str, state: str):
389
- # Verify the state value here
390
- print(client_id)
 
 
 
391
 
392
- access_token = ""
393
  try:
394
  token_response = requests.post(
395
  'https://huggingface.co/oauth/token',
 
396
  data={
397
  'grant_type': 'authorization_code',
398
  'code': code,
399
  'redirect_uri': f'https://{space_host}/login/callback',
400
- 'client_id': client_id,
401
- 'client_secret': client_secret,
402
- 'client_assertion_type': 'urn:ietf:params:oauth:client-assertion-type:jwt-bearer'
403
  }
404
  )
405
  print(token_response.status_code, token_response.text)
406
- except Exception:
407
- traceback.print_exc()
408
 
409
- # # Fetch user information using access token
410
- # user_response = requests.get(
411
- # 'https://huggingface.co/api/user',
412
- # headers={'Authorization': f'Bearer {access_token}'}
413
- # )
414
- # user_data = user_response.json()
415
- # username = user_data['username']
416
 
 
 
 
417
 
418
  return {"access_token": access_token}
419
 
 
386
 
387
  @app.get("/login/callback")
388
  async def oauth_callback(code: str, state: str):
389
+ # Prepare the authorization header
390
+ credentials = f"{client_id}:{client_secret}"
391
+ credentials_bytes = credentials.encode("ascii")
392
+ base64_credentials = base64.b64encode(credentials_bytes)
393
+ auth_header = f"Basic {base64_credentials.decode('ascii')}"
394
 
 
395
  try:
396
  token_response = requests.post(
397
  'https://huggingface.co/oauth/token',
398
+ headers={'Authorization': auth_header},
399
  data={
400
  'grant_type': 'authorization_code',
401
  'code': code,
402
  'redirect_uri': f'https://{space_host}/login/callback',
403
+ 'client_id': client_id
 
 
404
  }
405
  )
406
  print(token_response.status_code, token_response.text)
 
 
407
 
408
+ if token_response.status_code == 200:
409
+ tokens = token_response.json()
410
+ access_token = tokens.get('access_token')
411
+ # ID Token can be extracted here if needed
412
+ # id_token = tokens.get('id_token')
413
+ else:
414
+ access_token = ""
415
 
416
+ except Exception:
417
+ traceback.print_exc()
418
+ access_token = ""
419
 
420
  return {"access_token": access_token}
421