WiggleGAN / app.py
Rodrigo_Cobo
fix issue with cycle half
ce682d7
import os
import gradio as gr
import cv2
import torch
import urllib.request
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import subprocess
def calculate_depth(model_type, gan_type, dim, slider, img):
if not os.path.exists('temp'):
os.system('mkdir temp')
filename = "Images/Input-Test/1.png"
img.save(filename, "PNG")
midas = torch.hub.load("intel-isl/MiDaS", model_type)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
midas.to(device)
midas.eval()
midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms")
if model_type == "DPT_Large" or model_type == "DPT_Hybrid":
transform = midas_transforms.dpt_transform
else:
transform = midas_transforms.small_transform
img = cv2.imread(filename)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
input_batch = transform(img).to(device)
with torch.no_grad():
prediction = midas(input_batch)
prediction = torch.nn.functional.interpolate(
prediction.unsqueeze(1),
size=img.shape[:2],
mode="bicubic",
align_corners=False,
).squeeze()
output = prediction.cpu().numpy()
formatted = (output * 255.0 / np.max(output)).astype('uint8')
out_im = Image.fromarray(formatted)
out_im.save("Images/Input-Test/1_d.png", "PNG")
c_images = '1'
name_output = 'out'
dict_saved_gans = {'Cycle': '74962_110', 'Cycle(half)': '66942_110','noCycle': '31219_110', 'noCycle-noCr': '92332_110', 'noCycle-noCr-noL1': '82122_110', 'OnlyGen': '70944_110' }
subprocess.run(["python", "main.py", "--gan_type", 'WiggleGAN', "--expandGen", "4", "--expandDis", "4", "--batch_size", c_images, "--cIm", c_images,
"--visdom", "false", "--wiggleDepth", str(slider), "--seedLoad", dict_saved_gans[gan_type], "--gpu_mode", "false", "--imageDim", dim, "--name_wiggle", name_output
])
subprocess.run(["python", "WiggleResults/split.py", "--dim", dim])
path_video = os.path.join(os.path.dirname(__file__), 'WiggleResults' , name_output + '_0.mp4')
print(path_video)
return [out_im,f'WiggleResults/' + name_output + '_0.gif', path_video, f'WiggleResults/'+ name_output + '.jpg']
with gr.Blocks() as demo:
gr.Markdown("Start typing below and then click **Run** to see the output.")
## Depth Estimation
midas_models = ["DPT_Large","DPT_Hybrid","MiDaS_small"]
gan_models = ["Cycle","Cycle(half)","noCycle","noCycle-noCr","noCycle-noCr-noL1","OnlyGen"]
dim = ['256','512','1024']
with gr.Row():
inp = [gr.inputs.Dropdown(midas_models, default="MiDaS_small", label="Depth estimation model type")]
inp.append(gr.inputs.Dropdown(gan_models, default="Cycle", label="Different GAN trainings"))
inp.append(gr.inputs.Dropdown(dim, default="256", label="Wiggle dimension result"))
inp.append(gr.Slider(1,15, default = 2, label='StepCycles',step= 1))
with gr.Row():
inp.append(gr.Image(type="pil", label="Input"))
out = [gr.Image(type="pil", label="depth_estimation")]
with gr.Row():
out.append(gr.Image(type="file", label="Output_wiggle_gif"))
out.append(gr.Video(label="Output_wiggle_video"))
out.append(gr.Image(type="file", label="Output_images"))
btn = gr.Button("Calculate depth + Wiggle")
btn.click(fn=calculate_depth, inputs=inp, outputs=out)
demo.launch()