Shreyas94's picture
Update app.py
a9bb295 verified
import feedparser
import urllib.parse
import newspaper
import functools
from transformers import pipeline, BartForConditionalGeneration, BartTokenizer
from sentence_transformers import SentenceTransformer, util
from datetime import datetime
from concurrent.futures import ThreadPoolExecutor
import pandas as pd
import time
import sys
import gradio as gr
# Define sentiment analysis pipeline
sentiment_analysis = pipeline("sentiment-analysis", model="ProsusAI/finbert")
# Load Sentence Transformer model
sentence_model = SentenceTransformer("all-MiniLM-L6-v2")
# Load BART model and tokenizer for detailed news summary
bart_model_name = "facebook/bart-large-cnn"
bart_model = BartForConditionalGeneration.from_pretrained(bart_model_name)
bart_tokenizer = BartTokenizer.from_pretrained(bart_model_name)
# Cache for storing fetched articles
article_cache = {}
def fetch_article(url):
"""Fetch article text from URL."""
if url not in article_cache:
article = newspaper.Article(url)
article.download()
article.parse()
article_cache[url] = article.text
return article_cache[url]
def fetch_and_analyze_news_entry(entry, company_name, company_ticker, location):
"""Fetch and analyze sentiment for a single news entry."""
title = entry.title
url = entry.link
domain = urllib.parse.urlparse(url).netloc # Extract domain from URL
publishing_date = entry.published_parsed # Extract publishing date
# Analyze sentiment regardless of article text availability
try:
label, score = analyze_sentiment(title)
sentiment_label = "Positive" if label == "positive" else "Negative" if label == "negative" else "Neutral"
except Exception as e:
print(f"Error analyzing sentiment for title: {title}. Error: {e}")
sentiment_label = "Unknown"
try:
# Fetch article text using caching
article_text = fetch_article(url)
except Exception as e:
print(f"Error fetching article at URL: {url}. Skipping article.")
return {
"title": title,
"url": url,
"domain": domain, # Include domain in the result
"location": location, # Include location in the result
"publishing_date": datetime.fromtimestamp(time.mktime(publishing_date)).strftime("%Y-%m-%d %H:%M:%S"), # Convert to normal date format
"sentiment": sentiment_label,
"detailed_summary": "Paywall Detected",
"similarity_score": calculate_similarity(company_name, company_ticker, title) # Calculate similarity based on title
}
# Generate detailed news summary using BART model
detailed_summary = news_detailed(article_text)
# Calculate sentence similarity
similarity_score = calculate_similarity(company_name, company_ticker, title)
return {
"title": title,
"url": url,
"domain": domain, # Include domain in the result
"location": location, # Include location in the result
"publishing_date": datetime.fromtimestamp(time.mktime(publishing_date)).strftime("%Y-%m-%d %H:%M:%S"), # Convert to normal date format
"sentiment": sentiment_label,
"detailed_summary": detailed_summary,
"similarity_score": similarity_score
}
def fetch_and_analyze_news(company_name, company_ticker, event_name, start_date=None, end_date=None, location=None, num_news=5, include_domains=None, exclude_domains=None):
"""Fetch and analyze news entries."""
# Constructing the Google News RSS feed URL
query_name = f"{company_name} {event_name} {location}"
# Add date range to the query if start_date and end_date are provided
if start_date and end_date:
query_name += f" after:{start_date} before:{end_date}"
# Add domain suggestions and exclusions to the query
if include_domains:
include_domains_query = " OR ".join(f"site:{domain.strip()}" for domain in include_domains)
query_name += f" {include_domains_query}"
if exclude_domains:
exclude_domains_query = " ".join(f"-site:{domain.strip()}" for domain in exclude_domains)
query_name += f" {exclude_domains_query}"
encoded_query_name = urllib.parse.quote(query_name)
rss_url_name = f"https://news.google.com/rss/search?q={encoded_query_name}"
# Parsing the RSS feed for company name
feed_name = feedparser.parse(rss_url_name)
news_entries_name = feed_name.entries[:num_news]
analyzed_news_name = []
# Fetch and analyze news entries for company name in parallel
with ThreadPoolExecutor() as executor:
analyze_news_entry_func = functools.partial(fetch_and_analyze_news_entry, company_name=company_name, company_ticker=company_ticker, location=location)
analyzed_news_name = list(executor.map(analyze_news_entry_func, news_entries_name))
return analyzed_news_name
def news_detailed(article_text, max_length=250):
"""Generate detailed news summary using BART model."""
inputs = bart_tokenizer([article_text], max_length=max_length, truncation=True, return_tensors="pt")
summary_ids = bart_model.generate(inputs["input_ids"], num_beams=4, max_length=max_length, length_penalty=2.0, early_stopping=True)
detailed_summary = bart_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
return detailed_summary
def calculate_similarity(company_name, company_ticker, title, threshold=0.4):
"""Calculate sentence similarity."""
company_name_prefix = f"News Regarding {company_name}"
embeddings_company_name = sentence_model.encode([company_name_prefix], convert_to_tensor=True)
embeddings_title = sentence_model.encode([title], convert_to_tensor=True)
similarity_score_company_name = util.pytorch_cos_sim(embeddings_company_name, embeddings_title).item()
weighted_similarity_score = similarity_score_company_name
return weighted_similarity_score
def analyze_sentiment(title):
print("Analyzing sentiment...")
# Perform sentiment analysis on the input title
result = sentiment_analysis(title)
# Extract sentiment label and score from the result
labels = result[0]['label']
scores = result[0]['score']
print("Sentiment analyzed successfully.")
return labels, scores
def calculate_title_similarity(news_list, company_name, company_ticker):
"""Calculate similarity score between news titles."""
similar_news = []
for news in news_list:
similarity_score = calculate_similarity(company_name, company_ticker, news['title'])
if similarity_score > 0.7:
similar_news.append(news)
return similar_news
def fetch_news(company_name, company_ticker, event_name, start_date, end_date, location, num_news, include_domains, exclude_domains):
analyzed_news_name = fetch_and_analyze_news(company_name, company_ticker, event_name, start_date, end_date, location, num_news, include_domains, exclude_domains)
above_threshold_news = [news for news in analyzed_news_name if news is not None and news['similarity_score'] >= 0.3]
below_threshold_news = [news for news in analyzed_news_name if news is not None and news['similarity_score'] < 0.3]
similar_news = calculate_title_similarity(above_threshold_news, company_name, company_ticker)
above_threshold_df = pd.DataFrame(above_threshold_news)
below_threshold_df = pd.DataFrame(below_threshold_news)
similar_news_df = pd.DataFrame(similar_news)
file_name = f"{company_name}_News_Data.xlsx"
with pd.ExcelWriter(file_name) as writer:
above_threshold_df.to_excel(writer, sheet_name='Above_Threshold', index=False)
below_threshold_df.to_excel(writer, sheet_name='Below_Threshold', index=False)
similar_news_df.to_excel(writer, sheet_name='Similar_News', index=False)
return file_name
# Gradio Interface
def gradio_fetch_news(company_name, company_ticker, event_name, start_date, end_date, location, num_news, include_domains, exclude_domains):
file_name = fetch_news(company_name, company_ticker, event_name, start_date, end_date, location, num_news, include_domains, exclude_domains)
return file_name
inputs = [
gr.Textbox(label="Company Name"),
gr.Textbox(label="Company Ticker"),
gr.Textbox(label="Event Name"),
gr.Textbox(label="Start Date (optional)"),
gr.Textbox(label="End Date (optional)"),
gr.Textbox(label="Location (optional)"),
gr.Number(label="Number of News to Fetch"),
gr.Textbox(label="Include Domains (comma-separated)", placeholder="e.g., example.com,example.org"),
gr.Textbox(label="Exclude Domains (comma-separated)", placeholder="e.g., example.net,example.info")
]
outputs = gr.File(label="Download Excel File")
interface = gr.Interface(
fn=gradio_fetch_news,
inputs=inputs,
outputs=outputs,
title="News Fetcher",
description="Fetch and analyze news articles based on company name, event, and other criteria, and download the results as an Excel file."
)
if __name__ == "__main__":
interface.launch()