File size: 5,064 Bytes
ae433b0
 
cf12b53
8579576
 
 
 
ae433b0
8579576
 
cf12b53
 
 
8579576
 
 
 
 
 
 
 
 
cf12b53
8579576
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ae433b0
8579576
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ae433b0
8579576
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import gradio as gr
import torch
import logging
from transformers import AutoTokenizer, AutoModel
from diffusers import DiffusionPipeline
import soundfile as sf
import numpy as np

# Set up logging to debug startup issues
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

try:
    # Load text tokenizer and embedding model (umt5-base)
    def load_text_processor():
        logger.info("Loading text processor (umt5-base)...")
        tokenizer = AutoTokenizer.from_pretrained("./umt5-base")
        text_model = AutoModel.from_pretrained(
            "./umt5-base",
            use_safetensors=True,
            torch_dtype=torch.float16,
            device_map="auto"
        )
        logger.info("Text processor loaded successfully.")
        return tokenizer, text_model

    # Load the transformer backbone (phantomstep_transformer)
    def load_transformer():
        logger.info("Loading transformer (phantomstep_transformer)...")
        transformer = DiffusionPipeline.from_pretrained(
            "./phantomstep_transformer",
            use_safetensors=True,
            torch_dtype=torch.float16,
            device_map="auto"
        )
        logger.info("Transformer loaded successfully.")
        return transformer

    # Load the DCAE for audio encoding/decoding (phantomstep_dcae)
    def load_dcae():
        logger.info("Loading DCAE (phantomstep_dcae)...")
        dcae = DiffusionPipeline.from_pretrained(
            "./phantomstep_dcae",
            use_safetensors=True,
            torch_dtype=torch.float16,
            device_map="auto"
        )
        logger.info("DCAE loaded successfully.")
        return dcae

    # Load the vocoder for audio synthesis (phantomstep_vocoder)
    def load_vocoder():
        logger.info("Loading vocoder (phantomstep_vocoder)...")
        vocoder = DiffusionPipeline.from_pretrained(
            "./phantomstep_vocoder",
            use_safetensors=True,
            torch_dtype=torch.float16,
            device_map="auto"
        )
        logger.info("Vocoder loaded successfully.")
        return vocoder

    # Generate music from a text prompt
    def generate_music(prompt, duration=20, seed=42):
        logger.info(f"Generating music with prompt: {prompt}, duration: {duration}, seed: {seed}")
        torch.manual_seed(seed)
        
        # Load all components
        tokenizer, text_model = load_text_processor()
        transformer = load_transformer()
        dcae = load_dcae()
        vocoder = load_vocoder()
        
        # Step 1: Process text prompt to embeddings
        logger.info("Processing text prompt to embeddings...")
        inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
        inputs = {k: v.to(text_model.device) for k, v in inputs.items()}
        with torch.no_grad():
            embeddings = text_model(**inputs).last_hidden_state.mean(dim=1)
        
        # Step 2: Pass embeddings through transformer
        logger.info("Generating with transformer...")
        transformer_output = transformer(
            embeddings,
            num_inference_steps=50,
            audio_length_in_s=duration
        ).audios[0]
        
        # Step 3: Decode audio features with DCAE
        logger.info("Decoding with DCAE...")
        dcae_output = dcae(
            transformer_output,
            num_inference_steps=50,
            audio_length_in_s=duration
        ).audios[0]
        
        # Step 4: Synthesize final audio with vocoder
        logger.info("Synthesizing with vocoder...")
        audio = vocoder(
            dcae_output,
            num_inference_steps=50,
            audio_length_in_s=duration
        ).audios[0]
        
        # Save audio to a file
        output_path = "output.wav"
        sf.write(output_path, audio, 22050)  # 22kHz sample rate
        logger.info("Music generation complete.")
        return output_path

    # Gradio interface
    logger.info("Setting up Gradio interface...")
    with gr.Blocks(title="PhantomStep: Text-to-Music Generation šŸŽµ") as demo:
        gr.Markdown("# PhantomStep by GhostAI šŸš€")
        gr.Markdown("Enter a text prompt to generate music! šŸŽ¶")
        
        prompt_input = gr.Textbox(label="Text Prompt", placeholder="A jazzy piano melody with a fast tempo")
        duration_input = gr.Slider(label="Duration (seconds)", minimum=10, maximum=60, value=20, step=1)
        seed_input = gr.Number(label="Random Seed", value=42, precision=0)
        generate_button = gr.Button("Generate Music")
        
        audio_output = gr.Audio(label="Generated Music")
        
        generate_button.click(
            fn=generate_music,
            inputs=[prompt_input, duration_input, seed_input],
            outputs=audio_output
        )

    logger.info("Launching Gradio app...")
    demo.launch(server_name="0.0.0.0", server_port=7860, ssr_mode=False)

except Exception as e:
    logger.error(f"Failed to start the application: {str(e)}")
    raise