thomasgauthier commited on
Commit
b6d296b
β€’
1 Parent(s): 7a94c8b

oauth test flow

Browse files
main.py CHANGED
@@ -5,6 +5,8 @@ from fastapi import FastAPI, BackgroundTasks, HTTPException, Query
5
  from fastapi.responses import StreamingResponse
6
  from starlette.concurrency import run_in_threadpool
7
  from datasets import load_dataset
 
 
8
  import random
9
  import json
10
  from genson import SchemaBuilder
@@ -37,6 +39,8 @@ client = OpenAI(
37
  api_key=os.environ.get('OPENROUTER_KEY')
38
  )
39
 
 
 
40
  state_queue_map = {}
41
 
42
  def is_sharegpt(sample):
@@ -412,6 +416,34 @@ async def get_oauth_config(request: Request):
412
  }
413
 
414
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
415
  @app.get("/")
416
  def index() -> FileResponse:
417
  return FileResponse(path="static/index.html", media_type="text/html")
 
5
  from fastapi.responses import StreamingResponse
6
  from starlette.concurrency import run_in_threadpool
7
  from datasets import load_dataset
8
+ from fastapi import FastAPI, Depends, HTTPException, status
9
+ from fastapi.security import OAuth2PasswordBearer
10
  import random
11
  import json
12
  from genson import SchemaBuilder
 
39
  api_key=os.environ.get('OPENROUTER_KEY')
40
  )
41
 
42
+ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
43
+
44
  state_queue_map = {}
45
 
46
  def is_sharegpt(sample):
 
416
  }
417
 
418
 
419
+ async def get_current_user(token: str = Depends(oauth2_scheme)):
420
+ if not token:
421
+ raise HTTPException(
422
+ status_code=status.HTTP_401_UNAUTHORIZED,
423
+ detail="Missing token",
424
+ headers={"WWW-Authenticate": "Bearer"},
425
+ )
426
+
427
+ url = "https://huggingface.co/oauth/userinfo"
428
+ headers = {"Authorization": f"Bearer {token}"}
429
+ response = requests.get(url, headers=headers)
430
+
431
+ if response.status_code != 200:
432
+ raise HTTPException(
433
+ status_code=status.HTTP_401_UNAUTHORIZED,
434
+ detail="Invalid token",
435
+ headers={"WWW-Authenticate": "Bearer"},
436
+ )
437
+
438
+ user_info = response.json()
439
+ return user_info
440
+
441
+ @app.get("/gated_route")
442
+ async def gated_route(current_user: str = Depends(get_current_user)):
443
+ # Your logic here. The endpoint will only be accessible if the token is valid
444
+ return {"message": "You are authorized to access this route"}
445
+
446
+
447
  @app.get("/")
448
  def index() -> FileResponse:
449
  return FileResponse(path="static/index.html", media_type="text/html")
static/assets/{index-7974ca0c.js β†’ index-025bf825.js} RENAMED
The diff for this file is too large to render. See raw diff
 
static/index.html CHANGED
@@ -5,7 +5,7 @@
5
  <link rel="icon" type="image/svg+xml" href="/vite.svg" />
6
  <meta name="viewport" content="width=device-width, initial-scale=1.0" />
7
  <title>Vite + Preact</title>
8
- <script type="module" crossorigin src="/assets/index-7974ca0c.js"></script>
9
  <link rel="stylesheet" href="/assets/index-abe6d7fb.css">
10
  </head>
11
  <body ondrop="event.preventDefault()" >
 
5
  <link rel="icon" type="image/svg+xml" href="/vite.svg" />
6
  <meta name="viewport" content="width=device-width, initial-scale=1.0" />
7
  <title>Vite + Preact</title>
8
+ <script type="module" crossorigin src="/assets/index-025bf825.js"></script>
9
  <link rel="stylesheet" href="/assets/index-abe6d7fb.css">
10
  </head>
11
  <body ondrop="event.preventDefault()" >