pushpinder08's picture
Update app.py
1752355
import torch
import torch.nn as nn
from transformers import BertModel, BertTokenizer
import gradio as gr
# Config
MODEL_NAME = "yiyanghkust/finbert-tone"
FACTOR_MODEL_PATH = "finbert_factors.pth"
FRAMING_MODEL_PATH = "finbert_framing.pth"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Tokenizer
tokenizer = BertTokenizer.from_pretrained(MODEL_NAME)
# Label maps
label_map_factors = {0: 'Internal Factor', 1: 'External Factor', 2: 'No Factor'}
label_map_framing = {0: 'Internal Framing', 1: 'External Framing', 2: 'No Framing'}
# Unified model class (used for both)
class SingleTaskClassifier(nn.Module):
def __init__(self):
super().__init__()
self.bert = BertModel.from_pretrained(MODEL_NAME)
self.dropout = nn.Dropout(0.1)
self.classifier = nn.Linear(self.bert.config.hidden_size, 3)
def forward(self, input_ids, attention_mask):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
pooled = self.dropout(outputs.pooler_output)
return self.classifier(pooled)
# Load models
factor_model = SingleTaskClassifier().to(device)
framing_model = SingleTaskClassifier().to(device)
factor_model.load_state_dict(torch.load(FACTOR_MODEL_PATH, map_location=device))
framing_model.load_state_dict(torch.load(FRAMING_MODEL_PATH, map_location=device))
factor_model.eval()
framing_model.eval()
# Prediction function
def predict(text):
encoding = tokenizer(
text,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=128
)
input_ids = encoding["input_ids"].to(device)
attention_mask = encoding["attention_mask"].to(device)
with torch.no_grad():
logits_factors = factor_model(input_ids, attention_mask)
logits_framing = framing_model(input_ids, attention_mask)
pred_factors = torch.argmax(logits_factors, dim=1).item()
pred_framing = torch.argmax(logits_framing, dim=1).item()
return label_map_factors[pred_factors], label_map_framing[pred_framing]
# Gradio interface
gr.Interface(
fn=predict,
inputs=gr.Textbox(lines=3, placeholder="Enter a sentence to analyze..."),
outputs=["text", "text"],
title="FinBERT Dual Classifier",
description="This demo independently predicts both Factors and Framing using two fine-tuned FinBERT models."
).launch(share=True)