File size: 1,158 Bytes
9181b11
6870fe5
 
adfd51e
 
9181b11
50c8092
9181b11
6870fe5
adfd51e
6870fe5
 
 
3e94719
8495c81
6870fe5
 
 
 
adfd51e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6870fe5
220e363
6870fe5
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import gradio as gr
import onnxruntime as rt
from transformers import AutoTokenizer
import torch
import json

tokenizer = AutoTokenizer.from_pretrained("distilroberta-base")

with open("genre_types_encoded.json", "r") as fp:
    encode_genre_types = json.load(fp)

genres = list(encode_genre_types.keys())


inf_session = rt.InferenceSession('genre-classifier-quantized.onnx')
input_name = inf_session.get_inputs()[0].name
output_name = inf_session.get_outputs()[0].name

def classify_movie_genre(summary):
    tokens = tokenizer(summary, padding='max_length', truncation=True, return_tensors="pt")
    input_ids = tokens['input_ids'][0].tolist()[:512]
    
    print("Input summary:", summary)
    print("Tokenized input:", input_ids)
    
    logits = inf_session.run([output_name], {input_name: [input_ids]})[0]
    
    logits = torch.FloatTensor(logits)
    probs = torch.sigmoid(logits)[0]
    
    print("Logits:", logits)
    print("Probabilities:", probs)
    
    return dict(zip(genres, map(float, probs))) 

label = gr.Label(num_top_classes=5)
iface = gr.Interface(fn=classify_movie_genre, inputs="text", outputs=label)
iface.launch(inline=False)