File size: 3,941 Bytes
e706d2b
f43d01b
e706d2b
b9f0115
 
29de13b
 
 
 
5b63b45
29de13b
f43d01b
 
 
 
a5f7c39
f43d01b
dc5a8c5
a770601
 
 
 
 
 
998456b
f43d01b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77bb1b9
 
 
f43d01b
 
b9f0115
77bb1b9
e699699
 
 
 
23e1708
e699699
 
 
b9f0115
 
 
 
 
 
a83472e
 
b9f0115
e706d2b
5dc6bd9
8d19b59
 
 
 
 
 
 
 
 
 
 
 
 
 
b9f0115
5f32f95
5dc6bd9
b9f0115
5dc6bd9
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import torch
import os
from fastai.vision.all import *
import gradio as gr

############### HF ###########################

HF_TOKEN = os.getenv('HF_TOKEN')

hf_writer = gr.HuggingFaceDatasetSaver(HF_TOKEN, "savtadepth-flags-V2")

############## DVC ################################

PROD_MODEL_PATH = "src/models"
TRAIN_PATH = "src/data/processed/train/bathroom"
TEST_PATH = "src/data/processed/test/bathroom"

if  os.path.isdir(".dvc"):
    print("Running DVC")
    os.system("dvc config cache.type copy")
    os.system("dvc config core.no_scm true")
    if os.system(f"dvc pull {PROD_MODEL_PATH} {TRAIN_PATH } {TEST_PATH }") != 0:
        exit("dvc pull failed")
    os.system("rm -r .dvc")
# .apt/usr/lib/dvc			

############## Inference ##############################

class ImageImageDataLoaders(DataLoaders):
    """Basic wrapper around several `DataLoader`s with factory methods for Image to Image problems"""
    @classmethod
    @delegates(DataLoaders.from_dblock)
    def from_label_func(cls, path, filenames, label_func, valid_pct=0.2, seed=None, item_transforms=None,
                        batch_transforms=None, **kwargs):
        """Create from list of `fnames` in `path`s with `label_func`."""
        datablock = DataBlock(blocks=(ImageBlock(cls=PILImage), ImageBlock(cls=PILImageBW)),
                              get_y=label_func,
                              splitter=RandomSplitter(valid_pct, seed=seed),
                              item_tfms=item_transforms,
                              batch_tfms=batch_transforms)
        res = cls.from_dblock(datablock, filenames, path=path, **kwargs)
        return res


def get_y_fn(x):
    y = str(x.absolute()).replace('.jpg', '_depth.png')
    y = Path(y)

    return y


def create_data(data_path):
    fnames = get_files(data_path/'train', extensions='.jpg')
    data = ImageImageDataLoaders.from_label_func(data_path/'train', seed=42, bs=4, num_workers=0, filenames=fnames, label_func=get_y_fn)
    return data

data = create_data(Path('src/data/processed'))
learner = unet_learner(data,resnet34, metrics=rmse, wd=1e-2, n_out=3, loss_func=MSELossFlat(), path='src/')
learner.load('model')

def gen(input_img):
   return PILImageBW.create((learner.predict(input_img))[0]).convert('L')

################### Gradio Web APP ################################

title = "SavtaDepth WebApp"

description = """
<p>
<center>
Savta Depth is a collaborative Open Source Data Science project for monocular depth estimation - Turn 2d photos into 3d photos. To test the model and code please check out the link bellow.
<img src="https://huggingface.co/spaces/kingabzpro/savtadepth/resolve/main/examples/cover.png" alt="logo" width="250"/>
</center>
</p>
"""
article = "<p style='text-align: center'><a href='https://dagshub.com/OperationSavta/SavtaDepth' target='_blank'>SavtaDepth Project from OperationSavta</a></p><p style='text-align: center'><a href='https://colab.research.google.com/drive/1XU4DgQ217_hUMU1dllppeQNw3pTRlHy1?usp=sharing' target='_blank'>Google Colab Demo</a></p></center></p>"

examples = [
    ["examples/00008.jpg"],
    ["examples/00045.jpg"],
]
favicon = "examples/favicon.ico"
thumbnail = "examples/SavtaDepth.png"


def main():
	iface = gr.Interface(
	                     gen, 
	                     gr.inputs.Image(shape=(640,480),type='numpy'), 
	                     "image", 
	                     title = title, 
	                     flagging_options=["incorrect", "worst","ambiguous"], 
	                     allow_flagging = "manual",
	                     flagging_callback=hf_writer,
	                     description = description, 
	                     article = article, 
	                     examples = examples, 
	                     theme ="peach",
	                     allow_screenshot=True
	                     )

	iface.launch(enable_queue=True)
# enable_queue=True,auth=("admin", "pass1234")

if __name__ == '__main__':
	main()