johnrobinsn commited on
Commit
831c6b2
1 Parent(s): a69b761

Initial Push

Browse files
__pycache__/depth_viewer.cpython-39.pyc ADDED
Binary file (5.31 kB). View file
 
app.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import DPTFeatureExtractor, DPTForDepthEstimation
3
+ import torch
4
+ import numpy as np
5
+ from PIL import Image
6
+ import open3d as o3d
7
+ from pathlib import Path
8
+ from depth_viewer import depthviewer2html
9
+
10
+ feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-large")
11
+ model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large")
12
+
13
+ def process_image(image_path):
14
+ image_path = Path(image_path)
15
+ image = Image.open(image_path)
16
+
17
+ # prepare image for the model
18
+ encoding = feature_extractor(image, return_tensors="pt")
19
+
20
+ # forward pass
21
+ with torch.no_grad():
22
+ outputs = model(**encoding)
23
+ predicted_depth = outputs.predicted_depth
24
+
25
+ # interpolate to original size
26
+ prediction = torch.nn.functional.interpolate(
27
+ predicted_depth.unsqueeze(1),
28
+ size=image.size[::-1],
29
+ mode="bicubic",
30
+ align_corners=False,
31
+ ).squeeze()
32
+ output = prediction.cpu().numpy()
33
+ depth = (output * 255 / np.max(output)).astype('uint8')
34
+
35
+ h = depthviewer2html(image,depth)
36
+ return [h]
37
+
38
+ title = "3d Visualization of Depth Maps Generated using MiDaS"
39
+ description = "Improved 3D interactive depth viewer using Three.js embedded in a Gradio app. For more details see the <a href='https://colab.research.google.com/drive/1l2l8U7Vhq9RnvV2tHyfhrXKNuHfmb4IP?usp=sharing'>Colab Notebook.</a>"
40
+ examples = [["examples/owl1.jpg"],['examples/marsattacks.jpg'],['examples/kitten.jpg']]
41
+
42
+ iface = gr.Interface(fn=process_image,
43
+ inputs=[gr.Image(type="filepath",label="Input Image")],
44
+ outputs=[gr.HTML(label='Depth Viewer',elem_id='depth-viewer')],
45
+ title=title,
46
+ description=description,
47
+ examples=examples,
48
+ allow_flagging="never",
49
+ cache_examples=False,
50
+ css='#depth-viewer: {height:300px;}')
51
+
52
+ iface.launch(debug=True, enable_queue=False)
depth_viewer.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import base64
3
+ import numpy as np
4
+
5
+ _viewer_html = '''
6
+ <html>
7
+
8
+ <head>
9
+ <style>
10
+ body {
11
+ overflow: hidden;
12
+ margin: 0;
13
+ }
14
+ </style>
15
+ <script>
16
+ var image_url = "{{{image_url_marker}}}";
17
+ var depth_url = "{{{depth_url_marker}}}";
18
+
19
+ var blah;
20
+
21
+ function getImageData( image ) {
22
+ var canvas = document.createElement( 'canvas' );
23
+ canvas.width = image.width;
24
+ canvas.height = image.height;
25
+
26
+ var context = canvas.getContext( '2d' );
27
+ context.drawImage( image, 0, 0 );
28
+
29
+ return context.getImageData( 0, 0, image.width, image.height );
30
+ }
31
+
32
+ window.onload = (e) => {
33
+ var scene = new THREE.Scene();
34
+ scene.background = new THREE.Color( 0xffffff );
35
+ var camera = new THREE.PerspectiveCamera(60, window.innerWidth / window.innerHeight, 0.1, 1000);
36
+ camera.position.set(0, 0, 2);
37
+ var renderer = new THREE.WebGLRenderer({
38
+ antialias: true
39
+ });
40
+ renderer.setSize(window.innerWidth, window.innerHeight);
41
+ document.body.appendChild(renderer.domElement);
42
+
43
+ var controls = new THREE.OrbitControls(camera, renderer.domElement);
44
+
45
+ var reset = document.getElementById('reset')
46
+ reset.addEventListener('click', e => controls.reset())
47
+
48
+ function vertexShader() {
49
+ return document.getElementById("vshader").text
50
+ }
51
+
52
+ function fragmentShader() {
53
+ return document.getElementById("fshader").text
54
+ }
55
+
56
+ // The url for the image is passed in to the web application
57
+ //var p = new URLSearchParams(window.location.search);
58
+ //var image_url = 'https://tujot.com/thr/marsattacks.jpg';//p.get("image");
59
+ //var depth_url = '/depth?image='+image_url;
60
+
61
+
62
+ var texture = new THREE.TextureLoader().load(image_url, t => {
63
+ var w = t.image.width;
64
+ var h = t.image.height;
65
+ var max = Math.max(w, h);
66
+ var ar = w / h;
67
+
68
+ blah = getImageData(t.image);
69
+
70
+ console.log('texture:', getImageData(t.image).data)
71
+
72
+ var planeGeometry = new THREE.PlaneGeometry(w / max, h / max, w, h);
73
+ var depth = new THREE.TextureLoader().load(depth_url);
74
+
75
+ uniforms = {
76
+ image: { type: "t", value: texture },
77
+ depth: { type: "t", value: depth },
78
+ ar: { type: 'f', value: ar }
79
+ }
80
+
81
+ let planeMaterial = new THREE.ShaderMaterial({
82
+ uniforms: uniforms,
83
+ fragmentShader: fragmentShader(),
84
+ vertexShader: vertexShader(),
85
+ side: THREE.DoubleSide
86
+ });
87
+
88
+ var points = new THREE.Points(planeGeometry, planeMaterial)
89
+
90
+ points.position.set(0, 0, 0)
91
+
92
+ scene.add(points)
93
+
94
+ render();
95
+ });
96
+
97
+ function render() {
98
+ requestAnimationFrame(render);
99
+ renderer.render(scene, camera);
100
+ }
101
+ }
102
+ </script>
103
+ </head>
104
+
105
+ <body>
106
+ <script src="https://threejs.org/build/three.min.js"></script>
107
+ <script src="https://threejs.org/examples/js/controls/OrbitControls.js"></script>
108
+
109
+ <script id="vshader" type="x-shader/x-vertex">
110
+ uniform sampler2D depth;
111
+ uniform float ar;
112
+ varying vec3 vUv;
113
+ vec3 pos;
114
+
115
+ void main() {
116
+ vUv = position;
117
+ pos = position;
118
+ pos.z = texture2D(depth,(vec2(vUv.x,vUv.y*ar)+0.5)).r;
119
+
120
+ float s = 2.0 - pos.z;
121
+ pos.x = pos.x * s;
122
+ pos.y = pos.y * s;
123
+
124
+ vec4 modelViewPosition = modelViewMatrix * vec4(pos, 1.0);
125
+ gl_Position = projectionMatrix * modelViewPosition;
126
+ gl_PointSize = 2.0;
127
+ }
128
+ </script>
129
+ <script id="fshader" type="x-shader/x-fragment">
130
+ uniform sampler2D image;
131
+ uniform float ar;
132
+ varying vec3 vUv;
133
+
134
+ void main() {
135
+ gl_FragColor = texture2D(image,(vec2(vUv.x,vUv.y*ar)+0.5));
136
+ }
137
+ </script>
138
+
139
+ <div style="position:absolute">
140
+ <button id="reset">Reset</button>
141
+ </div>
142
+ </body>
143
+
144
+ </html>
145
+ '''
146
+
147
+ image_url_marker = '{{{image_url_marker}}}'
148
+ depth_url_marker = '{{{depth_url_marker}}}'
149
+
150
+ def depthviewer2html(image,depth):
151
+ image_rgb = cv2.cvtColor(np.array(image), cv2.COLOR_BGR2RGB)
152
+ _, buffer = cv2.imencode('.jpg',image_rgb)
153
+ image_data_url = 'data:image/jpeg;base64,'+base64.b64encode(buffer).decode('utf-8')
154
+ _, buffer = cv2.imencode('.png',np.array(depth))
155
+ mask_data_url = 'data:image/png;base64,'+base64.b64encode(buffer).decode('utf-8')
156
+ vhtml = str(_viewer_html).replace(image_url_marker,image_data_url).replace(depth_url_marker,mask_data_url)
157
+ e = base64.b64encode(bytes(vhtml,'utf-8')).decode('utf-8')
158
+ url = f'data:text/html;base64,{e}'
159
+ h = f'<iframe src="{url}" height="600" width="100%"></iframe>'
160
+ return h
examples/kitten.jpg ADDED
examples/marsattacks.jpg ADDED
examples/owl1.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch
2
+ git+https://github.com/nielsrogge/transformers.git@add_dpt_redesign#egg=transformers
3
+ numpy
4
+ Pillow
5
+ gradio==3.0b8
6
+ opencv-python