File size: 2,391 Bytes
d8eaf88
23e9852
61d09e6
 
 
10c55f7
61d09e6
fcda6f5
d8eaf88
fb6c0ad
e32d51c
d8eaf88
 
 
61d09e6
 
 
 
98d2c0f
61d09e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10c55f7
 
 
eac489d
6a7e68e
61d09e6
fb6c0ad
 
24dac97
fb6c0ad
e32d51c
23e9852
cdcc96d
 
fb6c0ad
61d09e6
98d2c0f
 
aa7a16e
98d2c0f
cdcc96d
0a1c1a2
9e2cd5a
fb6c0ad
 
 
 
 
9e2cd5a
 
 
fb6c0ad
23e9852
cdcc96d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import os
import gradio as gr
import cv2
import torch
import urllib.request
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

def calculate_depth(model_type, img):

    if not os.path.exists('temp'):
      os.system('mkdir temp')

    filename = "temp/image.jpg"

    img.save(filename, "JPEG")

    #model_type = "DPT_Hybrid" 
    midas = torch.hub.load("intel-isl/MiDaS", model_type)
    
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    midas.to(device)
    midas.eval()

    midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms")

    if model_type == "DPT_Large" or model_type == "DPT_Hybrid":
        transform = midas_transforms.dpt_transform
    else:
        transform = midas_transforms.small_transform

    img = cv2.imread(filename)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

    input_batch = transform(img).to(device)

    with torch.no_grad():
        prediction = midas(input_batch)

        prediction = torch.nn.functional.interpolate(
            prediction.unsqueeze(1),
            size=img.shape[:2],
            mode="bicubic",
            align_corners=False,
        ).squeeze()

    output = prediction.cpu().numpy()

    formatted = (output * 255 / np.max(output)).astype('uint8')
    out_im = Image.fromarray(formatted)
    out_im.save("temp/image_depth.jpeg", "JPEG")

    return f'temp/image_depth.jpeg'

def wiggle_effect(slider):

    return [f'temp/image_depth.jpeg',f'temp/image_depth.jpeg']



with gr.Blocks() as demo:
    gr.Markdown("Start typing below and then click **Run** to see the output.")
    inp = []

    midas_models = ["DPT_Large","DPT_Hybrid","MiDaS_small"]

    inp.append(gr.inputs.Dropdown(midas_models, default="MiDaS_small", label="Depth estimation model type"))

    with gr.Row():
        inp.append(gr.Image(type="pil", label="Input"))
        out = gr.Image(type="file", label="depth_estimation")
    btn = gr.Button("Calculate depth")
    btn.click(fn=calculate_depth, inputs=inp, outputs=out)

    inp = [gr.Slider(1,15, default = 2, label='StepCycles',step= 1)]
    with gr.Row():
        out = [ gr.Image(type="file", label="Output_images"), #TODO change to gallery 
                gr.Image(type="file", label="Output_wiggle")]
    btn = gr.Button("Generate Wigglegram")
    btn.click(fn=wiggle_effect, inputs=inp, outputs=out)

demo.launch()