Tech-Meld commited on
Commit
24d193b
1 Parent(s): da7f6e4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -12
app.py CHANGED
@@ -3,6 +3,7 @@ from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor
3
  import spaces
4
  import torch
5
  import re
 
6
 
7
  model = PaliGemmaForConditionalGeneration.from_pretrained("gokaygokay/sd3-long-captioner").to("cpu").eval()
8
  processor = PaliGemmaProcessor.from_pretrained("gokaygokay/sd3-long-captioner")
@@ -28,23 +29,31 @@ def modify_caption(caption: str) -> str:
28
 
29
  return re.sub(pattern, replace_fn, caption, count=1, flags=re.IGNORECASE)
30
 
31
- def create_captions_rich(images):
32
  captions = []
33
  prompt = "caption en"
34
 
35
- for image in images:
 
 
 
 
 
 
36
  model_inputs = processor(text=prompt, images=image, return_tensors="pt").to("cpu")
37
  input_len = model_inputs["input_ids"].shape[-1]
38
 
39
- with torch.inference_mode():
40
- generation = model.generate(**model_inputs, max_new_tokens=256, do_sample=False)
41
- generation = generation[0][input_len:]
42
- decoded = processor.decode(generation, skip_special_tokens=True)
43
-
44
- modified_caption = modify_caption(decoded)
45
- captions.append(modified_caption)
 
 
46
 
47
- return captions
48
 
49
  css = """
50
  #mkd {
@@ -60,10 +69,10 @@ with gr.Blocks(css=css) as demo:
60
  with gr.Tab(label="Image to Prompt for SD3"):
61
  with gr.Row():
62
  with gr.Column():
63
- input_imgs = gr.Image(label="Input Images", type="pil", interactive=True)
64
  submit_btn = gr.Button(value="Start")
65
  outputs = gr.Textbox(label="Prompts", lines=10, interactive=False)
66
 
67
- submit_btn.click(create_captions_rich, inputs=[input_imgs], outputs=[outputs])
68
 
69
  demo.launch(debug=True)
 
3
  import spaces
4
  import torch
5
  import re
6
+ from PIL import Image
7
 
8
  model = PaliGemmaForConditionalGeneration.from_pretrained("gokaygokay/sd3-long-captioner").to("cpu").eval()
9
  processor = PaliGemmaProcessor.from_pretrained("gokaygokay/sd3-long-captioner")
 
29
 
30
  return re.sub(pattern, replace_fn, caption, count=1, flags=re.IGNORECASE)
31
 
32
+ def create_captions_rich(files):
33
  captions = []
34
  prompt = "caption en"
35
 
36
+ for file_path in files:
37
+ try:
38
+ image = Image.open(file_path.name)
39
+ except Exception as e:
40
+ captions.append(f"Error opening image: {e}")
41
+ continue
42
+
43
  model_inputs = processor(text=prompt, images=image, return_tensors="pt").to("cpu")
44
  input_len = model_inputs["input_ids"].shape[-1]
45
 
46
+ try:
47
+ with torch.no_grad():
48
+ generation = model.generate(**model_inputs, max_new_tokens=256, do_sample=False)
49
+ generation = generation[0][input_len:]
50
+ decoded = processor.decode(generation, skip_special_tokens=True)
51
+ modified_caption = modify_caption(decoded)
52
+ captions.append(modified_caption)
53
+ except Exception as e:
54
+ captions.append(f"Error generating caption: {e}")
55
 
56
+ return "\n".join(captions)
57
 
58
  css = """
59
  #mkd {
 
69
  with gr.Tab(label="Image to Prompt for SD3"):
70
  with gr.Row():
71
  with gr.Column():
72
+ input_files = gr.Files(label="Input Images")
73
  submit_btn = gr.Button(value="Start")
74
  outputs = gr.Textbox(label="Prompts", lines=10, interactive=False)
75
 
76
+ submit_btn.click(create_captions_rich, inputs=[input_files], outputs=[outputs])
77
 
78
  demo.launch(debug=True)