Spaces:
Sleeping
Sleeping
# Import necessary libraries | |
import gradio as gr | |
from datasets import load_dataset, Dataset # Explicitly import Dataset class | |
import pandas as pd | |
import numpy as np | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments | |
import torch | |
from sklearn.model_selection import train_test_split | |
from sklearn.metrics import accuracy_score, precision_recall_fscore_support | |
import os | |
import shutil | |
# Load dataset once at startup | |
ds = load_dataset("ashraq/financial-news-articles") | |
df = pd.DataFrame(ds['train']) | |
# Simulate labels (replace with real labels in practice) | |
np.random.seed(42) | |
df['label'] = np.random.randint(0, 3, size=len(df)) # 0=neg, 1=neu, 2=pos | |
df['input_text'] = df['title'] + " " + df['text'] | |
# Global variables for model and tokenizer | |
model = None | |
tokenizer = None | |
sentiment_map = {0: "Negative", 1: "Neutral", 2: "Positive"} | |
# Function to tokenize dataset | |
def tokenize_function(examples, tokenizer): | |
return tokenizer(examples['input_text'], padding="max_length", truncation=True, max_length=512) | |
# Function to compute metrics | |
def compute_metrics(pred): | |
labels = pred.label_ids | |
preds = np.argmax(pred.predictions, axis=1) | |
precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted') | |
acc = accuracy_score(labels, preds) | |
return {'accuracy': acc, 'f1': f1, 'precision': precision, 'recall': recall} | |
# Train the model with user-defined parameters | |
def train_model(learning_rate, epochs, batch_size, save_path): | |
global model, tokenizer | |
# Split dataset | |
train_df, test_df = train_test_split(df, test_size=0.2, random_state=42) | |
# Load tokenizer and model | |
model_name = "bert-base-uncased" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=3) | |
# Prepare datasets | |
train_dataset = Dataset.from_pandas(train_df[['input_text', 'label']]) | |
test_dataset = Dataset.from_pandas(test_df[['input_text', 'label']]) | |
train_dataset = train_dataset.map(lambda x: tokenize_function(x, tokenizer), batched=True) | |
test_dataset = test_dataset.map(lambda x: tokenize_function(x, tokenizer), batched=True) | |
train_dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'label']) | |
test_dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'label']) | |
# Training arguments | |
training_args = TrainingArguments( | |
output_dir="./temp_model", | |
evaluation_strategy="epoch", | |
learning_rate=learning_rate, | |
per_device_train_batch_size=batch_size, | |
per_device_eval_batch_size=batch_size, | |
num_train_epochs=epochs, | |
weight_decay=0.01, | |
logging_dir='./logs', | |
logging_steps=10, | |
save_strategy="epoch", | |
load_best_model_at_end=True, | |
) | |
# Initialize trainer | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
train_dataset=train_dataset, | |
eval_dataset=test_dataset, | |
compute_metrics=compute_metrics, | |
) | |
# Train and evaluate | |
trainer.train() | |
eval_results = trainer.evaluate() | |
# Save the model if path provided | |
if save_path: | |
trainer.save_model(save_path) | |
tokenizer.save_pretrained(save_path) | |
output = f"Model saved to {save_path}\nEvaluation results: {eval_results}" | |
else: | |
output = f"Model trained but not saved.\nEvaluation results: {eval_results}" | |
# Clean up temp directory | |
if os.path.exists("./temp_model"): | |
shutil.rmtree("./temp_model") | |
if os.path.exists("./logs"): | |
shutil.rmtree("./logs") | |
return output | |
# Load a pre-trained model for inference | |
def load_model(model_path): | |
global model, tokenizer | |
if not os.path.exists(model_path): | |
return "Error: Model path does not exist." | |
tokenizer = AutoTokenizer.from_pretrained(model_path) | |
model = AutoModelForSequenceClassification.from_pretrained(model_path) | |
return "Model loaded successfully from " + model_path | |
# Predict sentiment for new input | |
def predict_sentiment(title, text): | |
global model, tokenizer | |
if model is None or tokenizer is None: | |
return "Error: Please train or load a model first." | |
input_text = title + " " + text | |
inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
pred_label = np.argmax(outputs.logits.numpy(), axis=1)[0] | |
return f"Predicted Sentiment: {sentiment_map[pred_label]}" | |
# Gradio interface | |
with gr.Blocks(title="Financial News Sentiment Analyzer") as demo: | |
gr.Markdown("# Financial News Sentiment Analyzer") | |
gr.Markdown("Train a sentiment model on financial news articles, save it, and predict sentiments.") | |
with gr.Tab("Train Model"): | |
gr.Markdown("### Train a New Sentiment Model") | |
learning_rate = gr.Slider(1e-5, 5e-5, value=2e-5, label="Learning Rate") | |
epochs = gr.Slider(1, 5, value=3, step=1, label="Number of Epochs") | |
batch_size = gr.Slider(4, 16, value=8, step=4, label="Batch Size") | |
save_path = gr.Textbox(label="Save Model Path (optional)", placeholder="e.g., ./my_sentiment_model") | |
train_button = gr.Button("Train Model") | |
output = gr.Textbox(label="Training Output") | |
train_button.click( | |
fn=train_model, | |
inputs=[learning_rate, epochs, batch_size, save_path], | |
outputs=output | |
) | |
with gr.Tab("Load Model"): | |
gr.Markdown("### Load an Existing Model") | |
model_path = gr.Textbox(label="Model Path", placeholder="e.g., ./my_sentiment_model") | |
load_button = gr.Button("Load Model") | |
load_output = gr.Textbox(label="Load Status") | |
load_button.click( | |
fn=load_model, | |
inputs=model_path, | |
outputs=load_output | |
) | |
with gr.Tab("Predict Sentiment"): | |
gr.Markdown("### Predict Sentiment for New Input") | |
title_input = gr.Textbox(label="Article Title") | |
text_input = gr.Textbox(label="Article Text", lines=5) | |
predict_button = gr.Button("Predict") | |
pred_output = gr.Textbox(label="Prediction") | |
predict_button.click( | |
fn=predict_sentiment, | |
inputs=[title_input, text_input], | |
outputs=pred_output | |
) | |
# Launch the app | |
demo.launch() |