myn0908 commited on
Commit
8811405
1 Parent(s): 136e8a5

enhance prompModule

Browse files
S2I/commons/controller.py CHANGED
@@ -47,31 +47,27 @@ class Sketch2ImageController():
47
  self.pipe = Sketch2ImagePipeline()
48
  self.zero_options = zero_options
49
 
50
- def update_canvas(self, use_line, use_eraser):
51
- brush_size = 20 if use_eraser else 4
52
- _color = "#ffffff" if use_eraser else "#000000"
53
- return self.gr.update(brush_radius=brush_size, brush_color=_color, interactive=True)
54
-
55
- def upload_sketch(self, file):
56
- _img = Image.open(file.name).convert("L")
57
- return self.gr.update(value=_img, source="upload", interactive=True)
58
-
59
  @staticmethod
60
  def pil_image_to_data_uri(img, format="PNG"):
61
  buffered = BytesIO()
62
  img.save(buffered, format=format)
63
  img_str = base64.b64encode(buffered.getvalue()).decode()
64
  return f"data:image/{format.lower()};base64,{img_str}"
65
-
66
- def artwork(self, options, image, prompt, prompt_template, style_name, seed, val_r, faster, model_name, type_flag):
67
  self.load_pipeline(zero_options=options)
 
 
68
 
69
- prompt = prompt_template.replace("{prompt}", prompt)
 
 
 
70
 
71
- if type_flag == 'live-sketch':
72
- img = Image.fromarray(np.array(image["composite"])[:, :, -1])
73
- elif type_flag == 'url-sketch':
74
  img = image["composite"]
 
 
75
 
76
  img = img.convert("RGB")
77
  img = img.resize((512, 512))
@@ -84,14 +80,13 @@ class Sketch2ImageController():
84
  noise = torch.randn((1, 4, H // 8, W // 8), device=c_t.device)
85
 
86
  with torch.no_grad():
87
- output_image = self.pipe.generate(c_t, prompt, r=val_r, noise_map=noise, half_model=faster, model_name=model_name)
88
 
89
  output_pil = F.to_pil_image(output_image[0].cpu() * 0.5 + 0.5)
90
 
91
- if type_flag == 'live-sketch':
92
- input_uri = self.pil_image_to_data_uri(Image.fromarray(255 - np.array(img)))
93
- else:
94
- input_uri = self.pil_image_to_data_uri(img)
95
 
96
- return output_pil
97
- # , self.gr.update(link=input_uri)
 
47
  self.pipe = Sketch2ImagePipeline()
48
  self.zero_options = zero_options
49
 
 
 
 
 
 
 
 
 
 
50
  @staticmethod
51
  def pil_image_to_data_uri(img, format="PNG"):
52
  buffered = BytesIO()
53
  img.save(buffered, format=format)
54
  img_str = base64.b64encode(buffered.getvalue()).decode()
55
  return f"data:image/{format.lower()};base64,{img_str}"
56
+
57
+ def artwork(self, options, image, prompt, prompt_template, style_name, seed, val_r, faster, model_name, type_flag, prompt_quality):
58
  self.load_pipeline(zero_options=options)
59
+ prompt_enhanced = self.automatic_enhance_prompt(prompt, prompt_quality)
60
+ prompt_enhanced = prompt_template.replace("{prompt}", prompt_enhanced)
61
 
62
+ # if type_flag == 'live-sketch':
63
+ # img = Image.fromarray(np.array(image["composite"])[:, :, -1])
64
+ # elif type_flag == 'url-sketch':
65
+ # img = image["composite"]
66
 
67
+ if type_flag == 'URL':
 
 
68
  img = image["composite"]
69
+ else:
70
+ img = Image.fromarray(np.array(image["composite"])[:, :, -1])
71
 
72
  img = img.convert("RGB")
73
  img = img.resize((512, 512))
 
80
  noise = torch.randn((1, 4, H // 8, W // 8), device=c_t.device)
81
 
82
  with torch.no_grad():
83
+ output_image = self.pipe.generate(c_t, prompt_enhanced, r=val_r, noise_map=noise, half_model=faster, model_name=model_name)
84
 
85
  output_pil = F.to_pil_image(output_image[0].cpu() * 0.5 + 0.5)
86
 
87
+ # if type_flag == 'live-sketch':
88
+ # input_uri = self.pil_image_to_data_uri(Image.fromarray(255 - np.array(img)))
89
+ # else:
90
+ # input_uri = self.pil_image_to_data_uri(img)
91
 
92
+ return output_pil
 
S2I/modules/models.py CHANGED
@@ -2,7 +2,7 @@ import torch
2
  import copy
3
  import os
4
  from diffusers import DDPMScheduler
5
- from transformers import AutoTokenizer, CLIPTextModel
6
  from diffusers import AutoencoderKL, UNet2DConditionModel
7
  from peft import LoraConfig
8
  from S2I.modules.utils import sc_vae_encoder_fwd, sc_vae_decoder_fwd, download_models, get_model_path, get_s2i_home
@@ -29,6 +29,8 @@ class PrimaryModel:
29
  self.global_tokenizer = None
30
  self.global_text_encoder = None
31
  self.global_scheduler = None
 
 
32
 
33
  @staticmethod
34
  def _load_model(path, model_class, unet_mode=False):
@@ -62,9 +64,14 @@ class PrimaryModel:
62
  sd = torch.load(p_ckpt, map_location="cpu")
63
  return sd
64
  def from_pretrained(self, model_name, r):
 
 
 
 
 
 
 
65
  if self.global_tokenizer is None:
66
- # self.global_tokenizer = AutoTokenizer.from_pretrained(self.backbone_diffusion_path,
67
- # subfolder="tokenizer")
68
  self.global_tokenizer = AutoTokenizer.from_pretrained("myn0908/stable-diffusion-3", subfolder="tokenizer_2")
69
 
70
  if self.global_text_encoder is None:
 
2
  import copy
3
  import os
4
  from diffusers import DDPMScheduler
5
+ from transformers import AutoTokenizer, CLIPTextModel, pipeline
6
  from diffusers import AutoencoderKL, UNet2DConditionModel
7
  from peft import LoraConfig
8
  from S2I.modules.utils import sc_vae_encoder_fwd, sc_vae_decoder_fwd, download_models, get_model_path, get_s2i_home
 
29
  self.global_tokenizer = None
30
  self.global_text_encoder = None
31
  self.global_scheduler = None
32
+ self.global_medium_prompt = None
33
+ self.global_long_prompt = None
34
 
35
  @staticmethod
36
  def _load_model(path, model_class, unet_mode=False):
 
64
  sd = torch.load(p_ckpt, map_location="cpu")
65
  return sd
66
  def from_pretrained(self, model_name, r):
67
+
68
+ if self.global_meidum_prompt is None:
69
+ self.global_medium_prompt = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance", device='cuda')
70
+
71
+ if self.global_long_prompt is None:
72
+ self.global_long_prompt = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance-Long", device='cuda')
73
+
74
  if self.global_tokenizer is None:
 
 
75
  self.global_tokenizer = AutoTokenizer.from_pretrained("myn0908/stable-diffusion-3", subfolder="tokenizer_2")
76
 
77
  if self.global_text_encoder is None:
S2I/modules/sketch2image.py CHANGED
@@ -72,6 +72,25 @@ class Sketch2ImagePipeline(PrimaryModel):
72
  self.global_unet.set_adapters(["default"], weights=[r])
73
  set_weights_and_activate_adapters(self.global_vae, ["vae_skip"], [r])
74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  def _move_to_cpu(self, module):
76
  module.to("cpu")
77
 
 
72
  self.global_unet.set_adapters(["default"], weights=[r])
73
  set_weights_and_activate_adapters(self.global_vae, ["vae_skip"], [r])
74
 
75
+ def automatic_enhance_prompt(self, input_prompt, model_choice):
76
+
77
+ if model_choice == "short-sentences":
78
+ result = self.global_medium_prompt("Enhance the description: " + input_prompt)
79
+ enhanced_text = result[0]['summary_text']
80
+
81
+ pattern = r'^.*?of\s+(.*?(?:\.|$))'
82
+ match = re.match(pattern, enhanced_text, re.IGNORECASE | re.DOTALL)
83
+
84
+ if match:
85
+ remaining_text = enhanced_text[match.end():].strip()
86
+ modified_sentence = match.group(1).capitalize()
87
+ enhanced_text = modified_sentence + ' ' + remaining_text
88
+ else:
89
+ result = self.global_long_prompt("Enhance the description: " + input_prompt)
90
+ enhanced_text = result[0]['summary_text']
91
+
92
+ return enhanced_text
93
+
94
  def _move_to_cpu(self, module):
95
  module.to("cpu")
96
 
app.py CHANGED
@@ -118,7 +118,7 @@ def get_meta_from_image(input_img, type_image):
118
  # Convert the processed image back to PIL Image
119
  img_pil = Image.fromarray(processed_img.astype('uint8'))
120
 
121
- return img_pil
122
 
123
 
124
  with gr.Blocks(css=css, theme="NoCrypt/miku@1.2.1") as demo:
@@ -267,10 +267,20 @@ with gr.Blocks(css=css, theme="NoCrypt/miku@1.2.1") as demo:
267
  clear_button = gr.Button("Reset Sketch Session", min_width=10, variant='primary')
268
  with gr.Accordion("S2I Advances Option", open=True):
269
  with gr.Row():
270
- input_type = gr.Radio(
271
- choices=["live-sketch", "url-sketch"],
272
- value="live-sketch",
273
- label="Type Sketch2Image models",
 
 
 
 
 
 
 
 
 
 
274
  interactive=True)
275
 
276
  style = gr.Dropdown(
@@ -307,7 +317,7 @@ with gr.Blocks(css=css, theme="NoCrypt/miku@1.2.1") as demo:
307
  queue=False,
308
  api_name=False,
309
  )
310
- inputs = [zero_gpu_options, image, prompt, prompt_temp, style, seed, val_r, half_model, model_options, input_type]
311
  outputs = [result]
312
  prompt.submit(fn=assign_gpu, inputs=inputs, outputs=outputs, api_name=False)
313
 
@@ -328,8 +338,8 @@ with gr.Blocks(css=css, theme="NoCrypt/miku@1.2.1") as demo:
328
  val_r.change(assign_gpu, inputs=inputs, outputs=outputs, queue=False, api_name=False)
329
  run_button.click(fn=assign_gpu, inputs=inputs, outputs=outputs, api_name=False)
330
  image.change(assign_gpu, inputs=inputs, outputs=outputs, queue=False, api_name=False)
331
- url_image.submit(fn=get_meta_from_image, inputs=[url_image, type_image], outputs=[image])
332
- url_image.change(fn=get_meta_from_image, inputs=[url_image, type_image], outputs=[image])
333
  if __name__ == '__main__':
334
  demo.queue()
335
  demo.launch(debug=True, share=False)
 
118
  # Convert the processed image back to PIL Image
119
  img_pil = Image.fromarray(processed_img.astype('uint8'))
120
 
121
+ return img_pil, 'URL'
122
 
123
 
124
  with gr.Blocks(css=css, theme="NoCrypt/miku@1.2.1") as demo:
 
267
  clear_button = gr.Button("Reset Sketch Session", min_width=10, variant='primary')
268
  with gr.Accordion("S2I Advances Option", open=True):
269
  with gr.Row():
270
+ # input_type = gr.Radio(
271
+ # choices=["live-sketch", "url-sketch"],
272
+ # value="live-sketch",
273
+ # label="Type Sketch2Image models",
274
+ # interactive=True)
275
+
276
+ input_type = gr.Textbox(
277
+ label="Check URL or Real-time Input",
278
+ interactive=True)
279
+
280
+ prompt_quality = gr.Radio(
281
+ choices=["short-sentences", "long-sentences"],
282
+ value="short-sentences",
283
+ label="Long/Short of Text Prompt",
284
  interactive=True)
285
 
286
  style = gr.Dropdown(
 
317
  queue=False,
318
  api_name=False,
319
  )
320
+ inputs = [zero_gpu_options, image, prompt, prompt_temp, style, seed, val_r, half_model, model_options, input_type, prompt_quality]
321
  outputs = [result]
322
  prompt.submit(fn=assign_gpu, inputs=inputs, outputs=outputs, api_name=False)
323
 
 
338
  val_r.change(assign_gpu, inputs=inputs, outputs=outputs, queue=False, api_name=False)
339
  run_button.click(fn=assign_gpu, inputs=inputs, outputs=outputs, api_name=False)
340
  image.change(assign_gpu, inputs=inputs, outputs=outputs, queue=False, api_name=False)
341
+ url_image.submit(fn=get_meta_from_image, inputs=[url_image, type_image], outputs=[image, input_type])
342
+ url_image.change(fn=get_meta_from_image, inputs=[url_image, type_image], outputs=[image, input_type])
343
  if __name__ == '__main__':
344
  demo.queue()
345
  demo.launch(debug=True, share=False)