liuhaotian commited on
Commit
4f97a73
1 Parent(s): f8dc7a7

Show error

Browse files
Files changed (1) hide show
  1. app.py +7 -2
app.py CHANGED
@@ -21,6 +21,9 @@ from datetime import datetime
21
  from huggingface_hub import hf_hub_download
22
  hf_hub_download = partial(hf_hub_download, library_name="gligen_demo")
23
 
 
 
 
24
 
25
  def load_from_hf(repo_id, filename='diffusion_pytorch_model.bin', subfolder=None):
26
  cache_file = hf_hub_download(repo_id=repo_id, filename=filename, subfolder=subfolder)
@@ -294,7 +297,9 @@ def generate(task, language_instruction, grounding_texts, sketch_pad,
294
  grounding_texts = [x.strip() for x in grounding_texts.split(';')]
295
  # assert len(boxes) == len(grounding_texts)
296
  if len(boxes) != len(grounding_texts):
297
- assert len(boxes) > len(grounding_texts)
 
 
298
  grounding_texts = grounding_texts + [""] * (len(boxes) - len(grounding_texts))
299
 
300
  boxes = (np.asarray(boxes) / 512).tolist()
@@ -749,6 +754,6 @@ with Blocks(
749
  )
750
 
751
  main.queue(concurrency_count=1, api_open=False)
752
- main.launch(share=False, show_api=False)
753
 
754
 
 
21
  from huggingface_hub import hf_hub_download
22
  hf_hub_download = partial(hf_hub_download, library_name="gligen_demo")
23
 
24
+ import sys
25
+ sys.tracebacklimit = 0
26
+
27
 
28
  def load_from_hf(repo_id, filename='diffusion_pytorch_model.bin', subfolder=None):
29
  cache_file = hf_hub_download(repo_id=repo_id, filename=filename, subfolder=subfolder)
 
297
  grounding_texts = [x.strip() for x in grounding_texts.split(';')]
298
  # assert len(boxes) == len(grounding_texts)
299
  if len(boxes) != len(grounding_texts):
300
+ if len(boxes) < len(grounding_texts):
301
+ raise ValueError("""The number of boxes should be equal to the number of grounding objects.
302
+ Number of boxes drawn: {}, number of grounding tokens: {}""".format(len(boxes), len(grounding_texts)))
303
  grounding_texts = grounding_texts + [""] * (len(boxes) - len(grounding_texts))
304
 
305
  boxes = (np.asarray(boxes) / 512).tolist()
 
754
  )
755
 
756
  main.queue(concurrency_count=1, api_open=False)
757
+ main.launch(share=False, show_api=False, show_error=True)
758
 
759