z-uo commited on
Commit
9cfda96
1 Parent(s): cff9246
Files changed (2) hide show
  1. app.py +182 -0
  2. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import streamlit.components.v1 as components
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from torchvision import models
7
+ from torchvision.transforms import ToTensor
8
+
9
+ import numpy as np
10
+ from PIL import Image
11
+ import math
12
+ from obj2html import obj2html
13
+
14
+ minDepth=10
15
+ maxDepth=1000
16
+ def my_DepthNorm(x, maxDepth):
17
+ return maxDepth / x
18
+
19
+ def vete(v, vt):
20
+ if v == vt:
21
+ return str(v)
22
+ return str(v)+"/"+str(vt)
23
+
24
+ def create_obj(img, objPath='model.obj', mtlPath='model.mtl', matName='colored', useMaterial=False):
25
+ w = img.shape[1]
26
+ h = img.shape[0]
27
+
28
+ FOV = math.pi/4
29
+ D = (img.shape[0]/2)/math.tan(FOV/2)
30
+
31
+ if max(objPath.find('\\'), objPath.find('/')) > -1:
32
+ os.makedirs(os.path.dirname(mtlPath), exist_ok=True)
33
+
34
+ with open(objPath, "w") as f:
35
+ if useMaterial:
36
+ f.write("mtllib " + mtlPath + "\n")
37
+ f.write("usemtl " + matName + "\n")
38
+
39
+ ids = np.zeros((img.shape[1], img.shape[0]), int)
40
+ vid = 1
41
+
42
+ all_x = []
43
+ all_y = []
44
+ all_z = []
45
+
46
+ for u in range(0, w):
47
+ for v in range(h-1, -1, -1):
48
+
49
+ d = img[v, u]
50
+
51
+ ids[u, v] = vid
52
+ if d == 0.0:
53
+ ids[u, v] = 0
54
+ vid += 1
55
+
56
+ x = u - w/2
57
+ y = v - h/2
58
+ z = -D
59
+
60
+ norm = 1 / math.sqrt(x*x + y*y + z*z)
61
+
62
+ t = d/(z*norm)
63
+
64
+ x = -t*x*norm
65
+ y = t*y*norm
66
+ z = -t*z*norm
67
+
68
+ f.write("v " + str(x) + " " + str(y) + " " + str(z) + "\n")
69
+
70
+ for u in range(0, img.shape[1]):
71
+ for v in range(0, img.shape[0]):
72
+ f.write("vt " + str(u/img.shape[1]) +
73
+ " " + str(v/img.shape[0]) + "\n")
74
+
75
+ for u in range(0, img.shape[1]-1):
76
+ for v in range(0, img.shape[0]-1):
77
+
78
+ v1 = ids[u, v]
79
+ v3 = ids[u+1, v]
80
+ v2 = ids[u, v+1]
81
+ v4 = ids[u+1, v+1]
82
+
83
+ if v1 == 0 or v2 == 0 or v3 == 0 or v4 == 0:
84
+ continue
85
+
86
+ f.write("f " + vete(v1, v1) + " " +
87
+ vete(v2, v2) + " " + vete(v3, v3) + "\n")
88
+ f.write("f " + vete(v3, v3) + " " +
89
+ vete(v2, v2) + " " + vete(v4, v4) + "\n")
90
+
91
+ class UpSample(nn.Sequential):
92
+ def __init__(self, skip_input, output_features):
93
+ super(UpSample, self).__init__()
94
+ self.convA = nn.Conv2d(skip_input, output_features, kernel_size=3, stride=1, padding=1)
95
+ self.leakyreluA = nn.LeakyReLU(0.2)
96
+ self.convB = nn.Conv2d(output_features, output_features, kernel_size=3, stride=1, padding=1)
97
+ self.leakyreluB = nn.LeakyReLU(0.2)
98
+
99
+ def forward(self, x, concat_with):
100
+ up_x = F.interpolate(x, size=[concat_with.size(2), concat_with.size(3)], mode='bilinear', align_corners=True)
101
+ return self.leakyreluB( self.convB( self.convA( torch.cat([up_x, concat_with], dim=1) ) ) )
102
+
103
+ class Decoder(nn.Module):
104
+ def __init__(self, num_features=1664, decoder_width = 1.0):
105
+ super(Decoder, self).__init__()
106
+ features = int(num_features * decoder_width)
107
+
108
+ self.conv2 = nn.Conv2d(num_features, features, kernel_size=1, stride=1, padding=0)
109
+
110
+ self.up1 = UpSample(skip_input=features//1 + 256, output_features=features//2)
111
+ self.up2 = UpSample(skip_input=features//2 + 128, output_features=features//4)
112
+ self.up3 = UpSample(skip_input=features//4 + 64, output_features=features//8)
113
+ self.up4 = UpSample(skip_input=features//8 + 64, output_features=features//16)
114
+
115
+ self.conv3 = nn.Conv2d(features//16, 1, kernel_size=3, stride=1, padding=1)
116
+
117
+ def forward(self, features):
118
+ x_block0, x_block1, x_block2, x_block3, x_block4 = features[3], features[4], features[6], features[8], features[12]
119
+ x_d0 = self.conv2(F.relu(x_block4))
120
+
121
+ x_d1 = self.up1(x_d0, x_block3)
122
+ x_d2 = self.up2(x_d1, x_block2)
123
+ x_d3 = self.up3(x_d2, x_block1)
124
+ x_d4 = self.up4(x_d3, x_block0)
125
+ return self.conv3(x_d4)
126
+
127
+ class Encoder(nn.Module):
128
+ def __init__(self):
129
+ super(Encoder, self).__init__()
130
+ self.original_model = models.densenet169( pretrained=False )
131
+
132
+ def forward(self, x):
133
+ features = [x]
134
+ for k, v in self.original_model.features._modules.items(): features.append( v(features[-1]) )
135
+ return features
136
+
137
+ class PTModel(nn.Module):
138
+ def __init__(self):
139
+ super(PTModel, self).__init__()
140
+ self.encoder = Encoder()
141
+ self.decoder = Decoder()
142
+
143
+ def forward(self, x):
144
+ return self.decoder( self.encoder(x) )
145
+
146
+ model = PTModel().float()
147
+ path = "https://github.com/nicolalandro/DenseDepth/releases/download/0.1/nyu.pth"
148
+ model.load_state_dict(torch.hub.load_state_dict_from_url(path, progress=True))
149
+ model.eval()
150
+
151
+ def predict(inp):
152
+ torch_image = ToTensor()(inp)
153
+ images = torch_image.unsqueeze(0)
154
+
155
+ with torch.no_grad():
156
+ predictions = model(images)
157
+ output = np.clip(my_DepthNorm(predictions.numpy(), maxDepth=maxDepth), minDepth, maxDepth) / maxDepth
158
+ depth = output[0,0,:,:]
159
+
160
+ img = Image.fromarray(np.uint8(depth*255))
161
+
162
+ create_obj(depth, 'model.obj')
163
+ html_string = obj2html('model.obj', html_elements_only=True)
164
+
165
+ return img, html_string
166
+
167
+ st.title("Monocular Depth Estimation")
168
+
169
+ uploader = st.file_uploader('Upload your portrait here',type=['jpg','jpeg','png'])
170
+
171
+ if uploader is not None:
172
+ pil_image = Image.open(uploader)
173
+ pil_depth, html_string = predict(pil_image)
174
+
175
+ col1, col2 = st.columns(2)
176
+ with col1:
177
+ st.image(pil_image)
178
+ with col2:
179
+ st.image(pil_depth)
180
+
181
+ components.html(html_string)
182
+ st.markdown(html_string, unsafe_allow_html=True)
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ streamlit
4
+ obj2html>=0.13