APJ23 commited on
Commit
3c380c2
1 Parent(s): 9e3035c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -12
app.py CHANGED
@@ -1,15 +1,12 @@
1
  import streamlit as st
2
  import pandas as pd
3
  import torch
4
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
5
- import random as r
6
  import asyncio
7
  import gradio as gr
 
8
 
9
  gr.Interface.load("models/APJ23/MultiHeaded_Sentiment_Analysis_Model").launch()
10
 
11
-
12
-
13
  tokenizer = AutoTokenizer.from_pretrained("APJ23/MultiHeaded_Sentiment_Analysis_Model", local_files_only=True)
14
  model = AutoModelForSequenceClassification.from_pretrained("APJ23/MultiHeaded_Sentiment_Analysis_Model")
15
 
@@ -22,14 +19,15 @@ classes = {
22
  5: 'Insult',
23
  6: 'Identity Hate'
24
  }
 
25
  @st.cache(allow_output_mutation=True)
26
- async def async_prediction(tweet,model,tokenizer):
27
  inputs = tokenizer(tweet, return_tensors="pt", padding=True, truncation=True)
28
  outputs = model(**inputs)
29
  predicted_class = torch.argmax(outputs.logits, dim=1)
30
  predicted_prob = torch.softmax(outputs.logits, dim=1)[0][predicted_class].item()
31
  return classes[predicted_class], predicted_prob
32
-
33
  def create_table(predictions):
34
  data = {'Tweet': [], 'Highest Toxicity Class': [], 'Probability': []}
35
  for tweet, prediction in predictions.items():
@@ -39,15 +37,25 @@ def create_table(predictions):
39
  df = pd.DataFrame(data)
40
  return df
41
 
 
 
 
 
 
42
  st.title('Toxicity Prediction App')
43
- tweet=st.text_input('Enter a tweet to check for toxicity')
44
- async def run_async_function():
45
- result = await async_prediction(tweet, model, tokenizer)
46
- return result
47
  if st.button('Predict'):
48
- predicted_class_label, predicted_prob = asyncio.run(run_async_function())
 
 
 
 
 
 
 
49
  prediction_text = f'Prediction: {predicted_class_label} ({predicted_prob:.2f})'
50
  st.write(prediction_text)
51
- predictions = {tweet: (predicted_class_label, predicted_prob)}
52
  table = create_table(predictions)
53
  st.table(table)
 
1
  import streamlit as st
2
  import pandas as pd
3
  import torch
 
 
4
  import asyncio
5
  import gradio as gr
6
+ from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
7
 
8
  gr.Interface.load("models/APJ23/MultiHeaded_Sentiment_Analysis_Model").launch()
9
 
 
 
10
  tokenizer = AutoTokenizer.from_pretrained("APJ23/MultiHeaded_Sentiment_Analysis_Model", local_files_only=True)
11
  model = AutoModelForSequenceClassification.from_pretrained("APJ23/MultiHeaded_Sentiment_Analysis_Model")
12
 
 
19
  5: 'Insult',
20
  6: 'Identity Hate'
21
  }
22
+
23
  @st.cache(allow_output_mutation=True)
24
+ def predict_toxicity(tweet, model, tokenizer):
25
  inputs = tokenizer(tweet, return_tensors="pt", padding=True, truncation=True)
26
  outputs = model(**inputs)
27
  predicted_class = torch.argmax(outputs.logits, dim=1)
28
  predicted_prob = torch.softmax(outputs.logits, dim=1)[0][predicted_class].item()
29
  return classes[predicted_class], predicted_prob
30
+
31
  def create_table(predictions):
32
  data = {'Tweet': [], 'Highest Toxicity Class': [], 'Probability': []}
33
  for tweet, prediction in predictions.items():
 
37
  df = pd.DataFrame(data)
38
  return df
39
 
40
+ async def run_async_prediction(tweet, model, tokenizer):
41
+ loop = asyncio.get_event_loop()
42
+ prediction = await loop.run_in_executor(None, predict_toxicity, tweet, model, tokenizer)
43
+ return prediction
44
+
45
  st.title('Toxicity Prediction App')
46
+ tweet_input = st.text_input('Enter a tweet to check for toxicity')
47
+
 
 
48
  if st.button('Predict'):
49
+ predictions = {tweet_input: None}
50
+ loop = asyncio.new_event_loop()
51
+ asyncio.set_event_loop(loop)
52
+ prediction = loop.run_until_complete(run_async_prediction(tweet_input, model, tokenizer))
53
+ predictions[tweet_input] = prediction
54
+ loop.close()
55
+
56
+ predicted_class_label, predicted_prob = prediction
57
  prediction_text = f'Prediction: {predicted_class_label} ({predicted_prob:.2f})'
58
  st.write(prediction_text)
59
+
60
  table = create_table(predictions)
61
  st.table(table)