waddaheaven's picture
Update app.py
50c8092 verified
raw
history blame contribute delete
No virus
1.16 kB
import gradio as gr
import onnxruntime as rt
from transformers import AutoTokenizer
import torch
import json
tokenizer = AutoTokenizer.from_pretrained("distilroberta-base")
with open("genre_types_encoded.json", "r") as fp:
encode_genre_types = json.load(fp)
genres = list(encode_genre_types.keys())
inf_session = rt.InferenceSession('genre-classifier-quantized.onnx')
input_name = inf_session.get_inputs()[0].name
output_name = inf_session.get_outputs()[0].name
def classify_movie_genre(summary):
tokens = tokenizer(summary, padding='max_length', truncation=True, return_tensors="pt")
input_ids = tokens['input_ids'][0].tolist()[:512]
print("Input summary:", summary)
print("Tokenized input:", input_ids)
logits = inf_session.run([output_name], {input_name: [input_ids]})[0]
logits = torch.FloatTensor(logits)
probs = torch.sigmoid(logits)[0]
print("Logits:", logits)
print("Probabilities:", probs)
return dict(zip(genres, map(float, probs)))
label = gr.Label(num_top_classes=5)
iface = gr.Interface(fn=classify_movie_genre, inputs="text", outputs=label)
iface.launch(inline=False)