PicturesOfMIDI / app.py
drscotthawley's picture
added more code
6dfde8b
raw
history blame
No virus
2.01 kB
# imports from gradio_demo.py
import gradio as gr
import numpy as np
from PIL import Image
import torch
from torchvision.transforms import ToTensor, ToPILImage
import sys
import os
from midi_player import MIDIPlayer
from midi_player.stylers import basic, cifka_advanced, dark
import numpy as np
from time import sleep
from subprocess import call
import pandas as pd
# imports from sample.py
import argparse
from pathlib import Path
import accelerate
import safetensors.torch as safetorch
#import torch
from tqdm import trange, tqdm
#from PIL import Image
from torchvision import transforms
import k_diffusion as K
from .pianoroll import regroup_lines, img_file_2_midi_file, square_to_rect, rect_to_square
from .square_to_rect import square_to_rect
def infer_mask_from_init_img(img, mask_with='grey'):
"note, this works whether image is normalized on 0..1 or -1..1, but not 0..255"
assert mask_with in ['blue','white','grey']
"given an image with mask areas marked, extract the mask itself"
if not torch.is_tensor(img):
img = ToTensor()(img)
print("img.shape: ", img.shape)
# shape of mask should be img shape without the channel dimension
if len(img.shape) == 3:
mask = torch.zeros(img.shape[-2:])
elif len(img.shape) == 2:
mask = torch.zeros(img.shape)
print("mask.shape: ", mask.shape)
if mask_with == 'white':
mask[ (img[0,:,:]==1) & (img[1,:,:]==1) & (img[2,:,:]==1)] = 1
elif mask_with == 'blue':
mask[img[2,:,:]==1] = 1 # blue
if mask_with == 'grey':
mask[ (img[0,:,:] != 0) & (img[0,:,:]==img[1,:,:]) & (img[2,:,:]==img[1,:,:])] = 1
return mask*1.0
def count_notes_in_mask(img, mask):
"counts the number of new notes in the mask"
img_t = ToTensor()(img)
new_notes = (mask * (img_t[1,:,:] > 0)).sum() # green channel
return new_notes.item()
def greet(name):
return "Hello " + name + "!!"
demo = gr.Interface(fn=greet, inputs="text", outputs="text")
demo.launch()