| import asyncio |
| import uvicorn |
| from fastapi import FastAPI, Depends |
| from starlette.responses import RedirectResponse |
| from starlette.middleware.sessions import SessionMiddleware |
| from authlib.integrations.starlette_client import OAuth, OAuthError |
| from fastapi import Request |
| import os |
| from starlette.config import Config |
| import gradio as gr |
| from dotenv import load_dotenv |
|
|
| from common import get_db |
| from home import build_home_page |
| from modules.models import SheamiUser |
| from ui import get_app_theme, get_app_title, get_css, render_about_markdowns, render_logo, render_logo_small, render_selected_patient_actions |
|
|
| load_dotenv() |
| app = FastAPI() |
|
|
| |
| GOOGLE_OAUTH_CLIENT_ID = os.environ.get("GOOGLE_OAUTH_CLIENT_ID") |
| GOOGLE_OAUTH_CLIENT_SECRET = os.environ.get("GOOGLE_OAUTH_CLIENT_SECRET") |
| SECRET_KEY = os.environ.get("AUTH_SECRET_KEY") |
|
|
| |
| config_data = {'GOOGLE_CLIENT_ID': GOOGLE_OAUTH_CLIENT_ID, 'GOOGLE_CLIENT_SECRET': GOOGLE_OAUTH_CLIENT_SECRET} |
| starlette_config = Config(environ=config_data) |
| oauth = OAuth(starlette_config) |
| oauth.register( |
| name='google', |
| server_metadata_url='https://accounts.google.com/.well-known/openid-configuration', |
| client_kwargs={'scope': 'openid email profile'}, |
| ) |
|
|
| app.add_middleware(SessionMiddleware, secret_key=SECRET_KEY) |
|
|
| |
| def get_user(request: Request): |
| user = request.session.get('user') |
| if user: |
| return user['name'] |
| return None |
|
|
| @app.get('/') |
| def public(request: Request, user = Depends(get_user)): |
| root_url = gr.route_utils.get_root_url(request, "/", None) |
| if user: |
| return RedirectResponse(url=f'{root_url}/home/') |
| else: |
| return RedirectResponse(url=f'{root_url}/main/') |
|
|
| @app.route('/logout') |
| async def logout(request: Request): |
| request.session.pop('user', None) |
| return RedirectResponse(url='/') |
|
|
| @app.route('/login') |
| async def login(request: Request): |
| root_url = gr.route_utils.get_root_url(request, "/login", None) |
| redirect_uri = f"{root_url}/auth" |
| print("Redirecting to", redirect_uri) |
| return await oauth.google.authorize_redirect(request, redirect_uri) |
|
|
| @app.route('/auth') |
| async def auth(request: Request): |
| try: |
| access_token = await oauth.google.authorize_access_token(request) |
| except OAuthError: |
| print("Error getting access token", str(OAuthError)) |
| return RedirectResponse(url='/') |
| request.session['user'] = dict(access_token)["userinfo"] |
| print("Redirecting to /home") |
| return RedirectResponse(url='/home') |
|
|
| with gr.Blocks() as login_demo: |
| render_logo_small() |
| with gr.Row(): |
| gr.Column() |
| render_about_markdowns() |
| gr.Column() |
| with gr.Row(): |
| gr.Column() |
| btn = gr.Button("Proceed", variant="huggingface", scale=0) |
| gr.Column() |
| _js_redirect = """ |
| () => { |
| url = '/login' + window.location.search; |
| window.open(url, '_blank'); |
| } |
| """ |
| btn.click(None, js=_js_redirect) |
|
|
| app = gr.mount_gradio_app(app, login_demo, path="/main") |
|
|
| async def register_user(logged_in_user: SheamiUser): |
| user = await get_db().get_user_by_email(email=logged_in_user.email) |
| if not user: |
| await get_db().add_user(email=logged_in_user.email, name=logged_in_user.name) |
|
|
| def get_sheami_user(request: gr.Request): |
| if request is None: |
| return None |
| try: |
| picture = f"{request.request.session["user"]["picture"]}" |
| except: |
| picture = "assets/user.png" |
|
|
| return SheamiUser( |
| email=f"{request.request.session["user"]["email"]}", |
| name=f"{request.username}", |
| picture_url=picture, |
| ) |
|
|
| def get_loggedin_user_name(request: gr.Request): |
| user = get_sheami_user(request) |
| if user is None: |
| return None |
| else: |
| return user.name |
|
|
| def get_loggedin_user_email(request: gr.Request): |
| user = get_sheami_user(request) |
| if user is None: |
| return None |
| else: |
| return user.email |
|
|
| async def build_securely(): |
| with gr.Blocks( |
| title=get_app_title(), theme=get_app_theme(), css=get_css(), fill_height=True |
| ) as demo: |
| |
| with gr.Row(): |
| with gr.Column(scale=4): |
| render_logo() |
| with gr.Column(scale=1): |
| with gr.Group(): |
| gr.Button("Logout", link="/logout", variant="huggingface") |
| logged_in_user_name = gr.Markdown(elem_classes="text-center") |
| logged_in_user_email = gr.Markdown(elem_classes="text-center") |
| gr.Markdown("---") |
| logged_in_sheami_user = gr.State() |
| with gr.Column(elem_id="patient-card") as patient_card: |
| |
| ( |
| selected_patient_info, |
| delete_patient_btn, |
| edit_patient_btn, |
| upload_reports_btn, |
| add_vitals_btn, |
| ) = render_selected_patient_actions() |
|
|
| @gr.render(inputs=logged_in_sheami_user) |
| def render_home_page(user: SheamiUser | None): |
| if user: |
| asyncio.run(register_user(logged_in_user=user)) |
| build_home_page( |
| logged_in_user=user, |
| selected_patient_info=selected_patient_info, |
| delete_patient_btn=delete_patient_btn, |
| edit_patient_btn=edit_patient_btn, |
| upload_reports_btn=upload_reports_btn, |
| add_vitals_btn=add_vitals_btn, |
| ) |
| else: |
| pass |
|
|
| demo.load(get_loggedin_user_name, inputs=None, outputs=logged_in_user_name) |
| demo.load(get_sheami_user, inputs=None, outputs=logged_in_sheami_user) |
| demo.load(get_loggedin_user_email, inputs=None, outputs=logged_in_user_email) |
| |
|
|
| return demo |
| |
| sheami_app = asyncio.run(build_securely()).queue() |
|
|
| app = gr.mount_gradio_app(app, sheami_app, path="/home", auth_dependency=get_user) |
|
|
|
|
| if __name__ == '__main__': |
| uvicorn.run(app,host="0.0.0.0",port=7860) |