unknown commited on
Commit
be549be
1 Parent(s): d0aa01e

Add application file

Browse files
Files changed (1) hide show
  1. app.py +5 -9
app.py CHANGED
@@ -1,8 +1,7 @@
1
  import gradio as gr
2
  import onnxruntime as rt
3
  from transformers import AutoTokenizer
4
- import torch
5
- import json
6
 
7
  tokenizer = AutoTokenizer.from_pretrained("distilroberta-base")
8
 
@@ -15,17 +14,14 @@ inf_session = rt.InferenceSession('movie-classifier.onnx')
15
  input_name = inf_session.get_inputs()[0].name
16
  output_name = inf_session.get_outputs()[0].name
17
 
18
- def classify_movie_genre(Overview, num_top_classes=5):
19
  input_ids = tokenizer(Overview)['input_ids'][:512]
20
  logits = inf_session.run([output_name], {input_name: [input_ids]})[0]
21
  logits = torch.FloatTensor(logits)
22
  probs = torch.sigmoid(logits)[0]
23
-
24
- # Get the top N predicted genres
25
- top_genres = [genres[i] for i in probs.argsort(descending=True)[:num_top_classes]]
26
-
27
- return top_genres
28
 
29
- label = gr.outputs.Label()
30
  iface = gr.Interface(fn=classify_movie_genre, inputs="text", outputs=label)
31
  iface.launch(inline=False)
 
 
1
  import gradio as gr
2
  import onnxruntime as rt
3
  from transformers import AutoTokenizer
4
+ import torch, json
 
5
 
6
  tokenizer = AutoTokenizer.from_pretrained("distilroberta-base")
7
 
 
14
  input_name = inf_session.get_inputs()[0].name
15
  output_name = inf_session.get_outputs()[0].name
16
 
17
+ def classify_movie_genre(Overview):
18
  input_ids = tokenizer(Overview)['input_ids'][:512]
19
  logits = inf_session.run([output_name], {input_name: [input_ids]})[0]
20
  logits = torch.FloatTensor(logits)
21
  probs = torch.sigmoid(logits)[0]
22
+ return dict(zip(genres, map(float, probs)))
 
 
 
 
23
 
24
+ label = gr.Interface.Label(num_top_classes=5) # Use Label() instead of gr.outputs.Label()
25
  iface = gr.Interface(fn=classify_movie_genre, inputs="text", outputs=label)
26
  iface.launch(inline=False)
27
+