Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, HTTPException, File, UploadFile | |
from fastapi.middleware.cors import CORSMiddleware | |
from typing import Annotated | |
from pydantic import BaseModel | |
import uvicorn | |
from fastapi import FastAPI, UploadFile, File | |
from typing import Union | |
import json | |
import csv | |
from modeles import bert, squeezebert, deberta | |
app = FastAPI() | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
async def startup_event(): | |
print("start") | |
async def root(): | |
return {"message": "Hello World"} | |
async def create_upload_file(file: UploadFile, texte: str, model: str): | |
return {"model": model, "texte": texte, "filename": file.filename} | |
async def create_upload_file(context: str, texte: str, model: str): | |
return {"model": model, "texte": texte, "context": context} | |
async def create_upload_file(texte: str, model: str): | |
return {"model": model, "texte": texte} | |
# # Modèle Pydantic pour les requêtes SqueezeBERT | |
# class SqueezeBERTRequest(BaseModel): | |
# context: str | |
# question: str | |
async def qasqueezebert(context: str, question: str): | |
try: | |
squeezebert_answer = squeezebert(context, question) | |
if squeezebert_answer: | |
return squeezebert_answer | |
else: | |
raise HTTPException(status_code=404, detail="No answer found") | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}") | |
# # Modèle Pydantic pour les requêtes BERT | |
# class BERTRequest(BaseModel): | |
# context: str | |
# question: str | |
async def qabert(context: str, question: str): | |
try: | |
bert_answer = bert(context, question) | |
if bert_answer: | |
return bert_answer | |
else: | |
raise HTTPException(status_code=404, detail="No answer found") | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}") | |
# # Modèle Pydantic pour les requêtes DeBERTa | |
# class DeBERTaRequest(BaseModel): | |
# context: str | |
# question: str | |
async def qadeberta(context: str, question: str): | |
try: | |
deberta_answer = deberta(context, question) | |
if deberta_answer: | |
return deberta_answer | |
else: | |
raise HTTPException(status_code=404, detail="No answer found") | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}") | |
def extract_data(file: UploadFile) -> Union[str, dict, list]: | |
if file.filename.endswith(".txt"): | |
data = file.file.read() | |
return data.decode("utf-8") | |
elif file.filename.endswith(".csv"): | |
data = file.file.read().decode("utf-8") | |
rows = data.split("\n") | |
reader = csv.DictReader(rows) | |
return [dict(row) for row in reader] | |
elif file.filename.endswith(".json"): | |
data = file.file.read().decode("utf-8") | |
return json.loads(data) | |
else: | |
return "Invalid file format" | |