File size: 1,067 Bytes
8af7698
 
 
 
 
6a79179
8af7698
 
 
 
 
 
420c089
 
8af7698
 
 
 
 
 
 
 
 
 
 
420c089
 
 
8af7698
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch
import numpy as np
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSequenceClassification

repo_name = 'juliensimon/autonlp-song-lyrics-18753417'

tokenizer = AutoTokenizer.from_pretrained(repo_name)
model = AutoModelForSequenceClassification.from_pretrained(repo_name)
labels = model.config.id2label
print(labels)

def predict(lyrics):
	inputs = tokenizer(lyrics, padding=True, truncation=True, return_tensors="pt")
	outputs = model(**inputs)
	predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
	predictions = predictions.detach().numpy()[0]
	predictions = predictions*100
	print(predictions)
	sorted_indexes = np.argsort(predictions)
	return "These lyrics are {:.2f}% {}, {:.2f}% {} and {:.2f}% {}.".format(
		predictions[sorted_indexes[-1]], labels[sorted_indexes[-1]],
		predictions[sorted_indexes[-2]], labels[sorted_indexes[-2]],
		predictions[sorted_indexes[-3]], labels[sorted_indexes[-3]])

input = gr.inputs.Textbox(lines=20)

iface = gr.Interface(fn=predict, inputs=input, outputs="text")
iface.launch()