Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
from transformers import BertModel, BertTokenizer | |
import gradio as gr | |
# Config | |
MODEL_NAME = "yiyanghkust/finbert-tone" | |
FACTOR_MODEL_PATH = "finbert_factors.pth" | |
FRAMING_MODEL_PATH = "finbert_framing.pth" | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# Tokenizer | |
tokenizer = BertTokenizer.from_pretrained(MODEL_NAME) | |
# Label maps | |
label_map_factors = {0: 'Internal Factor', 1: 'External Factor', 2: 'No Factor'} | |
label_map_framing = {0: 'Internal Framing', 1: 'External Framing', 2: 'No Framing'} | |
# Unified model class (used for both) | |
class SingleTaskClassifier(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.bert = BertModel.from_pretrained(MODEL_NAME) | |
self.dropout = nn.Dropout(0.1) | |
self.classifier = nn.Linear(self.bert.config.hidden_size, 3) | |
def forward(self, input_ids, attention_mask): | |
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) | |
pooled = self.dropout(outputs.pooler_output) | |
return self.classifier(pooled) | |
# Load models | |
factor_model = SingleTaskClassifier().to(device) | |
framing_model = SingleTaskClassifier().to(device) | |
factor_model.load_state_dict(torch.load(FACTOR_MODEL_PATH, map_location=device)) | |
framing_model.load_state_dict(torch.load(FRAMING_MODEL_PATH, map_location=device)) | |
factor_model.eval() | |
framing_model.eval() | |
# Prediction function | |
def predict(text): | |
encoding = tokenizer( | |
text, | |
return_tensors="pt", | |
padding="max_length", | |
truncation=True, | |
max_length=128 | |
) | |
input_ids = encoding["input_ids"].to(device) | |
attention_mask = encoding["attention_mask"].to(device) | |
with torch.no_grad(): | |
logits_factors = factor_model(input_ids, attention_mask) | |
logits_framing = framing_model(input_ids, attention_mask) | |
pred_factors = torch.argmax(logits_factors, dim=1).item() | |
pred_framing = torch.argmax(logits_framing, dim=1).item() | |
return label_map_factors[pred_factors], label_map_framing[pred_framing] | |
# Gradio interface | |
gr.Interface( | |
fn=predict, | |
inputs=gr.Textbox(lines=3, placeholder="Enter a sentence to analyze..."), | |
outputs=["text", "text"], | |
title="FinBERT Dual Classifier", | |
description="This demo independently predicts both Factors and Framing using two fine-tuned FinBERT models." | |
).launch(share=True) |