Tech-Meld commited on
Commit
8bd959a
1 Parent(s): 13bac9e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -37
app.py CHANGED
@@ -1,11 +1,10 @@
1
  import gradio as gr
2
  from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor
3
- from PIL import Image
4
  import spaces
5
  import torch
6
  import re
7
 
8
- model = PaliGemmaForConditionalGeneration.from_pretrained("gokaygokay/sd3-long-captioner").to("cpu").eval()
9
  processor = PaliGemmaProcessor.from_pretrained("gokaygokay/sd3-long-captioner")
10
 
11
  def modify_caption(caption: str) -> str:
@@ -26,55 +25,46 @@ def modify_caption(caption: str) -> str:
26
 
27
  def replace_fn(match):
28
  return replacers[match.group(0)]
29
-
30
  return re.sub(pattern, replace_fn, caption, count=1, flags=re.IGNORECASE)
31
 
 
32
  def create_captions_rich(images):
33
- # Debugging: Print out the type of 'images'
34
- print(f"Type of 'images': {type(images)}")
35
- if isinstance(images, tuple):
36
- print("Received a tuple, expected a file-like object.")
37
- # If it's a tuple, you can try accessing the first element as an example
38
- print(f"Type of 'images[0]': {type(images[0])}")
39
-
40
  captions = []
41
- for image_path in images:
42
- try:
43
- # If 'images' is a tuple, you might need to modify this part to extract the image file correctly
44
- image = Image.open(image_path).convert("RGB")
45
- prompt = "caption en"
46
- model_inputs = processor(text=prompt, images=image, return_tensors="pt").to("cpu")
47
- input_len = model_inputs["input_ids"].shape[-1]
48
-
49
- with torch.inference_mode():
50
- generation = model.generate(**model_inputs, max_new_tokens=256, do_sample=False)
51
- generation = generation[0][input_len:]
52
- decoded = processor.decode(generation, skip_special_tokens=True)
53
 
54
- modified_caption = modify_caption(decoded)
55
- captions.append(modified_caption)
56
- except Exception as e:
57
- captions.append(f"Error processing image: {e}")
58
  return captions
59
 
60
-
61
  css = """
62
  #mkd {
63
  height: 500px;
64
  overflow: auto;
65
- border: 8px solid #ccc;
66
  }
67
  """
68
 
69
  with gr.Blocks(css=css) as demo:
70
- gr.HTML("<h1><center>Finetuned PaliGemma for SD3 prompt generation.<center><h1>")
71
- with gr.Tab(label="Image to Prompt for SD3"):
72
- with gr.Row():
73
- with gr.Column():
74
- input_img = gr.Gallery(label="Input Images", type="pil", interactive=True)
75
- submit_btn = gr.Button(value="Start")
76
- output = gr.Textbox(label="Prompt", lines=10, interactive=True)
 
77
 
78
- submit_btn.click(create_captions_rich, [input_img], [output])
79
 
80
- demo.launch(debug=True)
 
1
  import gradio as gr
2
  from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor
 
3
  import spaces
4
  import torch
5
  import re
6
 
7
+ model = PaliGemmaForConditionalGeneration.from_pretrained("gokaygokay/sd3-long-captioner").to("cuda").eval()
8
  processor = PaliGemmaProcessor.from_pretrained("gokaygokay/sd3-long-captioner")
9
 
10
  def modify_caption(caption: str) -> str:
 
25
 
26
  def replace_fn(match):
27
  return replacers[match.group(0)]
28
+
29
  return re.sub(pattern, replace_fn, caption, count=1, flags=re.IGNORECASE)
30
 
31
+ @spaces.GPU
32
  def create_captions_rich(images):
 
 
 
 
 
 
 
33
  captions = []
34
+ prompt = "caption en"
35
+
36
+ for image in images:
37
+ model_inputs = processor(text=prompt, images=image, return_tensors="pt").to("cuda")
38
+ input_len = model_inputs["input_ids"].shape[-1]
39
+
40
+ with torch.inference_mode():
41
+ generation = model.generate(**model_inputs, max_new_tokens=256, do_sample=False)
42
+ generation = generation[0][input_len:]
43
+ decoded = processor.decode(generation, skip_special_tokens=True)
 
 
44
 
45
+ modified_caption = modify_caption(decoded)
46
+ captions.append(modified_caption)
47
+
 
48
  return captions
49
 
 
50
  css = """
51
  #mkd {
52
  height: 500px;
53
  overflow: auto;
54
+ border: 16px solid #ccc;
55
  }
56
  """
57
 
58
  with gr.Blocks(css=css) as demo:
59
+ gr.HTML("<h1><center>Fine-tuned PaliGemma for SD3 Image Guided Prompt Generation.<center><h1>")
60
+
61
+ with gr.Tab(label="Image to Prompt for SD3."):
62
+ with gr.Row():
63
+ with gr.Column():
64
+ input_imgs = gr.Image(label="Input Images", type="pil", tool="editor", interactive=True, multiple=True)
65
+ submit_btn = gr.Button(value="Start")
66
+ outputs = gr.Text(label="Prompts", interactive=False)
67
 
68
+ submit_btn.click(create_captions_rich, [input_imgs], [outputs])
69
 
70
+ demo.launch(debug=True)