drscotthawley commited on
Commit
3cf4680
1 Parent(s): 6dfde8b

more code full gui

Browse files
Files changed (1) hide show
  1. app.py +97 -4
app.py CHANGED
@@ -59,10 +59,103 @@ def count_notes_in_mask(img, mask):
59
  return new_notes.item()
60
 
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
- def greet(name):
64
- return "Hello " + name + "!!"
 
 
 
 
 
65
 
66
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
67
- demo.launch()
68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  return new_notes.item()
60
 
61
 
62
+ def grab_dense_gen(init_img,
63
+ PREFIX,
64
+ num_to_gen=64,
65
+ busyness=100, # after ranking images by how many notes were in mask, which one should we grab?
66
+ ):
67
+ df = None
68
+ mask = infer_mask_from_init_img(init_img, mask_with='grey')
69
+ for num in range(num_to_gen):
70
+ filename = f'{PREFIX}_{num:05d}.png'
71
+ gen_img = Image.open(filename)
72
+ gen_img_rect = square_to_rect(gen_img)
73
+ new_notes = count_notes_in_mask(gen_img, mask)
74
+ if df is None:
75
+ df = pd.DataFrame([[filename, new_notes, gen_img_rect]], columns=['filename', 'new_notes', 'img_rect'])
76
+ else:
77
+ df = pd.concat([df, pd.DataFrame([[filename, new_notes, gen_img_rect]], columns=['filename', 'new_notes', 'img_rect'])], ignore_index=True)
78
 
79
+ # sort df by new_notes column,
80
+ df = df.sort_values(by='new_notes', ascending=True)
81
+ grab_index = (len(df)-1)*busyness//100
82
+ print("grab_index = ", grab_index)
83
+ dense_filename = df.iloc[grab_index]['filename']
84
+ print("Grabbing filename = ", dense_filename)
85
+ return dense_filename
86
 
 
 
87
 
88
+
89
+ def process_image(image, repaint, busyness):
90
+ # get image ready and execute sampler
91
+ print("image = ",image)
92
+ image = image['composite']
93
+ # if image is a numpy array convert to PIL
94
+ if isinstance(image, np.ndarray):
95
+ image = ToPILImage()(image)
96
+ image = image.convert("RGB").crop((0, 0, 512, 128))
97
+ image = rect_to_square( image )
98
+ #mask = infer_mask_from_init_img( image )
99
+ masked_img_file = 'gradio_masked_image.png' # TODO: could allow for clobber at scale
100
+ print("Saving masked image file to ", masked_img_file)
101
+ image.save(masked_img_file)
102
+ num = 64 # number of images to generate; we'll take the one with the most notes in the masked region
103
+ bs = num
104
+ repaint = repaint
105
+ seed_scale = 1.0
106
+ DEVICES = 'CUDA_VISIBLE_DEVICES=3'
107
+ USER = 'shawley'
108
+ RUN_HOME = f'/runs/{USER}/k-diffusion/pop909/full_chords'
109
+ CKPT = f'{RUN_HOME}/256_chords_00130000.pth'
110
+ PREFIX = 'gradiodemo'
111
+ # !echo {DEVICES} {CT_HOME} {CKPT} {PREFIX} {masked_img_file}
112
+ print("Reading init image from ", masked_img_file,", repaint = ",repaint)
113
+ cmd = f'/home/shawley/envs/hs/bin/python {CT_HOME}/sample.py --batch-size {bs} --checkpoint {CKPT} --config {CT_HOME}/configs/config_pop909_256x256_chords.json -n {num} --prefix {PREFIX} --init-image {masked_img_file} --steps=100 --repaint={repaint}'
114
+ print("Will run command: ", cmd)
115
+ args = cmd.split(' ')
116
+ #call(cmd, shell=True)
117
+ print("Calling: ", args)
118
+ return_value = call(args)
119
+ print("Return value = ", return_value)
120
+
121
+
122
+ # find gen'd image and convert to midi piano roll
123
+ #gen_file = f'{PREFIX}_00000.png'
124
+ gen_file = grab_dense_gen(image, PREFIX, num_to_gen=num)
125
+ gen_image = square_to_rect(Image.open(gen_file))
126
+ midi_file = img_file_2_midi_file(gen_file)
127
+ srcdoc = MIDIPlayer(midi_file, 300, styler=dark).html
128
+ srcdoc = srcdoc.replace("\"", "'")
129
+ html = f'''<iframe srcdoc="{srcdoc}" height="500" width="100%" title="Iframe Example"></iframe>'''
130
+
131
+
132
+ # convert the midi to audio too
133
+ audio_file = 'gradio_demo_out.mp3'
134
+ cmd = f'timidity {midi_file} -Ow -o {audio_file}'
135
+ print("Converting midi to audio with: ", cmd)
136
+ return_value = call(cmd.split(' '))
137
+ print("Return value = ", return_value)
138
+
139
+ return gen_image, html, audio_file
140
+
141
+ # def greet(name):
142
+ # return "Hello " + name + "!!"
143
+
144
+ # demo = gr.Interface(fn=greet, inputs="text", outputs="text")
145
+ # demo.launch()
146
+
147
+
148
+
149
+ demo = gr.Interface(fn=process_image,
150
+ inputs=[gr.ImageEditor(sources=["upload",'clipboard'], label="Input Piano Roll Image (White = Gen Notes Here)", value=make_dict('all_black.png'), brush=gr.Brush(colors=["#FFFFFF","#000000"])),
151
+ gr.Slider(minimum=1, maximum=10, step=1, value=2, label="RePaint (Larger = More Notes, But Crazier. Also Slower.)"),
152
+ gr.Slider(minimum=1, maximum=100, step=1, value=100, label="Busy-ness Percentile (Based on Notes Generated)")],
153
+ outputs=[gr.Image(width=512, height=128, label='Generated Piano Roll Image'),
154
+ gr.HTML(label="MIDI Player"),
155
+ gr.Audio(label="MIDI as Audio")],
156
+ examples= [[make_dict(y),1,100] for y in ['all_white.png','all_black.png','init_img_melody.png','init_img_accomp.png','init_img_cont.png',]]+
157
+ [[make_dict(x),2,100] for x in ['584_TOTAL_crop.png', '780_TOTAL_crop_bg.png', '780_TOTAL_crop_draw.png','loop_middle_2.png']]+
158
+ [[make_dict(z),3,100] for z in ['584_TOTAL_crop_draw.png','loop_middle.png']] +
159
+ [[make_dict('ismir_mask_2.png'),6,100]],
160
+ )
161
+ demo.queue().launch()