Spaces:
Running
Running
import gradio as gr | |
import requests | |
import torch | |
import torch.nn as nn | |
import re | |
import datetime | |
from transformers import AutoTokenizer | |
import numpy as np | |
from transformers import AutoModelForSequenceClassification | |
from transformers import TFAutoModelForSequenceClassification | |
from transformers import AutoConfig | |
from scipy.special import softmax | |
# Load tokenizer and sentiment model | |
MODEL = "cardiffnlp/xlm-twitter-politics-sentiment" | |
tokenizer = AutoTokenizer.from_pretrained(MODEL) | |
config = AutoConfig.from_pretrained(MODEL) | |
model = AutoModelForSequenceClassification.from_pretrained(MODEL) | |
model.save_pretrained(MODEL) | |
class ScorePredictor(nn.Module): | |
def __init__(self, vocab_size, embedding_dim=128, hidden_dim=256, output_dim=1): | |
super(ScorePredictor, self).__init__() | |
self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0) | |
self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True) | |
self.fc = nn.Linear(hidden_dim, output_dim) | |
self.sigmoid = nn.Sigmoid() | |
def forward(self, input_ids, attention_mask): | |
embedded = self.embedding(input_ids) | |
lstm_out, _ = self.lstm(embedded) | |
final_hidden_state = lstm_out[:, -1, :] | |
output = self.fc(final_hidden_state) | |
return self.sigmoid(output) | |
# Load trained score predictor model | |
score_model = ScorePredictor(tokenizer.vocab_size) | |
score_model.load_state_dict(torch.load("score_predictor.pth")) | |
score_model.eval() | |
# preprocesses text | |
def preprocess_text(text): | |
text = text.lower() | |
text = re.sub(r'http\S+', '', text) | |
text = re.sub(r'[^a-zA-Z0-9\s.,!?]', '', text) | |
text = re.sub(r'\s+', ' ', text).strip() | |
return text | |
# predicts sentiment | |
def predict_sentiment(text): | |
if not text: | |
return 0.0 | |
# encoded_input = tokenizer( | |
# text.split(), | |
# return_tensors='pt', | |
# padding=True, | |
# truncation=True, | |
# max_length=512 | |
# ) | |
# input_ids, attention_mask = encoded_input["input_ids"], encoded_input["attention_mask"] | |
# with torch.no_grad(): | |
# score = score_model(input_ids, attention_mask)[0].item() | |
# k = 20 | |
# midpoint = 0.7 | |
# scaled_score = 1 / (1 + np.exp(-k * (score - midpoint))) | |
# final_output = scaled_score * 100 | |
# return 1-final_output | |
text = preprocess_text(text) | |
encoded_input = tokenizer(text, return_tensors='pt') | |
output = model(**encoded_input) | |
scores = output[0][0].detach().numpy() | |
scores = softmax(scores) | |
ranking = np.argsort(scores) | |
ranking = ranking[::-1] | |
negative_id = -1 | |
for idx, label in config.id2label.items(): | |
if label.lower() == 'negative': | |
negative_id = idx | |
negative_score = scores[negative_id] | |
return (1-(float(negative_score)))*100 | |
# uses Polygon API to fetch article | |
def fetch_articles(ticker): | |
POLYGON_API_KEY = "cMCv7jipVvV4qLBikgzllNmW_isiODRR" | |
url = f"https://api.polygon.io/v2/reference/news?ticker={ticker}&limit=1&apiKey={POLYGON_API_KEY}" | |
print(f"[FETCH] {ticker}: {url}") | |
try: | |
response = requests.get(url, timeout=10) | |
response.raise_for_status() | |
data = response.json() | |
if data.get("results"): | |
article = data["results"][0] | |
title = article.get("title", "") | |
description = article.get("description", "") | |
return title + " " + description | |
return None | |
# checks specific HTTP errors | |
except requests.exceptions.HTTPError as http_err: | |
print(f"[ERROR] HTTP error for {ticker}: {http_err}") | |
return f"HTTP error when fetching {ticker}: {http_err}" | |
# catches any other error | |
except Exception as exc: | |
print(f"[ERROR] Unexpected error for {ticker}: {exc}") | |
return f"Error fetching articles for {ticker}: {exc}" | |
# initialize cache | |
sentiment_cache = {} | |
# checks if cache is valid | |
def is_cache_valid(cached_time, max_age_minutes=10): | |
if cached_time is None: | |
return False | |
now = datetime.datetime.utcnow() | |
age = now - cached_time | |
return age.total_seconds() < max_age_minutes * 60 | |
# analyzes the tikcers | |
def analyze_ticker(user_ticker: str): | |
user_ticker = user_ticker.upper().strip() | |
tickers_to_check = list({user_ticker, "SPY"}) | |
results = [] | |
for tk in tickers_to_check: | |
cached = sentiment_cache.get(tk, {}) | |
if cached and is_cache_valid(cached.get("timestamp")): | |
print(f"[CACHE] Using cached sentiment for {tk}") | |
results.append({**cached, "ticker": tk}) | |
continue | |
print(f"[INFO] Fetching fresh data for {tk}") | |
article_text = fetch_articles(tk) | |
if article_text is None: | |
sentiment_score = None | |
article_text = f"No news articles found for {tk}." | |
else: | |
sentiment_score = predict_sentiment(article_text) | |
timestamp = datetime.datetime.utcnow() | |
cache_entry = { | |
"article": article_text, | |
"sentiment": sentiment_score, | |
"timestamp": timestamp, | |
} | |
sentiment_cache[tk] = cache_entry | |
results.append({**cache_entry, "ticker": tk}) | |
# sort so user ticker appears first, SPY second | |
results.sort(key=lambda x: 0 if x["ticker"] == user_ticker else 1) | |
return results | |
def display_sentiment(results): | |
html = "<h2>Sentiment Analysis</h2><ul>" | |
for r in results: | |
ts_str = r["timestamp"].strftime("%Y-%m-%d %H:%M:%S UTC") | |
score_display = ( | |
f"{r['sentiment']:.2f}" | |
if r['sentiment'] is not None else | |
"—" | |
) | |
html += ( | |
f"<li><b>{r['ticker']}</b> ({ts_str})<br>" | |
f"{r['article']}<br>" | |
f"<i>Sentiment score:</i> {score_display}</li>" | |
) | |
html += "</ul>" | |
return html | |
with gr.Blocks() as demo: | |
gr.Markdown("# Ticker vs. SPY Sentiment Tracker") | |
input_box = gr.Textbox(label="Enter any ticker symbol (e.g., AAPL)") | |
output_html = gr.HTML() | |
run_btn = gr.Button("Analyze") | |
def _placeholder(t): | |
return f"<h3>Gathering latest articles for {t.upper()} and SPY … please wait.</h3>" | |
run_btn.click(_placeholder, inputs=input_box, outputs=output_html, queue=False).then( | |
lambda t: display_sentiment(analyze_ticker(t)), | |
inputs=input_box, | |
outputs=output_html, | |
) | |
demo.launch() |