#======================================================================================= # https://huggingface.co/spaces/asigalov61/Imagen-POP-Music-Medley-Diffusion-Transformer #======================================================================================= import os import time as reqtime import datetime from pytz import timezone import torch from imagen_pytorch import Unet, Imagen, ImagenTrainer from imagen_pytorch.data import Dataset import spaces import gradio as gr import numpy as np import random import tqdm import TMIDIX import TPLOTS from midi_to_colab_audio import midi_to_colab_audio # ================================================================================================= @spaces.GPU def Generate_POP_Medley(input_num_medley_comps): print('=' * 70) print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT))) start_time = reqtime.time() print('=' * 70) print('Loading model...') DIM = 64 CHANS = 1 TSTEPS = 1000 DEVICE = 'cuda' # 'cpu' unet = Unet( dim = DIM, dim_mults = (1, 2, 4, 8), num_resnet_blocks = 1, channels=CHANS, layer_attns = (False, False, False, True), layer_cross_attns = False ) imagen = Imagen( condition_on_text = False, # this must be set to False for unconditional Imagen unets = unet, channels=CHANS, image_sizes = 128, timesteps = TSTEPS ) trainer = ImagenTrainer( imagen = imagen, split_valid_from_train = True # whether to split the validation dataset from the training ).to(DEVICE) print('=' * 70) print('Loading model checkpoint...') print('=' * 70) trainer.load('Imagen_POP909_64_dim_12638_steps_0.00983_loss.ckpt') print('=' * 70) print('Done!') print('=' * 70) print('Req number of medley compositions:', input_num_medley_comps) print('=' * 70) print('Generating...') images = trainer.sample(batch_size = input_num_medley_comps, return_pil_images = True) threshold = 128 imgs_array = [] for i in images: arr = np.array(i) farr = np.where(arr < threshold, 0, 1) imgs_array.append(farr) print('Done!') print('=' * 70) #=============================================================================== print('Converting images to scores...') medley_compositions_escores = [] for i in imgs_array: bmatrix = TPLOTS.images_to_binary_matrix([i]) score = TMIDIX.binary_matrix_to_original_escore_notes(bmatrix) medley_compositions_escores.append(score) print('Done!') print('=' * 70) print('Creating medley score...') medley_labels = ['Composition #' + str(i+1) for i in range(len(medley_compositions_escores))] medley_escore = TMIDIX.escore_notes_medley(medley_compositions_escores, medley_labels, pause_time_value=16) #=============================================================================== print('Rendering results...') print('=' * 70) print('Sample INTs', medley_escore[:15]) print('=' * 70) fn1 = "Imagen-POP-Music-Medley-Diffusion-Transformer-Composition" output_score, patches, overflow_patches = TMIDIX.patch_enhanced_score_notes(medley_escore) detailed_stats = TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(output_score, output_signature = 'Imagen POP Music Medley', output_file_name = fn1, track_name='Project Los Angeles', list_of_MIDI_patches=patches, timings_multiplier=256 ) new_fn = fn1+'.mid' audio = midi_to_colab_audio(new_fn, soundfont_path=soundfont, sample_rate=16000, volume_scale=10, output_for_gradio=True ) print('Done!') print('=' * 70) #======================================================== output_midi_title = str(fn1) output_midi_summary = str(output_score[:3]) output_midi = str(new_fn) output_audio = (16000, audio) output_plot = TMIDIX.plot_ms_SONG(output_score, plot_title=output_midi, return_plt=True, timings_multiplier=256) print('Output MIDI file name:', output_midi) print('Output MIDI title:', output_midi_title) print('Output MIDI summary:', output_midi_summary) print('=' * 70) #======================================================== print('-' * 70) print('Req end time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT))) print('-' * 70) print('Req execution time:', (reqtime.time() - start_time), 'sec') return output_midi_title, output_midi_summary, output_midi, output_audio, output_plot # ================================================================================================= if __name__ == "__main__": PDT = timezone('US/Pacific') print('=' * 70) print('App start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT))) print('=' * 70) soundfont = "SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2" app = gr.Blocks() with app: gr.Markdown("