File size: 1,725 Bytes
b748937
 
 
 
 
68be913
f5a5094
b748937
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5a36cca
e494e7b
b748937
 
 
 
 
 
 
 
 
 
 
 
 
 
f5a5094
2b3745f
 
e494e7b
 
f5a5094
 
 
171ee2b
f5a5094
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from PIL import Image
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2 as cv
import gradio as gr

from transformers import DPTForDepthEstimation, DPTFeatureExtractor

model = DPTForDepthEstimation.from_pretrained(
    "Intel/dpt-hybrid-midas", low_cpu_mem_usage=True
)
feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-hybrid-midas")


def getImage(image):
    # prepare image for the model
    inputs = feature_extractor(images=image, return_tensors="pt")

    with torch.no_grad():
        outputs = model(**inputs)
        predicted_depth = outputs.predicted_depth

    # interpolate to original size
    prediction = torch.nn.functional.interpolate(
        predicted_depth.unsqueeze(1),
        size=image.size[::-1],
        mode="bicubic",
        align_corners=False,
    )

    # get the prediction in form of numpy array
    output = prediction.squeeze().cpu().numpy()
    formatted = (output * 255 / np.max(output)).astype("uint8")
    depth = Image.fromarray(formatted)
    depth = np.asarray(depth)
    
    
    # create blurred version of original image
    blurred = cv.GaussianBlur(np.asarray(image), (99, 99), sigmaX=0)

    # separate foreground from background
    ret, thresh = cv.threshold(depth, 0, 255, cv.THRESH_BINARY_INV + cv.THRESH_OTSU)

    image = np.asarray(image)

    x, y, c = image.shape
    for i in range(c):
        for j in range(x):
            for k in range(y):
                if thresh[j][k] == 0:
                    blurred[j][k][i] = image[j][k][i]

    image = blurred
    return image
    
    #return depth


demo = gr.Interface(
    getImage,
    gr.Image(type="pil"),
    "image"
)

if __name__ == "__main__":
    demo.launch()