File size: 1,851 Bytes
6e14e06
71e82b9
 
bbd0842
71e82b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
910b02d
71e82b9
 
 
a397132
910b02d
 
71e82b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
910b02d
71e82b9
 
 
910b02d
71e82b9
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import gradio as gr
import torch
import torchaudio
from torch import nn

# Load the saved generator model
class Generator(nn.Module):
    def __init__(self, latent_dim):
        super(Generator, self).__init__()
        self.generator = nn.Sequential(
            nn.Linear(latent_dim, 1024),
            nn.ReLU(),
            nn.Linear(1024, 4096),
            nn.ReLU(),
            nn.Linear(4096, 8192),
            nn.Tanh()
        )

    def forward(self, x):
        return self.generator(x)


latent_dim = 100
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator = Generator(latent_dim).to(device)
generator_model_path = 'noisyKickGAN/generator_model.pkl'
generator.load_state_dict(torch.load(generator_model_path, map_location=device))


def generate_kick_drums():
    # Define the number of samples you want to generate
    num_generated_samples = 3
    output_files = []

    # Generate new kick drum samples
    generator.eval()
    with torch.no_grad():
        for i in range(num_generated_samples):
            noise = torch.randn(1, latent_dim).to(device)
            generated_sample = generator(noise).squeeze().cpu()

            # Save the generated sample
            output_filename = f"generated_kick_{i+1}.wav"
            torchaudio.save(output_filename, generated_sample.unsqueeze(0), 16000)
            
            output_files.append(output_filename)

    return tuple(output_files)

# Define Gradio interface
def gradio_interface():
    generate_button = gr.Interface(fn=generate_kick_drums,
                                   inputs=None, 
                                   outputs=[gr.Audio(type='filepath', label=f"generated_kick_{i}") for i in range(3)], 
                                   live=True)
    generate_button.launch(debug=True)

# Run the Gradio interface
gradio_interface()