WiggleGAN / app.py
Rodrigo_Cobo
addapted some text stuff
9e2cd5a
raw history blame
No virus
2.39 kB
import os
import gradio as gr
import cv2
import torch
import urllib.request
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
def calculate_depth(model_type, img):
if not os.path.exists('temp'):
os.system('mkdir temp')
filename = "temp/image.jpg"
img.save(filename, "JPEG")
#model_type = "DPT_Hybrid"
midas = torch.hub.load("intel-isl/MiDaS", model_type)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
midas.to(device)
midas.eval()
midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms")
if model_type == "DPT_Large" or model_type == "DPT_Hybrid":
transform = midas_transforms.dpt_transform
else:
transform = midas_transforms.small_transform
img = cv2.imread(filename)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
input_batch = transform(img).to(device)
with torch.no_grad():
prediction = midas(input_batch)
prediction = torch.nn.functional.interpolate(
prediction.unsqueeze(1),
size=img.shape[:2],
mode="bicubic",
align_corners=False,
).squeeze()
output = prediction.cpu().numpy()
formatted = (output * 255 / np.max(output)).astype('uint8')
out_im = Image.fromarray(formatted)
out_im.save("temp/image_depth.jpeg", "JPEG")
return f'temp/image_depth.jpeg'
def wiggle_effect(slider):
return [f'temp/image_depth.jpeg',f'temp/image_depth.jpeg']
with gr.Blocks() as demo:
gr.Markdown("Start typing below and then click **Run** to see the output.")
inp = []
midas_models = ["DPT_Large","DPT_Hybrid","MiDaS_small"]
inp.append(gr.inputs.Dropdown(midas_models, default="MiDaS_small", label="Depth estimation model type"))
with gr.Row():
inp.append(gr.Image(type="pil", label="Input"))
out = gr.Image(type="file", label="depth_estimation")
btn = gr.Button("Calculate depth")
btn.click(fn=calculate_depth, inputs=inp, outputs=out)
inp = [gr.Slider(1,15, default = 2, label='StepCycles',step= 1)]
with gr.Row():
out = [ gr.Image(type="file", label="Output_images"), #TODO change to gallery
gr.Image(type="file", label="Output_wiggle")]
btn = gr.Button("Generate Wigglegram")
btn.click(fn=wiggle_effect, inputs=inp, outputs=out)
demo.launch()