APJ23 commited on
Commit
13112c4
1 Parent(s): 3c380c2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -17
app.py CHANGED
@@ -1,9 +1,10 @@
1
  import streamlit as st
2
  import pandas as pd
3
  import torch
4
- import asyncio
 
5
  import gradio as gr
6
- from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
7
 
8
  gr.Interface.load("models/APJ23/MultiHeaded_Sentiment_Analysis_Model").launch()
9
 
@@ -19,15 +20,14 @@ classes = {
19
  5: 'Insult',
20
  6: 'Identity Hate'
21
  }
22
-
23
  @st.cache(allow_output_mutation=True)
24
- def predict_toxicity(tweet, model, tokenizer):
25
  inputs = tokenizer(tweet, return_tensors="pt", padding=True, truncation=True)
26
  outputs = model(**inputs)
27
  predicted_class = torch.argmax(outputs.logits, dim=1)
28
  predicted_prob = torch.softmax(outputs.logits, dim=1)[0][predicted_class].item()
29
  return classes[predicted_class], predicted_prob
30
-
31
  def create_table(predictions):
32
  data = {'Tweet': [], 'Highest Toxicity Class': [], 'Probability': []}
33
  for tweet, prediction in predictions.items():
@@ -37,25 +37,23 @@ def create_table(predictions):
37
  df = pd.DataFrame(data)
38
  return df
39
 
40
- async def run_async_prediction(tweet, model, tokenizer):
41
- loop = asyncio.get_event_loop()
42
- prediction = await loop.run_in_executor(None, predict_toxicity, tweet, model, tokenizer)
43
- return prediction
44
-
45
  st.title('Toxicity Prediction App')
46
- tweet_input = st.text_input('Enter a tweet to check for toxicity')
47
 
48
- if st.button('Predict'):
49
- predictions = {tweet_input: None}
50
  loop = asyncio.new_event_loop()
51
  asyncio.set_event_loop(loop)
52
- prediction = loop.run_until_complete(run_async_prediction(tweet_input, model, tokenizer))
53
- predictions[tweet_input] = prediction
54
  loop.close()
 
 
 
 
55
 
56
- predicted_class_label, predicted_prob = prediction
 
57
  prediction_text = f'Prediction: {predicted_class_label} ({predicted_prob:.2f})'
58
  st.write(prediction_text)
59
-
60
  table = create_table(predictions)
61
  st.table(table)
 
1
  import streamlit as st
2
  import pandas as pd
3
  import torch
4
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
5
+ import random as r
6
  import gradio as gr
7
+ import asyncio
8
 
9
  gr.Interface.load("models/APJ23/MultiHeaded_Sentiment_Analysis_Model").launch()
10
 
 
20
  5: 'Insult',
21
  6: 'Identity Hate'
22
  }
 
23
  @st.cache(allow_output_mutation=True)
24
+ def predict_toxicity(tweet,model,tokenizer):
25
  inputs = tokenizer(tweet, return_tensors="pt", padding=True, truncation=True)
26
  outputs = model(**inputs)
27
  predicted_class = torch.argmax(outputs.logits, dim=1)
28
  predicted_prob = torch.softmax(outputs.logits, dim=1)[0][predicted_class].item()
29
  return classes[predicted_class], predicted_prob
30
+
31
  def create_table(predictions):
32
  data = {'Tweet': [], 'Highest Toxicity Class': [], 'Probability': []}
33
  for tweet, prediction in predictions.items():
 
37
  df = pd.DataFrame(data)
38
  return df
39
 
 
 
 
 
 
40
  st.title('Toxicity Prediction App')
41
+ tweet=st.text_input('Enter a tweet to check for toxicity')
42
 
43
+ async def predict_toxicity_async(tweet, model, tokenizer):
 
44
  loop = asyncio.new_event_loop()
45
  asyncio.set_event_loop(loop)
46
+ result = await loop.run_until_complete(predict_toxicity(tweet, model, tokenizer))
 
47
  loop.close()
48
+ return result
49
+
50
+ def predict_toxicity_sync(tweet, model, tokenizer):
51
+ return asyncio.run(predict_toxicity_async(tweet, model, tokenizer))
52
 
53
+ if st.button('Predict'):
54
+ predicted_class_label, predicted_prob = predict_toxicity_sync(tweet, model, tokenizer)
55
  prediction_text = f'Prediction: {predicted_class_label} ({predicted_prob:.2f})'
56
  st.write(prediction_text)
57
+ predictions = {tweet: (predicted_class_label, predicted_prob)}
58
  table = create_table(predictions)
59
  st.table(table)