|
import os |
|
from fastapi import FastAPI, HTTPException, Request |
|
from pydantic import BaseModel |
|
from transformers import BertTokenizer, BertForSequenceClassification |
|
import torch |
|
import gdown |
|
from fastapi.responses import JSONResponse |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from fastapi.responses import JSONResponse |
|
|
|
|
|
port = int(os.environ.get("PORT", 8080)) |
|
|
|
|
|
print(f"Starting application on port {port}") |
|
|
|
|
|
file_id = 'id' |
|
model_url = f'https://drive.google.com/uc?id={file_id}' |
|
output = 'bert_model_cpu.bin' |
|
gdown.download(model_url, output, quiet=False) |
|
|
|
|
|
if not os.path.exists(output) or os.path.getsize(output) < 1024: |
|
raise ValueError("El archivo descargado no parece ser un modelo v谩lido. Verifica el enlace de Google Drive.") |
|
|
|
|
|
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True) |
|
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2) |
|
|
|
|
|
device = torch.device("cpu") |
|
try: |
|
model.load_state_dict(torch.load(output, map_location=device)) |
|
model.to(device) |
|
model.eval() |
|
except Exception as e: |
|
raise ValueError(f"Error al cargar el modelo: {e}") |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
|
|
|
|
@app.middleware("http") |
|
async def validate_json(request: Request, call_next): |
|
try: |
|
if request.method == "POST": |
|
await request.json() |
|
except Exception: |
|
return JSONResponse( |
|
status_code=400, |
|
content={"message": "Invalid JSON format."}, |
|
) |
|
response = await call_next(request) |
|
return response |
|
|
|
|
|
class SentimentRequest(BaseModel): |
|
text: str |
|
|
|
|
|
def preprocess_tweet(tweet, tokenizer, max_length=64): |
|
encoded_dict = tokenizer.encode_plus( |
|
tweet, |
|
add_special_tokens=True, |
|
max_length=max_length, |
|
truncation=True, |
|
padding='max_length', |
|
return_attention_mask=True, |
|
return_tensors='pt', |
|
) |
|
return encoded_dict['input_ids'], encoded_dict['attention_mask'] |
|
|
|
|
|
def predict_tweet_sentiment(tweet, model, tokenizer): |
|
input_ids, attention_mask = preprocess_tweet(tweet, tokenizer) |
|
input_ids = input_ids.to(device) |
|
attention_mask = attention_mask.to(device) |
|
with torch.no_grad(): |
|
outputs = model(input_ids, token_type_ids=None, attention_mask=attention_mask) |
|
logits = outputs.logits |
|
prediction = torch.argmax(logits, dim=1).item() |
|
return "Positivo" if prediction == 1 else "Negativo" |
|
|
|
|
|
@app.post("/predict/") |
|
async def predict(request: SentimentRequest): |
|
try: |
|
print(f"Prediciendo sentimiento para: {request.text}") |
|
sentiment = predict_tweet_sentiment(request.text, model, tokenizer) |
|
return {"sentiment": sentiment} |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
print(f"Running Uvicorn on host 0.0.0.0 and port {port}") |
|
uvicorn.run(app, host="0.0.0.0", port=port) |
|
|