APJ23 commited on
Commit
caab640
1 Parent(s): c6bdd32

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -6
app.py CHANGED
@@ -6,8 +6,6 @@ import random as r
6
  import asyncio
7
  import gradio as gr
8
 
9
- gr.Interface.load("models/APJ23/MultiHeaded_Sentiment_Analysis_Model").launch()
10
-
11
  tokenizer = AutoTokenizer.from_pretrained("APJ23/MultiHeaded_Sentiment_Analysis_Model")
12
  model = AutoModelForSequenceClassification.from_pretrained("APJ23/MultiHeaded_Sentiment_Analysis_Model")
13
 
@@ -20,8 +18,9 @@ classes = {
20
  5: 'Insult',
21
  6: 'Identity Hate'
22
  }
 
23
  @st.cache(allow_output_mutation=True)
24
- def prediction(tweet,model,tokenizer):
25
  inputs = tokenizer(tweet, return_tensors="pt", padding=True, truncation=True)
26
  outputs = model(**inputs)
27
  predicted_class = torch.argmax(outputs.logits, dim=1)
@@ -38,11 +37,10 @@ def create_table(predictions):
38
  return df
39
 
40
  st.title('Toxicity Prediction App')
41
- tweet=st.text_input('Enter a tweet to check for toxicity')
42
-
43
 
44
  if st.button('Predict'):
45
- predicted_class_label, predicted_prob = prediction(tweet, model, tokenizer))
46
  prediction_text = f'Prediction: {predicted_class_label} ({predicted_prob:.2f})'
47
  st.write(prediction_text)
48
  predictions = {tweet: (predicted_class_label, predicted_prob)}
 
6
  import asyncio
7
  import gradio as gr
8
 
 
 
9
  tokenizer = AutoTokenizer.from_pretrained("APJ23/MultiHeaded_Sentiment_Analysis_Model")
10
  model = AutoModelForSequenceClassification.from_pretrained("APJ23/MultiHeaded_Sentiment_Analysis_Model")
11
 
 
18
  5: 'Insult',
19
  6: 'Identity Hate'
20
  }
21
+
22
  @st.cache(allow_output_mutation=True)
23
+ def prediction(tweet, model, tokenizer):
24
  inputs = tokenizer(tweet, return_tensors="pt", padding=True, truncation=True)
25
  outputs = model(**inputs)
26
  predicted_class = torch.argmax(outputs.logits, dim=1)
 
37
  return df
38
 
39
  st.title('Toxicity Prediction App')
40
+ tweet = st.text_input('Enter a tweet to check for toxicity')
 
41
 
42
  if st.button('Predict'):
43
+ predicted_class_label, predicted_prob = prediction(tweet, model, tokenizer)
44
  prediction_text = f'Prediction: {predicted_class_label} ({predicted_prob:.2f})'
45
  st.write(prediction_text)
46
  predictions = {tweet: (predicted_class_label, predicted_prob)}