|
|
import gradio as gr |
|
|
from transformers import BertForSequenceClassification, BertTokenizer |
|
|
import torch |
|
|
|
|
|
|
|
|
MODEL_NAME = "young476/LyricToGenre0607" |
|
|
model = BertForSequenceClassification.from_pretrained(MODEL_NAME) |
|
|
tokenizer = BertTokenizer.from_pretrained(MODEL_NAME) |
|
|
|
|
|
|
|
|
genre_labels = ["발라드", "댄스", "힙합", "록", "트로트", "R&B"] |
|
|
|
|
|
def predict_genre(lyrics): |
|
|
inputs = tokenizer(lyrics, return_tensors="pt", truncation=True, max_length=256) |
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs) |
|
|
pred_id = outputs.logits.argmax(dim=1).item() |
|
|
pred_label = genre_labels[pred_id] |
|
|
probs = torch.softmax(outputs.logits, dim=1).squeeze().tolist() |
|
|
prob_dict = {genre_labels[i]: float(probs[i]) for i in range(len(genre_labels))} |
|
|
return pred_label, prob_dict |
|
|
|
|
|
demo = gr.Interface( |
|
|
fn=predict_genre, |
|
|
inputs=gr.Textbox(lines=8, label="가사 입력"), |
|
|
outputs=[gr.Label(num_top_classes=1, label="예측 장르"), gr.Label(label="장르별 확률")], |
|
|
title="가사 기반 장르 분류기", |
|
|
description="한국 노래 가사를 입력하면 장르를 예측합니다." |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|