MYTE commited on
Commit
d96b1a8
1 Parent(s): 925f7dd

entertainment genres app

Browse files
Files changed (2) hide show
  1. app.py +60 -4
  2. requirements.txt +4 -1
app.py CHANGED
@@ -1,9 +1,65 @@
1
  import gradio as gr
 
 
 
 
 
2
 
3
 
4
- def greet(name):
5
- return "Hello " + name + "!!"
 
 
6
 
 
7
 
8
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
9
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import json
3
+ import torch
4
+ from transformers import AutoTokenizer
5
+ import onnxruntime as rt
6
+ import platform
7
 
8
 
9
+ if platform.system() == "Windows":
10
+ import pathlib
11
+ temp = pathlib.PosixPath
12
+ pathlib.PosixPath = pathlib.WindowsPath
13
 
14
+ model_path = "entertainment-genre-quantized.onnx"
15
 
16
+ with open("genre_types_encoded.json", "r") as file:
17
+ categories = json.load(file)
18
+
19
+ inf_session = rt.InferenceSession(model_path)
20
+ input_name = inf_session.get_inputs()[0].name
21
+ output_name = inf_session.get_outputs()[0].name
22
+
23
+ tokenizer = AutoTokenizer.from_pretrained("distilroberta-base")
24
+
25
+
26
+ def get_top_label(cat_dict, idx):
27
+ for key, value in cat_dict.items():
28
+ if idx == value:
29
+ return key
30
+
31
+
32
+ def get_top_probs(cat_probs, idx):
33
+ return cat_probs[idx]
34
+
35
+
36
+ def entertainment_genres(description):
37
+ input_ids = tokenizer(description)['input_ids'][:512]
38
+ probs = inf_session.run([output_name], {input_name: [input_ids]})[0]
39
+ top_3_indices = sorted(range(len(probs[0])), key=lambda idx: probs[0][idx], reverse=True)[:3]
40
+ cat_prob = torch.sigmoid(torch.FloatTensor(probs))[0]
41
+ print(cat_prob)
42
+
43
+ top_labels = []
44
+ for i in top_3_indices:
45
+ top_labels.append(get_top_label(categories, i))
46
+
47
+ top_probs = []
48
+ for i in top_3_indices:
49
+ top_probs.append(get_top_probs(cat_prob, i))
50
+
51
+ return dict(zip(top_labels, map(float, top_probs)))
52
+
53
+
54
+ example = [
55
+ ["March Of Soldiers is a real time strategy single player , It is a military game based on the player's skill and "
56
+ "the strength of his financial economy"],
57
+ ["When the menace known as the Joker wreaks havoc and chaos on the people of Gotham, Batman must accept one of "
58
+ "the greatest psychological and physical tests of his ability to fight injustice."]
59
+ ]
60
+
61
+
62
+ label = gr.outputs.Label(num_top_classes=3)
63
+
64
+ iface = gr.Interface(fn=entertainment_genres, inputs="text", outputs=label, examples=example)
65
+ iface.launch(inline=False)
requirements.txt CHANGED
@@ -1 +1,4 @@
1
- gradio
 
 
 
 
1
+ gradio==3.44.0
2
+ torch==2.0.1
3
+ transformers==4.33.1
4
+ onnxruntime==1.15.1