APJ23 commited on
Commit
c3bf700
1 Parent(s): 13112c4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -9
app.py CHANGED
@@ -6,8 +6,6 @@ import random as r
6
  import gradio as gr
7
  import asyncio
8
 
9
- gr.Interface.load("models/APJ23/MultiHeaded_Sentiment_Analysis_Model").launch()
10
-
11
  tokenizer = AutoTokenizer.from_pretrained("APJ23/MultiHeaded_Sentiment_Analysis_Model", local_files_only=True)
12
  model = AutoModelForSequenceClassification.from_pretrained("APJ23/MultiHeaded_Sentiment_Analysis_Model")
13
 
@@ -21,7 +19,7 @@ classes = {
21
  6: 'Identity Hate'
22
  }
23
  @st.cache(allow_output_mutation=True)
24
- def predict_toxicity(tweet,model,tokenizer):
25
  inputs = tokenizer(tweet, return_tensors="pt", padding=True, truncation=True)
26
  outputs = model(**inputs)
27
  predicted_class = torch.argmax(outputs.logits, dim=1)
@@ -40,18 +38,15 @@ def create_table(predictions):
40
  st.title('Toxicity Prediction App')
41
  tweet=st.text_input('Enter a tweet to check for toxicity')
42
 
43
- async def predict_toxicity_async(tweet, model, tokenizer):
44
  loop = asyncio.new_event_loop()
45
  asyncio.set_event_loop(loop)
46
- result = await loop.run_until_complete(predict_toxicity(tweet, model, tokenizer))
47
  loop.close()
48
  return result
49
 
50
- def predict_toxicity_sync(tweet, model, tokenizer):
51
- return asyncio.run(predict_toxicity_async(tweet, model, tokenizer))
52
-
53
  if st.button('Predict'):
54
- predicted_class_label, predicted_prob = predict_toxicity_sync(tweet, model, tokenizer)
55
  prediction_text = f'Prediction: {predicted_class_label} ({predicted_prob:.2f})'
56
  st.write(prediction_text)
57
  predictions = {tweet: (predicted_class_label, predicted_prob)}
 
6
  import gradio as gr
7
  import asyncio
8
 
 
 
9
  tokenizer = AutoTokenizer.from_pretrained("APJ23/MultiHeaded_Sentiment_Analysis_Model", local_files_only=True)
10
  model = AutoModelForSequenceClassification.from_pretrained("APJ23/MultiHeaded_Sentiment_Analysis_Model")
11
 
 
19
  6: 'Identity Hate'
20
  }
21
  @st.cache(allow_output_mutation=True)
22
+ def prediction(tweet,model,tokenizer):
23
  inputs = tokenizer(tweet, return_tensors="pt", padding=True, truncation=True)
24
  outputs = model(**inputs)
25
  predicted_class = torch.argmax(outputs.logits, dim=1)
 
38
  st.title('Toxicity Prediction App')
39
  tweet=st.text_input('Enter a tweet to check for toxicity')
40
 
41
+ async def run_async_function(tweet, model, tokenizer):
42
  loop = asyncio.new_event_loop()
43
  asyncio.set_event_loop(loop)
44
+ result = await loop.run_in_executor(None, prediction, tweet, model, tokenizer)
45
  loop.close()
46
  return result
47
 
 
 
 
48
  if st.button('Predict'):
49
+ predicted_class_label, predicted_prob = asyncio.run(run_async_function(tweet, model, tokenizer))
50
  prediction_text = f'Prediction: {predicted_class_label} ({predicted_prob:.2f})'
51
  st.write(prediction_text)
52
  predictions = {tweet: (predicted_class_label, predicted_prob)}