Spaces:
Running
Running
feat: melhora a chamada do openrouter para conseguir dar fallback em outros modelos
f8e49c0
verified
| from flask import Flask, request, jsonify | |
| import torch | |
| import torch.nn as nn | |
| import pandas as pd | |
| from transformers import BertModel, AutoTokenizer | |
| import re | |
| from flask_cors import CORS | |
| import logging | |
| import os | |
| import requests | |
| from dotenv import load_dotenv | |
| # Carregar variáveis de ambiente | |
| load_dotenv() | |
| app = Flask(__name__) | |
| CORS(app) | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| MAX_LEN = 200 | |
| # --- OpenRouter API --- | |
| OPENROUTER_API_KEY = os.getenv('OPENROUTER_API_KEY') | |
| OPENROUTER_URL = 'https://openrouter.ai/api/v1/chat/completions' | |
| # --- Emoções (mesma ordem do treino) --- | |
| EMOTION_LABELS = [ | |
| 'Neutro', 'Alegria', 'Tristeza', 'Raiva', 'Medo', | |
| 'Nojo', 'Surpresa', 'Confiança', 'Antecipação' | |
| ] | |
| # Modelo usado no TREINO | |
| MODEL_NAME = "neuralmind/bert-base-portuguese-cased" | |
| SAVE_DIR = "models" | |
| MODEL_PATH = f"{SAVE_DIR}/best_model.pth" | |
| # ====================================================================== | |
| # BERT CLASSIFIER IGUAL AO DO TREINO | |
| # ====================================================================== | |
| class BERTClassifier(nn.Module): | |
| def __init__(self, model_name="neuralmind/bert-base-portuguese-cased", num_classes=9, dropout=0.3): | |
| super().__init__() | |
| self.bert = BertModel.from_pretrained(model_name) | |
| self.dropout = nn.Dropout(dropout) | |
| self.classifier = nn.Linear(self.bert.config.hidden_size, num_classes) | |
| def forward(self, input_ids, attention_mask, token_type_ids=None): | |
| outputs = self.bert( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| token_type_ids=token_type_ids, | |
| return_dict=True | |
| ) | |
| cls = outputs.last_hidden_state[:, 0] | |
| x = self.dropout(cls) | |
| return self.classifier(x) | |
| # ====================================================================== | |
| # CARREGAR O MESMO MODELO E TOKENIZER DO TREINO | |
| # ====================================================================== | |
| def load_trained_model(): | |
| logger.info(f"Carregando modelo e tokenizer de '{SAVE_DIR}'...") | |
| if not os.path.exists(MODEL_PATH): | |
| raise FileNotFoundError(f"Não encontrei o modelo treinado em {MODEL_PATH}") | |
| tokenizer = AutoTokenizer.from_pretrained(SAVE_DIR) | |
| model = BERTClassifier(model_name=MODEL_NAME, num_classes=len(EMOTION_LABELS)) | |
| state_dict = torch.load(MODEL_PATH, map_location=device) | |
| model.load_state_dict(state_dict) | |
| model.to(device) | |
| model.eval() | |
| logger.info("Modelo treinado carregado com sucesso!") | |
| return model, tokenizer | |
| try: | |
| bert_model, tokenizer = load_trained_model() | |
| loaded = True | |
| except Exception as e: | |
| logger.error(f"ERRO ao carregar modelo: {e}") | |
| loaded = False | |
| # ====================================================================== | |
| # PRÉ-PROCESSAMENTO (mesmo estilo do treino) | |
| # ====================================================================== | |
| def preprocess_text(text: str): | |
| if not isinstance(text, str): | |
| return "" | |
| text = re.sub(r'http\S+', '', text) | |
| return text.strip() | |
| def tokenize_text(text: str): | |
| text = preprocess_text(text) | |
| encoding = tokenizer.encode_plus( | |
| text, | |
| add_special_tokens=True, | |
| max_length=MAX_LEN, | |
| padding='max_length', | |
| truncation=True, | |
| return_attention_mask=True, | |
| return_token_type_ids=True, | |
| return_tensors='pt' | |
| ) | |
| return { | |
| 'input_ids': encoding['input_ids'].to(device), | |
| 'attention_mask': encoding['attention_mask'].to(device), | |
| 'token_type_ids': encoding['token_type_ids'].to(device) | |
| } | |
| # ====================================================================== | |
| # FUNÇÃO PARA CHAMAR OPENROUTER | |
| # ====================================================================== | |
| import time | |
| def call_openrouter(frase): | |
| if not OPENROUTER_API_KEY: | |
| raise ValueError("OPENROUTER_API_KEY não configurada") | |
| prompt = f""" | |
| Analise a frase: "{frase}". | |
| Escolha UMA emoção principal de Plutchik: | |
| 'Neutro', 'Alegria', 'Tristeza', 'Raiva', 'Medo', 'Nojo', 'Surpresa', 'Confiança', 'Antecipação'. | |
| Responda apenas com a emoção, sem explicação. | |
| """ | |
| headers = { | |
| "Authorization": f"Bearer {OPENROUTER_API_KEY}", | |
| "HTTP-Referer": "http://localhost:3030", | |
| "Content-Type": "application/json" | |
| } | |
| # MODELOS A SEREM TENTADOS (ordem de fallback) | |
| modelos_fallback = [ | |
| "google/gemma-3-12b-it:free", | |
| "google/gemma-3-4b-it:free", | |
| "google/gemma-3-27b-it:free", | |
| "nvidia/nemotron-nano-12b-v2-vl:free" | |
| ] | |
| erros = [] | |
| # 🔁 Tentar cada modelo até um funcionar | |
| for modelo in modelos_fallback: | |
| payload = { | |
| "model": modelo, | |
| "messages": [{"role": "user", "content": prompt}], | |
| "temperature": 0.0 | |
| } | |
| for tentativa in range(2): # duas tentativas por modelo | |
| try: | |
| response = requests.post(OPENROUTER_URL, json=payload, headers=headers) | |
| result = response.json() | |
| # Erro da OpenRouter | |
| if "error" in result: | |
| msg = str(result["error"]).lower() | |
| # ⚠ Rate limit → tentar outra vez ou outro modelo | |
| if "rate" in msg or "429" in msg: | |
| time.sleep(1.2) | |
| continue | |
| # Outro erro → pular para próximo modelo | |
| erros.append((modelo, result["error"])) | |
| break | |
| # Resposta válida | |
| if "choices" in result: | |
| emocao = result["choices"][0]["message"]["content"].strip().title() | |
| emocao = emocao.replace("ç", "c").replace("ã", "a") | |
| return emocao, modelo | |
| # Formato inesperado | |
| erros.append((modelo, result)) | |
| break | |
| except Exception as e: | |
| erros.append((modelo, str(e))) | |
| time.sleep(1) | |
| # Se chegou aqui, tenta o próximo modelo | |
| # Se nenhum modelo funcionou | |
| raise ValueError({ | |
| "mensagem": "Nenhum modelo conseguiu responder.", | |
| "tentativas": erros | |
| }) | |
| # ====================================================================== | |
| # ENDPOINTS | |
| # ====================================================================== | |
| def home(): | |
| return jsonify({ | |
| "status": "API de Emoções BERT", | |
| "modelo_carregado": loaded, | |
| "device": str(device), | |
| "emocoes": EMOTION_LABELS | |
| }) | |
| def predict_emotion(): | |
| try: | |
| if not loaded: | |
| return jsonify({"erro": "Modelo BERT não carregado!"}), 500 | |
| data = request.get_json() | |
| texto = data.get("texto", "").strip() | |
| if not texto: | |
| return jsonify({"erro": "Campo 'texto' é obrigatório"}), 400 | |
| inputs = tokenize_text(texto) | |
| with torch.no_grad(): | |
| outputs = bert_model( | |
| inputs['input_ids'], | |
| inputs['attention_mask'], | |
| inputs['token_type_ids'] | |
| ) | |
| probs = torch.softmax(outputs, dim=1)[0].cpu().numpy() | |
| pred_idx = int(probs.argmax()) | |
| return jsonify({ | |
| "texto": texto, | |
| "emocao": EMOTION_LABELS[pred_idx], | |
| "confianca": float(probs[pred_idx]), | |
| "todas_emocoes": { | |
| EMOTION_LABELS[i]: float(probs[i]) for i in range(9) | |
| } | |
| }) | |
| except Exception as e: | |
| logger.exception("Erro na predição:") | |
| return jsonify({"erro": str(e)}), 500 | |
| def predict_llm(): | |
| try: | |
| data = request.get_json() | |
| frase = data.get("frase", "") | |
| if not frase: | |
| return jsonify({"erro": "Campo 'frase' é obrigatório"}), 400 | |
| emocao, modelo_usado = call_openrouter(frase) | |
| return jsonify({ | |
| "frase": frase, | |
| "emocao": emocao, | |
| "modelo": modelo_usado | |
| }) | |
| except Exception as e: | |
| return jsonify({"erro": str(e)}), 500 | |
| # ====================================================================== | |
| # EXECUTAR SERVIDOR | |
| # ====================================================================== | |
| if __name__ == '__main__': | |
| app.run(host='0.0.0.0', port=7860) | |