Music-Generator / app.py
nullHawk's picture
v0
e8ca4ee verified
raw
history blame
1.64 kB
import torch
import torch.nn as nn
import numpy as np
import gradio as gr
from model import MusicLSTM
from train import DataLoader, Config, generate_song as generate_ABC_notation
from utils import load_vocab
from convert import abc_to_audio
class GradioApp():
def __init__(self):
# Set up configuration and data
self.config = Config()
self.CHECKPOINT_FILE = "checkpoint/model.pth"
self.data_loader = DataLoader(self.config.INPUT_FILE, self.config)
self.checkpoint = torch.load(self.CHECKPOINT_FILE, weights_only=False)
char_idx, char_list = load_vocab()
self.model = MusicLSTM(
input_size=len(char_idx),
hidden_size=self.config.HIDDEN_SIZE,
output_size=len(char_idx),
)
self.model.load_state_dict(self.checkpoint)
self.model.eval()
#Setup Interface
self.input = gr.Button("")
self.output = gr.Audio(label="Generated Music")
# self.output = gr.Textbox("")
self.interface = gr.Interface(fn=self.generate_music, inputs=self.input, outputs=self.output, title="AI Music Generator", description="Generate a new song using a trained RNN model.")
def launch(self):
self.interface.launch()
def generate_music(self, input):
"""Generate a new song using the trained model."""
abc_notation = generate_ABC_notation(self.model, self.data_loader)
abc_notation = abc_notation.strip("<start>").strip("<end>").strip()
audio = abc_to_audio(abc_notation)
return audio
if __name__ == '__main__':
app = GradioApp()
app.launch()