hugoycj
commited on
Commit
•
e08a54a
1
Parent(s):
909ae17
Refactor model initialization and demo launch in app.py
Browse files
app.py
CHANGED
@@ -85,26 +85,6 @@ def infer(
|
|
85 |
temperature,
|
86 |
):
|
87 |
|
88 |
-
score_thresholds = [0.1, 0.2, 0.3, 0.4, 0.5]
|
89 |
-
|
90 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
91 |
-
|
92 |
-
parser = main_mcc.get_args_parser()
|
93 |
-
parser.set_defaults(eval=True)
|
94 |
-
|
95 |
-
args = parser.parse_args()
|
96 |
-
|
97 |
-
model = mcc_model.get_mcc_model(
|
98 |
-
occupancy_weight=1.0,
|
99 |
-
rgb_weight=0.01,
|
100 |
-
args=args,
|
101 |
-
)
|
102 |
-
|
103 |
-
if device == "cuda":
|
104 |
-
model = model.cuda()
|
105 |
-
|
106 |
-
misc.load_model(args=args, model_without_ddp=model, optimizer=None, loss_scaler=None)
|
107 |
-
|
108 |
rgb = image
|
109 |
obj = load_obj(point_cloud.name)
|
110 |
|
@@ -178,15 +158,33 @@ def infer(
|
|
178 |
|
179 |
return temp_file_name
|
180 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
181 |
|
182 |
-
demo = gr.Interface(fn=infer,
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
demo.launch(server_name="0.0.0.0", server_port=7860)
|
|
|
85 |
temperature,
|
86 |
):
|
87 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
rgb = image
|
89 |
obj = load_obj(point_cloud.name)
|
90 |
|
|
|
158 |
|
159 |
return temp_file_name
|
160 |
|
161 |
+
if __name__ == '__main__':
|
162 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
163 |
+
|
164 |
+
parser = main_mcc.get_args_parser()
|
165 |
+
parser.set_defaults(eval=True)
|
166 |
+
|
167 |
+
args = parser.parse_args()
|
168 |
+
|
169 |
+
model = mcc_model.get_mcc_model(
|
170 |
+
occupancy_weight=1.0,
|
171 |
+
rgb_weight=0.01,
|
172 |
+
args=args,
|
173 |
+
)
|
174 |
+
|
175 |
+
if device == "cuda":
|
176 |
+
model = model.cuda()
|
177 |
+
|
178 |
+
misc.load_model(args=args, model_without_ddp=model, optimizer=None, loss_scaler=None)
|
179 |
|
180 |
+
demo = gr.Interface(fn=infer,
|
181 |
+
inputs=[gr.Image(label="Input Image"),
|
182 |
+
gr.File(label="Pointcloud File"),
|
183 |
+
gr.File(label="Segmentation File"),
|
184 |
+
gr.Slider(minimum=0.05, maximum=0.5, step=0.05, value=0.2, label="Granularity"),
|
185 |
+
gr.Slider(minimum=0, maximum=1.0, step=0.1, value=0.1, label="Temperature")
|
186 |
+
],
|
187 |
+
outputs=[gr.outputs.File(label="Point Cloud")],
|
188 |
+
examples=[["demo/quest2.jpg", "demo/quest2.obj", "demo/quest2_seg.png", 0.2, 0.1]],
|
189 |
+
cache_examples=True)
|
190 |
+
demo.launch(server_name="0.0.0.0", server_port=7860)
|