|
import os |
|
import praw |
|
import gradio as gr |
|
from transformers import TextClassificationPipeline, AutoModelForSequenceClassification, AutoTokenizer |
|
|
|
|
|
|
|
client_id = os.environ["client_id"] |
|
client_secret = os.environ["client_secret"] |
|
user_agent = os.environ["user_agent"] |
|
|
|
reddit = praw.Reddit(client_id =client_id, |
|
client_secret =client_secret, user_agent =user_agent) |
|
|
|
|
|
model_name = "ProsusAI/finbert" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) |
|
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels = 3) |
|
pipe = TextClassificationPipeline(model=model, tokenizer=tokenizer, max_length=64, truncation=True, padding = 'max_length') |
|
|
|
|
|
def reddit_analysis(subreddit_name, num_posts): |
|
|
|
local_score = 0 |
|
local_titles = [] |
|
subreddit = reddit.subreddit(subreddit_name) |
|
if int(num_posts) > 16: |
|
return "Number of posts should be less than 15" |
|
else: |
|
for post in subreddit.new(limit=int(num_posts)): |
|
|
|
prediction = pipe(post.title) |
|
local_titles.append(post.title) |
|
|
|
if prediction[0]["label"] == "negative": |
|
local_score-= prediction[0]["score"] |
|
elif prediction[0]["label"] == "positive": |
|
local_score+= prediction[0]["score"] |
|
|
|
titles_string = "\n".join(local_titles) |
|
|
|
return local_score, titles_string |
|
|
|
|
|
|
|
|
|
|
|
|
|
total_score = 0 |
|
text_list = [] |
|
def manual_analysis(text): |
|
|
|
global total_score |
|
prediction = pipe(text) |
|
|
|
text_list.append(text) |
|
if prediction[0]["label"] == "negative": |
|
total_score-= prediction[0]["score"] |
|
elif prediction[0]["label"] == "positive": |
|
total_score+= prediction[0]["score"] |
|
|
|
return prediction, total_score |
|
|
|
|
|
with gr.Blocks() as demo: |
|
with gr.Tab("Seperate Analysis"): |
|
first_title = """<p><h1 align="center" style="font-size: 24px;">Analyse texts manually</h1></p>""" |
|
gr.HTML(first_title) |
|
with gr.Row(): |
|
with gr.Column(): |
|
text = gr.Textbox(label="text") |
|
analyse = gr.Button("Analyse") |
|
|
|
|
|
with gr.Column(): |
|
label_score = gr.Textbox(label="Label/Score") |
|
average_score = gr.Textbox(label="Average Score") |
|
|
|
analyse.click(fn=manual_analysis, inputs=text, outputs=[label_score, average_score], api_name="Calc1") |
|
|
|
with gr.Tab("Mass Analysis"): |
|
second_title = """<p><h1 align="center" style="font-size: 24px;">Analyse latest posts from Reddit</h1></p>""" |
|
gr.HTML(second_title) |
|
with gr.Row(): |
|
with gr.Column(): |
|
subreddit_name = gr.Textbox(label="Subreddit Name") |
|
|
|
num_post = gr.Textbox(label="Number of Posts") |
|
analyse = gr.Button("Analyse") |
|
with gr.Column(): |
|
average_score = gr.Textbox(label="Average Score") |
|
tifu_titles = gr.Textbox(label="Tifu Titles") |
|
|
|
analyse.click(fn=reddit_analysis, inputs=[subreddit_name, num_post], outputs=[average_score, tifu_titles], api_name="Calc2") |
|
|
|
|
|
|
|
demo.launch() |