thiruvanth's picture
Update app.py
3853379 verified
def get_predictions(input_text: str) -> dict:
label2id = model.config.label2id
inputs = tokenizer(input_text, return_tensors='pt', truncation=True)
inputs = inputs.to(device)
outputs = model(**inputs)
logits = outputs.logits
sigmoid = torch.nn.Sigmoid()
probs = sigmoid(logits.squeeze().cpu())
probs = probs.detach().numpy()
for i, k in enumerate(label2id.keys()):
label2id[k] = probs[i]
label2id = {k: float(v) for k, v in sorted(label2id.items(), key=lambda item: item[1].item(), reverse=True)}
print(label2id)
return label2id
import gradio as gr
gr.Interface(
fn=get_predictions,
inputs=gr.components.Textbox(label='Input'),
theme = "darkdefault",
outputs=gr.components.Label(label='Predictions', num_top_classes=3),
allow_flagging='never'
).launch(debug='True')