|
|
|
|
|
""" |
|
|
Hugging Face Space App for Financial Sentiment Analysis Ensemble |
|
|
""" |
|
|
|
|
|
import gradio as gr |
|
|
import torch |
|
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
|
import numpy as np |
|
|
from datetime import datetime |
|
|
import json |
|
|
|
|
|
class FinancialSentimentEnsemble: |
|
|
def __init__(self): |
|
|
self.models = {} |
|
|
self.tokenizers = {} |
|
|
self.model_names = [ |
|
|
"codealchemist01/financial-sentiment-distilbert", |
|
|
"codealchemist01/financial-sentiment-bert-large", |
|
|
"codealchemist01/financial-sentiment-improved" |
|
|
] |
|
|
self.labels = ["Bearish π", "Neutral β‘οΈ", "Bullish π"] |
|
|
self.load_models() |
|
|
|
|
|
def load_models(self): |
|
|
"""Load all models and tokenizers""" |
|
|
print("π Loading Financial Sentiment Analysis Ensemble...") |
|
|
|
|
|
for i, model_name in enumerate(self.model_names): |
|
|
try: |
|
|
print(f"π₯ Loading {model_name}...") |
|
|
self.tokenizers[i] = AutoTokenizer.from_pretrained(model_name) |
|
|
self.models[i] = AutoModelForSequenceClassification.from_pretrained(model_name) |
|
|
self.models[i].eval() |
|
|
print(f"β
{model_name} loaded successfully!") |
|
|
except Exception as e: |
|
|
print(f"β Error loading {model_name}: {e}") |
|
|
|
|
|
print(f"π Ensemble ready with {len(self.models)} models!") |
|
|
|
|
|
def predict_single_model(self, text, model_idx): |
|
|
"""Predict sentiment using a single model""" |
|
|
if model_idx not in self.models: |
|
|
return None |
|
|
|
|
|
try: |
|
|
inputs = self.tokenizers[model_idx]( |
|
|
text, |
|
|
return_tensors="pt", |
|
|
truncation=True, |
|
|
padding=True, |
|
|
max_length=512 |
|
|
) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.models[model_idx](**inputs) |
|
|
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1) |
|
|
|
|
|
return probabilities[0].numpy() |
|
|
except Exception as e: |
|
|
print(f"Error in model {model_idx}: {e}") |
|
|
return None |
|
|
|
|
|
def predict_ensemble(self, text): |
|
|
"""Predict sentiment using ensemble of all models""" |
|
|
if not text.strip(): |
|
|
return "Please enter some text to analyze.", {}, {} |
|
|
|
|
|
individual_predictions = {} |
|
|
all_probabilities = [] |
|
|
|
|
|
|
|
|
for i, model_name in enumerate(self.model_names): |
|
|
probs = self.predict_single_model(text, i) |
|
|
if probs is not None: |
|
|
all_probabilities.append(probs) |
|
|
|
|
|
|
|
|
predicted_class = np.argmax(probs) |
|
|
confidence = probs[predicted_class] |
|
|
|
|
|
model_short_name = model_name.split("/")[-1].replace("financial-sentiment-", "").title() |
|
|
individual_predictions[f"{model_short_name}"] = { |
|
|
"Prediction": self.labels[predicted_class], |
|
|
"Confidence": f"{confidence:.1%}" |
|
|
} |
|
|
|
|
|
if not all_probabilities: |
|
|
return "Error: No models available for prediction.", {}, {} |
|
|
|
|
|
|
|
|
ensemble_probs = np.mean(all_probabilities, axis=0) |
|
|
ensemble_prediction = np.argmax(ensemble_probs) |
|
|
ensemble_confidence = ensemble_probs[ensemble_prediction] |
|
|
|
|
|
|
|
|
prob_dict = {} |
|
|
for i, label in enumerate(self.labels): |
|
|
prob_dict[label] = float(ensemble_probs[i]) |
|
|
|
|
|
|
|
|
result_text = f""" |
|
|
## π― Ensemble Prediction: **{self.labels[ensemble_prediction]}** |
|
|
**Confidence:** {ensemble_confidence:.1%} |
|
|
|
|
|
### π Probability Distribution: |
|
|
- π Bearish: {ensemble_probs[0]:.1%} |
|
|
- β‘οΈ Neutral: {ensemble_probs[1]:.1%} |
|
|
- π Bullish: {ensemble_probs[2]:.1%} |
|
|
|
|
|
### π€ Individual Model Results: |
|
|
""" |
|
|
|
|
|
for model_name, result in individual_predictions.items(): |
|
|
result_text += f"- **{model_name}**: {result['Prediction']} ({result['Confidence']})\n" |
|
|
|
|
|
return result_text, prob_dict, individual_predictions |
|
|
|
|
|
|
|
|
ensemble = FinancialSentimentEnsemble() |
|
|
|
|
|
def analyze_sentiment(text): |
|
|
"""Main function for Gradio interface""" |
|
|
return ensemble.predict_ensemble(text) |
|
|
|
|
|
|
|
|
examples = [ |
|
|
"The stock market is showing strong bullish momentum with record highs across major indices.", |
|
|
"Company earnings fell short of expectations, leading to a significant drop in share price.", |
|
|
"The Federal Reserve maintained interest rates, keeping market conditions stable.", |
|
|
"Tesla's innovative battery technology could revolutionize the automotive industry.", |
|
|
"Rising inflation concerns are creating uncertainty in the financial markets.", |
|
|
"The merger announcement sent both companies' stock prices soaring.", |
|
|
"Quarterly results were mixed, with some sectors outperforming while others lagged." |
|
|
] |
|
|
|
|
|
|
|
|
with gr.Blocks( |
|
|
theme=gr.themes.Soft(), |
|
|
title="Financial Sentiment Analysis Ensemble", |
|
|
css=""" |
|
|
.gradio-container { |
|
|
max-width: 1200px !important; |
|
|
} |
|
|
.main-header { |
|
|
text-align: center; |
|
|
margin-bottom: 2rem; |
|
|
} |
|
|
.model-info { |
|
|
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); |
|
|
color: white; |
|
|
padding: 1rem; |
|
|
border-radius: 10px; |
|
|
margin: 1rem 0; |
|
|
} |
|
|
""" |
|
|
) as demo: |
|
|
|
|
|
gr.HTML(""" |
|
|
<div class="main-header"> |
|
|
<h1>π Financial Sentiment Analysis Ensemble</h1> |
|
|
<p>Advanced AI-powered sentiment analysis for financial texts using an ensemble of 3 fine-tuned models</p> |
|
|
</div> |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=2): |
|
|
text_input = gr.Textbox( |
|
|
label="π Enter Financial Text", |
|
|
placeholder="Type or paste financial news, social media posts, or market commentary here...", |
|
|
lines=4, |
|
|
max_lines=10 |
|
|
) |
|
|
|
|
|
analyze_btn = gr.Button("π Analyze Sentiment", variant="primary", size="lg") |
|
|
|
|
|
gr.Examples( |
|
|
examples=examples, |
|
|
inputs=text_input, |
|
|
label="π‘ Try these examples:" |
|
|
) |
|
|
|
|
|
with gr.Column(scale=3): |
|
|
result_output = gr.Markdown(label="π Analysis Results") |
|
|
|
|
|
with gr.Row(): |
|
|
prob_plot = gr.BarPlot( |
|
|
x="Sentiment", |
|
|
y="Probability", |
|
|
title="Ensemble Probability Distribution", |
|
|
x_title="Sentiment Categories", |
|
|
y_title="Probability", |
|
|
width=400, |
|
|
height=300 |
|
|
) |
|
|
|
|
|
individual_results = gr.JSON( |
|
|
label="π€ Individual Model Predictions", |
|
|
visible=True |
|
|
) |
|
|
|
|
|
|
|
|
gr.HTML(""" |
|
|
<div class="model-info"> |
|
|
<h3>π§ Ensemble Models:</h3> |
|
|
<ul> |
|
|
<li><strong>DistilBERT Model:</strong> Fast and efficient, optimized for real-time analysis</li> |
|
|
<li><strong>BERT-Large Model:</strong> High accuracy with deep contextual understanding</li> |
|
|
<li><strong>Improved Model:</strong> Enhanced with advanced training techniques</li> |
|
|
</ul> |
|
|
<p><strong>Ensemble Accuracy:</strong> 79.7% | <strong>Categories:</strong> Bearish π, Neutral β‘οΈ, Bullish π</p> |
|
|
</div> |
|
|
""") |
|
|
|
|
|
|
|
|
analyze_btn.click( |
|
|
fn=analyze_sentiment, |
|
|
inputs=text_input, |
|
|
outputs=[result_output, prob_plot, individual_results] |
|
|
) |
|
|
|
|
|
text_input.submit( |
|
|
fn=analyze_sentiment, |
|
|
inputs=text_input, |
|
|
outputs=[result_output, prob_plot, individual_results] |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch( |
|
|
server_name="0.0.0.0", |
|
|
server_port=7860, |
|
|
share=False |
|
|
) |