unknown commited on
Commit
d0aa01e
1 Parent(s): df3fbfd

Add application file

Browse files
Files changed (1) hide show
  1. app.py +9 -5
app.py CHANGED
@@ -1,7 +1,8 @@
1
  import gradio as gr
2
  import onnxruntime as rt
3
  from transformers import AutoTokenizer
4
- import torch, json
 
5
 
6
  tokenizer = AutoTokenizer.from_pretrained("distilroberta-base")
7
 
@@ -14,14 +15,17 @@ inf_session = rt.InferenceSession('movie-classifier.onnx')
14
  input_name = inf_session.get_inputs()[0].name
15
  output_name = inf_session.get_outputs()[0].name
16
 
17
- def classify_movie_genre(Overview):
18
  input_ids = tokenizer(Overview)['input_ids'][:512]
19
  logits = inf_session.run([output_name], {input_name: [input_ids]})[0]
20
  logits = torch.FloatTensor(logits)
21
  probs = torch.sigmoid(logits)[0]
22
- return dict(zip(genres, map(float, probs)))
23
 
 
 
 
 
24
 
25
- label = gr.outputs.Label(num_top_classes=3)
26
  iface = gr.Interface(fn=classify_movie_genre, inputs="text", outputs=label)
27
- iface.launch(inline=False)
 
1
  import gradio as gr
2
  import onnxruntime as rt
3
  from transformers import AutoTokenizer
4
+ import torch
5
+ import json
6
 
7
  tokenizer = AutoTokenizer.from_pretrained("distilroberta-base")
8
 
 
15
  input_name = inf_session.get_inputs()[0].name
16
  output_name = inf_session.get_outputs()[0].name
17
 
18
+ def classify_movie_genre(Overview, num_top_classes=5):
19
  input_ids = tokenizer(Overview)['input_ids'][:512]
20
  logits = inf_session.run([output_name], {input_name: [input_ids]})[0]
21
  logits = torch.FloatTensor(logits)
22
  probs = torch.sigmoid(logits)[0]
 
23
 
24
+ # Get the top N predicted genres
25
+ top_genres = [genres[i] for i in probs.argsort(descending=True)[:num_top_classes]]
26
+
27
+ return top_genres
28
 
29
+ label = gr.outputs.Label()
30
  iface = gr.Interface(fn=classify_movie_genre, inputs="text", outputs=label)
31
+ iface.launch(inline=False)