Spaces:
Runtime error
Runtime error
import streamlit as st | |
import torch.nn.functional as F | |
import numpy as np | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
repo_name = "adrianmoses/autonlp-auto-nlp-lyrics-classification-19333717" | |
tokenizer = AutoTokenizer.from_pretrained(repo_name) | |
model = AutoModelForSequenceClassification.from_pretrained(repo_name) | |
labels = model.config.id2label | |
def predict(lyrics): | |
inputs = tokenizer(lyrics, padding=True, truncation=True, return_tensors="pt") | |
outputs = model(**inputs) | |
predictions = F.softmax(outputs.logits, dim=-1) | |
predictions = predictions.detach().numpy()[0] | |
predictions = predictions*100 | |
sorted_indexes = np.argsort(predictions) | |
return "These lyrics are {:.2f}% {}, {:.2f}% {} and {:.2f}% {}.".format( | |
predictions[sorted_indexes[-1]], labels[sorted_indexes[-1]], | |
predictions[sorted_indexes[-2]], labels[sorted_indexes[-2]], | |
predictions[sorted_indexes[-3]], labels[sorted_indexes[-3]]) | |
col1, col2 = st.columns(2) | |
lyrics = col1.text_area("Lyrics") | |
clicked = col1.button("Submit") | |
output = "" | |
if clicked: | |
output = predict(lyrics) | |
col2.write(output) | |