bharat-raghunathan's picture
Revert back to v0.91
32441b9
raw history blame
No virus
1.05 kB
import gradio as gr
import numpy as np
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
def lyrics_categories(input_text):
spotify_model = "spotify/autonlp-huggingface-demo-song-lyrics-18923587"
model = AutoModelForSequenceClassification.from_pretrained(spotify_model)
tokenizer = AutoTokenizer.from_pretrained(spotify_model)
labels = model.config.id2label
inputs = tokenizer(input_text, return_tensors="pt")
outputs = model(**inputs)
predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
predictions = predictions.detach().numpy()[0]
index_sorted = np.argsort(predictions)[::-1]
clean_outputs = {labels[idx]:str(predictions[idx]) for idx in index_sorted}
print(clean_outputs)
return clean_outputs
iface = gr.Interface(fn=lyrics_categories,
inputs=gr.inputs.Textbox(lines=20, placeholder="Enter song lyrics here...", label="Song Lyrics"),
outputs=gr.outputs.Label(num_top_classes=5, label="Lyrics Categories"))
iface.launch()