02alexander commited on
Commit
35b6d9d
1 Parent(s): 34c96bd

fix pickle error

Browse files
Files changed (1) hide show
  1. app.py +42 -26
app.py CHANGED
@@ -168,31 +168,37 @@ def pipeline_callback(output_queue: SimpleQueue, pipe: Any, step_index: int, tim
168
  return callback_kwargs
169
 
170
  @spaces.GPU
171
- def generate_mvs(output_queue: SimpleQueue, input_image, sample_steps, sample_seed):
172
 
173
  seed_everything(sample_seed)
174
 
175
- z123_image = pipeline(
176
- input_image,
177
- num_inference_steps=sample_steps,
178
- callback_on_step_end=lambda *args, **kwargs: pipeline_callback(output_queue, *args, **kwargs),
179
- ).images[0]
 
 
180
 
181
- output_queue.put(("z123_image", z123_image))
182
-
183
- # sampling
184
- # z123_image = pipeline(
185
- # input_image,
186
- # num_inference_steps=sample_steps
187
- # ).images[0]
 
 
 
 
188
 
189
- # show_image = np.asarray(z123_image, dtype=np.uint8)
190
- # show_image = torch.from_numpy(show_image) # (960, 640, 3)
191
- # show_image = rearrange(show_image, '(n h) (m w) c -> (n m) h w c', n=3, m=2)
192
- # show_image = rearrange(show_image, '(n m) h w c -> (n h) (m w) c', n=2, m=3)
193
- # show_image = Image.fromarray(show_image.numpy())
 
194
 
195
- # return z123_image, show_image
196
 
197
  @spaces.GPU
198
  def make3d(output_queue: SimpleQueue, images: Image.Image):
@@ -292,13 +298,25 @@ def log_to_rr(input_image, do_remove_background, sample_steps, sample_seed):
292
 
293
  yield stream.read()
294
 
295
- output_queue = SimpleQueue()
296
 
297
- mvs_thread = threading.Thread(target=generate_mvs, args=[output_queue, input_image, sample_steps, sample_seed])
298
- mvs_thread.start()
299
 
300
- while True:
301
- msg = output_queue.get()
 
 
 
 
 
 
 
 
 
 
 
 
302
  if msg[0] == "z123_image":
303
  z123_image = msg[1]
304
  break
@@ -307,8 +325,6 @@ def log_to_rr(input_image, do_remove_background, sample_steps, sample_seed):
307
  entity = msg[2]
308
  rr.log(entity_path, entity)
309
  yield stream.read()
310
-
311
- mvs_thread.join()
312
 
313
  rr.log("z123image", rr.Image(z123_image))
314
  yield stream.read()
 
168
  return callback_kwargs
169
 
170
  @spaces.GPU
171
+ def generate_mvs(input_image, sample_steps, sample_seed):
172
 
173
  seed_everything(sample_seed)
174
 
175
+ def thread_target(output_queue, input_image, sample_steps):
176
+ z123_image = pipeline(
177
+ input_image,
178
+ num_inference_steps=sample_steps,
179
+ callback_on_step_end=lambda *args, **kwargs: pipeline_callback(output_queue, *args, **kwargs),
180
+ ).images[0]
181
+ output_queue.put(("z123_image", z123_image))
182
 
183
+ output_queue = SimpleQueue()
184
+ z123_thread = threading.Thread(
185
+ target=thread_target,
186
+ args=
187
+ [
188
+ output_queue,
189
+ input_image,
190
+ sample_steps,
191
+ ]
192
+ )
193
+ z123_thread.start()
194
 
195
+ while True:
196
+ msg = output_queue.get()
197
+ yield msg
198
+ if msg[0] == "z123_image":
199
+ break
200
+ z123_thread.join()
201
 
 
202
 
203
  @spaces.GPU
204
  def make3d(output_queue: SimpleQueue, images: Image.Image):
 
298
 
299
  yield stream.read()
300
 
301
+ # output_queue = SimpleQueue()
302
 
303
+ # mvs_thread = threading.Thread(target=generate_mvs, args=[output_queue, input_image, sample_steps, sample_seed])
304
+ # mvs_thread.start()
305
 
306
+ # while True:
307
+ # msg = output_queue.get()
308
+ # if msg[0] == "z123_image":
309
+ # z123_image = msg[1]
310
+ # break
311
+ # elif msg[0] == "log":
312
+ # entity_path = msg[1]
313
+ # entity = msg[2]
314
+ # rr.log(entity_path, entity)
315
+ # yield stream.read()
316
+
317
+ # mvs_thread.join()
318
+
319
+ for msg in generate_mvs(input_image, sample_steps, sample_seed):
320
  if msg[0] == "z123_image":
321
  z123_image = msg[1]
322
  break
 
325
  entity = msg[2]
326
  rr.log(entity_path, entity)
327
  yield stream.read()
 
 
328
 
329
  rr.log("z123image", rr.Image(z123_image))
330
  yield stream.read()