Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,658 Bytes
caabda6 59a2caa caabda6 500319a 5e340e8 500319a f084bcc 59a2caa 0a08063 b1e308f 6dfde8b d19a04f 6dfde8b 59a2caa 6dfde8b 3f93e88 6dfde8b 3f93e88 6dfde8b 3f93e88 6dfde8b 59a2caa 6dfde8b 59a2caa 3cf4680 6dfde8b 3cf4680 0a08063 f084bcc 28f8000 f084bcc 0a08063 59a2caa 3cf4680 5e340e8 59a2caa 5e340e8 3cf4680 8a80eb5 3cf4680 b46aa4b d19a04f 3cf4680 f084bcc 3664313 f084bcc 3cf4680 288c3b4 3cf4680 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 |
# imports from gradio_demo.py
import gradio as gr
import spaces
import numpy as np
from PIL import Image
import torch
from torchvision.transforms import ToTensor, ToPILImage
import sys
import os
from midi_player import MIDIPlayer
from midi_player.stylers import basic, cifka_advanced, dark
import numpy as np
from time import sleep
from subprocess import call
import pandas as pd
# imports from sample.py
import argparse
from pathlib import Path
import accelerate
import safetensors.torch as safetorch
#import torch
from tqdm import trange, tqdm
#from PIL import Image
from torchvision import transforms
import k_diffusion as K
# test natten import:
import natten
import accelerate
from sample import zero_wrapper
zero = torch.Tensor([0]).cuda()
print("Zero Device = ",zero.device," <-- this probably says cpu") # <-- 'cpu' 🤔
from pom.pianoroll import regroup_lines, img_file_2_midi_file, square_to_rect, rect_to_square
from pom.square_to_rect import square_to_rect
CT_HOME = '.'
@spaces.GPU
def infer_mask_from_init_img(img, mask_with='grey'):
"note, this works whether image is normalized on 0..1 or -1..1, but not 0..255"
assert mask_with in ['blue','white','grey']
"given an image with mask areas marked, extract the mask itself"
print("\n in infer_mask_from_init_img: ")
if not torch.is_tensor(img):
img = ToTensor()(img)
print(" img.shape: ", img.shape)
# shape of mask should be img shape without the channel dimension
if len(img.shape) == 3:
mask = torch.zeros(img.shape[-2:])
elif len(img.shape) == 2:
mask = torch.zeros(img.shape)
print(" mask.shape: ", mask.shape)
if mask_with == 'white':
mask[ (img[0,:,:]==1) & (img[1,:,:]==1) & (img[2,:,:]==1)] = 1
elif mask_with == 'blue':
mask[img[2,:,:]==1] = 1 # blue
if mask_with == 'grey':
mask[ (img[0,:,:] != 0) & (img[0,:,:]==img[1,:,:]) & (img[2,:,:]==img[1,:,:])] = 1
return mask*1.0
@spaces.GPU
def count_notes_in_mask(img, mask):
"counts the number of new notes in the mask"
img_t = ToTensor()(img)
new_notes = (mask * (img_t[1,:,:] > 0)).sum() # green channel
return new_notes.item()
@spaces.GPU
def grab_dense_gen(init_img,
PREFIX,
num_to_gen=64,
busyness=100, # after ranking images by how many notes were in mask, which one should we grab?
):
df = None
mask = infer_mask_from_init_img(init_img, mask_with='grey')
for num in range(num_to_gen):
filename = f'{PREFIX}_{num:05d}.png'
gen_img = Image.open(filename)
gen_img_rect = square_to_rect(gen_img)
new_notes = count_notes_in_mask(gen_img, mask)
if df is None:
df = pd.DataFrame([[filename, new_notes, gen_img_rect]], columns=['filename', 'new_notes', 'img_rect'])
else:
df = pd.concat([df, pd.DataFrame([[filename, new_notes, gen_img_rect]], columns=['filename', 'new_notes', 'img_rect'])], ignore_index=True)
# sort df by new_notes column,
df = df.sort_values(by='new_notes', ascending=True)
grab_index = (len(df)-1)*busyness//100
print("grab_index = ", grab_index)
dense_filename = df.iloc[grab_index]['filename']
print("Grabbing filename = ", dense_filename)
return dense_filename
# dummy class to make an args-like object
class Args:
def __init__(self, **kwargs):
for key, value in kwargs.items():
setattr(self, key, value)
def __repr__(self):
return f'Args({", ".join(f"{key}={value}" for key, value in self.__dict__.items())})'
def __str__(self):
return f'Args with attributes: {", ".join(f"{key}={value}" for key, value in self.__dict__.items())}'
@spaces.GPU
def process_image(image, repaint, busyness):
print("Process Image: Zero Device = ",zero.device)
accelerator = accelerate.Accelerator()
device = accelerator.device
print('Accelerator device:', device, flush=True)
# get image ready and execute sampler
print("image = ",image)
image = image['composite']
# if image is a numpy array convert to PIL
if isinstance(image, np.ndarray):
image = ToPILImage()(image)
image = image.convert("RGB").crop((0, 0, 512, 128))
image = rect_to_square( image )
#mask = infer_mask_from_init_img( image )
masked_img_file = 'gradio_masked_image.png' # TODO: could allow for clobber at scale
print("Saving masked image file to ", masked_img_file)
image.save(masked_img_file)
num = 64 # number of images to generate; we'll take the one with the most notes in the masked region
bs = num
repaint = repaint
seed_scale = 1.0
CT_HOME = '.'
CKPT = f'ckpt/256_chords_00130000.pth'
PREFIX = 'gradiodemo'
# !echo {DEVICES} {CT_HOME} {CKPT} {PREFIX} {masked_img_file}
print("Reading init image from ", masked_img_file,", repaint = ",repaint)
# cmd = f'{sys.executable} {CT_HOME}/sample.py --batch-size {bs} --checkpoint {CKPT} --config {CT_HOME}/configs/config_pop909_256x256_chords.json -n {num} --prefix {PREFIX} --init-image {masked_img_file} --steps=100 --repaint={repaint}'
# print("Will run command: ", cmd)
# args = cmd.split(' ')
# #call(cmd, shell=True)
# print("Calling: ", args,"\n")
# return_value = call(args)
# print("Return value = ", return_value)
args = Args(batch_size=bs, checkpoint=CKPT, config=f'{CT_HOME}/configs/config_pop909_256x256_chords.json', n=num, prefix=PREFIX, init_image=masked_img_file, steps=100, seed_scale=0.0, repaint=repaint)
print(" Now calling zero_wrapper with args = ",args,"\n")
zero_wrapper(args, accelerator, device)
# find gen'd image and convert to midi piano roll
#gen_file = f'{PREFIX}_00000.png'
gen_file = grab_dense_gen(image, PREFIX, num_to_gen=num)
gen_image = square_to_rect(Image.open(gen_file))
midi_file = img_file_2_midi_file(gen_file)
srcdoc = MIDIPlayer(midi_file, 300, styler=dark).html
srcdoc = srcdoc.replace("\"", "'")
html = f'''<iframe srcdoc="{srcdoc}" height="500" width="100%" title="Iframe Example"></iframe>'''
# convert the midi to audio too
audio_file = 'gradio_demo_out.mp3'
cmd = f'timidity {midi_file} -Ow -o {audio_file}'
print("Converting midi to audio with: ", cmd)
return_value = call(cmd.split(' '))
print("Return value = ", return_value)
return gen_image, html, audio_file
make_dict = lambda x: {'background':x, 'composite':x, 'layers':[x]}
demo = gr.Interface(fn=process_image,
inputs=[gr.ImageEditor(sources=["upload",'clipboard'], label="Input Piano Roll Image (White = Gen Notes Here)", value=make_dict('all_black.png'), brush=gr.Brush(colors=["#FFFFFF","#000000"])),
gr.Slider(minimum=1, maximum=10, step=1, value=2, label="RePaint (Larger = More Notes, But Crazier. Also Slower.)"),
gr.Slider(minimum=1, maximum=100, step=1, value=100, label="Busy-ness Percentile (Based on Notes Generated)")],
outputs=[gr.Image(width=512, height=128, label='Generated Piano Roll Image'),
gr.HTML(label="MIDI Player"),
gr.Audio(label="MIDI as Audio")],
examples= [[make_dict(y),1,100] for y in ['all_white.png','all_black.png','init_img_melody.png','init_img_accomp.png','init_img_cont.png',]]+
[[make_dict(x),2,100] for x in ['584_TOTAL_crop.png', '780_TOTAL_crop_bg.png', '780_TOTAL_crop_draw.png','loop_middle_2.png']]+
[[make_dict(z),3,100] for z in ['584_TOTAL_crop_draw.png','loop_middle.png']] +
[[make_dict('ismir_mask_2.png'),6,100]],
)
demo.queue().launch() |