hugoycj commited on
Commit
e08a54a
1 Parent(s): 909ae17

Refactor model initialization and demo launch in app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -31
app.py CHANGED
@@ -85,26 +85,6 @@ def infer(
85
  temperature,
86
  ):
87
 
88
- score_thresholds = [0.1, 0.2, 0.3, 0.4, 0.5]
89
-
90
- device = "cuda" if torch.cuda.is_available() else "cpu"
91
-
92
- parser = main_mcc.get_args_parser()
93
- parser.set_defaults(eval=True)
94
-
95
- args = parser.parse_args()
96
-
97
- model = mcc_model.get_mcc_model(
98
- occupancy_weight=1.0,
99
- rgb_weight=0.01,
100
- args=args,
101
- )
102
-
103
- if device == "cuda":
104
- model = model.cuda()
105
-
106
- misc.load_model(args=args, model_without_ddp=model, optimizer=None, loss_scaler=None)
107
-
108
  rgb = image
109
  obj = load_obj(point_cloud.name)
110
 
@@ -178,15 +158,33 @@ def infer(
178
 
179
  return temp_file_name
180
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
 
182
- demo = gr.Interface(fn=infer,
183
- inputs=[gr.Image(label="Input Image"),
184
- gr.File(label="Pointcloud File"),
185
- gr.File(label="Segmentation File"),
186
- gr.Slider(minimum=0.05, maximum=0.5, step=0.05, value=0.2, label="Granularity"),
187
- gr.Slider(minimum=0, maximum=1.0, step=0.1, value=0.1, label="Temperature")
188
- ],
189
- outputs=[gr.outputs.File(label="Point Cloud Json")],
190
- examples=[["demo/quest2.jpg", "demo/quest2.obj", "demo/quest2_seg.png", 0.2, 0.1]],
191
- cache_examples=True)
192
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
85
  temperature,
86
  ):
87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  rgb = image
89
  obj = load_obj(point_cloud.name)
90
 
 
158
 
159
  return temp_file_name
160
 
161
+ if __name__ == '__main__':
162
+ device = "cuda" if torch.cuda.is_available() else "cpu"
163
+
164
+ parser = main_mcc.get_args_parser()
165
+ parser.set_defaults(eval=True)
166
+
167
+ args = parser.parse_args()
168
+
169
+ model = mcc_model.get_mcc_model(
170
+ occupancy_weight=1.0,
171
+ rgb_weight=0.01,
172
+ args=args,
173
+ )
174
+
175
+ if device == "cuda":
176
+ model = model.cuda()
177
+
178
+ misc.load_model(args=args, model_without_ddp=model, optimizer=None, loss_scaler=None)
179
 
180
+ demo = gr.Interface(fn=infer,
181
+ inputs=[gr.Image(label="Input Image"),
182
+ gr.File(label="Pointcloud File"),
183
+ gr.File(label="Segmentation File"),
184
+ gr.Slider(minimum=0.05, maximum=0.5, step=0.05, value=0.2, label="Granularity"),
185
+ gr.Slider(minimum=0, maximum=1.0, step=0.1, value=0.1, label="Temperature")
186
+ ],
187
+ outputs=[gr.outputs.File(label="Point Cloud")],
188
+ examples=[["demo/quest2.jpg", "demo/quest2.obj", "demo/quest2_seg.png", 0.2, 0.1]],
189
+ cache_examples=True)
190
+ demo.launch(server_name="0.0.0.0", server_port=7860)