Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException, UploadFile, File,Request,Depends,status,BackgroundTasks | |
| from fastapi.security import OAuth2PasswordBearer | |
| from pydantic import BaseModel | |
| from typing import Optional, List | |
| from uuid import uuid4 | |
| import os | |
| from dotenv import load_dotenv | |
| from rag import * | |
| from fastapi.responses import StreamingResponse | |
| import json | |
| from prompt import * | |
| from fastapi.middleware.cors import CORSMiddleware | |
| import requests | |
| import pandas as pd | |
| load_dotenv() | |
| ## setup authorization | |
| api_keys = [os.environ.get("FASTAPI_API_KEY")] | |
| oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") # use token authentication | |
| def api_key_auth(api_key: str = Depends(oauth2_scheme)): | |
| if api_key not in api_keys: | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="Forbidden" | |
| ) | |
| dev_mode = os.environ.get("DEV") | |
| if dev_mode == "True": | |
| app = FastAPI() | |
| else: | |
| app = FastAPI(dependencies=[Depends(api_key_auth)]) | |
| app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"]) | |
| # Pydantic model for the form data | |
| class verify_response_model(BaseModel): | |
| response: str = Field(description="The response from the user to the question") | |
| answers: list[str] = Field(description="The possible answers to the question to test if the user read the entire book") | |
| question: str = Field(description="The question asked to the user to test if they read the entire book") | |
| class UserInput(BaseModel): | |
| query: str | |
| stream: Optional[bool] = False | |
| messages: Optional[list[dict]] = [] | |
| class Artwork(BaseModel): | |
| name: str | |
| artist: str | |
| image_url: str | |
| date: str | |
| description: str | |
| class WhatifInput(BaseModel): | |
| question: str | |
| answer: str | |
| # Global variable to store the data | |
| artworks_data = [] | |
| def load_data(): | |
| global artworks_data | |
| # Provide the path to your local spreadsheet | |
| spreadsheet_path = "data.xlsx" | |
| # Read the spreadsheet into a DataFrame | |
| df = pd.read_excel(spreadsheet_path, sheet_name='Sheet1') # Adjust sheet_name as needed | |
| df = df.fillna(False) | |
| # Convert DataFrame to a list of dictionaries | |
| df_filtered = df[df['Publication'] == True] | |
| artworks_data = df_filtered.to_dict(orient='records') | |
| print("Data loaded successfully") | |
| load_data() | |
| #endpoinds | |
| async def get_artworks_by_artist(artist_name: str): | |
| artist_name_lower = artist_name.lower() | |
| results = [] | |
| for artwork in artworks_data: | |
| if artist_name_lower in artwork['Artiste'].lower(): | |
| result = { | |
| 'name':artwork['Titre français'], | |
| 'artist':artwork['Artiste'], | |
| 'image_url':artwork['Image_URL'], | |
| 'date':str(artwork['Date']), # Ensure date is a string | |
| 'description':artwork['Media'] | |
| } | |
| results.append(result) | |
| if not results: | |
| raise HTTPException(status_code=404, detail="Artist not found") | |
| return results | |
| async def generate_sphinx(): | |
| try: | |
| sphinx : sphinx_output = generate_sphinx_response() | |
| return {"question": sphinx.question, "answers": sphinx.answers} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def verify_sphinx(response: verify_response_model): | |
| try: | |
| score : bool = verify_response(response.response, response.answers, response.question) | |
| return {"score": score} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def generate(user_input: UserInput): | |
| try: | |
| print(user_input.stream,user_input.query) | |
| if user_input.stream: | |
| return StreamingResponse(generate_stream(user_input.query,user_input.messages,stream=True),media_type="application/json") | |
| else: | |
| return generate_stream(user_input.query,user_input.messages,stream=False) | |
| except Exception as e: | |
| return {"message": str(e)} | |
| async def whatif(whatif_input: WhatifInput): | |
| try: | |
| print(whatif_input.question) | |
| return generate_whatif_stream(question=whatif_input.question,response=whatif_input.answer) | |
| except Exception as e: | |
| return {"message": str(e)} | |
| async def generate_whatif_chat(user_input: UserInput): | |
| try: | |
| if user_input.stream: | |
| return StreamingResponse(generate_stream_whatif_chat(user_input.query,user_input.messages,stream=True),media_type="application/json") | |
| else: | |
| return generate_stream_whatif_chat(user_input.query,user_input.messages,stream=False) | |
| except Exception as e: | |
| return {"message": str(e)} |