myn0908 commited on
Commit
cc3415b
1 Parent(s): 204fcce

update enhance prompt

Browse files
S2I/commons/controller.py CHANGED
@@ -58,16 +58,11 @@ class Sketch2ImageController():
58
  self.load_pipeline(zero_options=options)
59
  # prompt = prompt_template.replace("{prompt}", prompt)
60
 
61
- # if type_flag == 'live-sketch':
62
- # img = Image.fromarray(np.array(image["composite"])[:, :, -1])
63
- # elif type_flag == 'url-sketch':
64
- # img = image["composite"]
65
-
66
- if type_flag == 'URL':
67
- img = image["composite"]
68
- else:
69
  img = Image.fromarray(np.array(image["composite"])[:, :, -1])
70
-
 
 
71
  img = img.convert("RGB")
72
  img = img.resize((512, 512))
73
 
@@ -83,9 +78,5 @@ class Sketch2ImageController():
83
 
84
  output_pil = F.to_pil_image(output_image[0].cpu() * 0.5 + 0.5)
85
 
86
- # if type_flag == 'live-sketch':
87
- # input_uri = self.pil_image_to_data_uri(Image.fromarray(255 - np.array(img)))
88
- # else:
89
- # input_uri = self.pil_image_to_data_uri(img)
90
 
91
  return output_pil
 
58
  self.load_pipeline(zero_options=options)
59
  # prompt = prompt_template.replace("{prompt}", prompt)
60
 
61
+ if type_flag == 'live-sketch':
 
 
 
 
 
 
 
62
  img = Image.fromarray(np.array(image["composite"])[:, :, -1])
63
+ elif type_flag == 'url-sketch':
64
+ img = image["composite"]
65
+
66
  img = img.convert("RGB")
67
  img = img.resize((512, 512))
68
 
 
78
 
79
  output_pil = F.to_pil_image(output_image[0].cpu() * 0.5 + 0.5)
80
 
 
 
 
 
81
 
82
  return output_pil
S2I/modules/models.py CHANGED
@@ -65,10 +65,10 @@ class PrimaryModel:
65
  return sd
66
  def from_pretrained(self, model_name, r):
67
  if self.global_medium_prompt is None:
68
- self.global_medium_prompt = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance", device='cuda')
69
 
70
  if self.global_long_prompt is None:
71
- self.global_long_prompt = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance-Long", device='cuda')
72
 
73
  if self.global_tokenizer is None:
74
  self.global_tokenizer = AutoTokenizer.from_pretrained("myn0908/stable-diffusion-3", subfolder="tokenizer_2")
 
65
  return sd
66
  def from_pretrained(self, model_name, r):
67
  if self.global_medium_prompt is None:
68
+ self.global_medium_prompt = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance", device='cuda' if torch.cuda.is_available() else 'cpu')
69
 
70
  if self.global_long_prompt is None:
71
+ self.global_long_prompt = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance-Long", device='cuda' if torch.cuda.is_available() else 'cpu')
72
 
73
  if self.global_tokenizer is None:
74
  self.global_tokenizer = AutoTokenizer.from_pretrained("myn0908/stable-diffusion-3", subfolder="tokenizer_2")
S2I/modules/sketch2image.py CHANGED
@@ -75,8 +75,8 @@ class Sketch2ImagePipeline(PrimaryModel):
75
  self.global_unet.set_adapters(["default"], weights=[r])
76
  set_weights_and_activate_adapters(self.global_vae, ["vae_skip"], [r])
77
 
78
- def automatic_enhance_prompt(self, input_prompt, model_choice):
79
- if model_choice == "short-sentences":
80
  result = self.global_medium_prompt("Enhance the description: " + input_prompt)
81
  enhanced_text = result[0]['summary_text']
82
 
@@ -87,10 +87,6 @@ class Sketch2ImagePipeline(PrimaryModel):
87
  remaining_text = enhanced_text[match.end():].strip()
88
  modified_sentence = match.group(1).capitalize()
89
  enhanced_text = modified_sentence + ' ' + remaining_text
90
- else:
91
- result = self.global_long_prompt("Enhance the description: " + input_prompt)
92
- enhanced_text = result[0]['summary_text']
93
-
94
  return enhanced_text
95
 
96
  def _move_to_cpu(self, module):
 
75
  self.global_unet.set_adapters(["default"], weights=[r])
76
  set_weights_and_activate_adapters(self.global_vae, ["vae_skip"], [r])
77
 
78
+ def automatic_enhance_prompt(self, input_prompt, prompt_quality):
79
+ if prompt_quality:
80
  result = self.global_medium_prompt("Enhance the description: " + input_prompt)
81
  enhanced_text = result[0]['summary_text']
82
 
 
87
  remaining_text = enhanced_text[match.end():].strip()
88
  modified_sentence = match.group(1).capitalize()
89
  enhanced_text = modified_sentence + ' ' + remaining_text
 
 
 
 
90
  return enhanced_text
91
 
92
  def _move_to_cpu(self, module):
app.py CHANGED
@@ -260,6 +260,7 @@ with gr.Blocks(css=css, theme="NoCrypt/miku@1.2.1") as demo:
260
  show_download_button=True,
261
  )
262
  with gr.Group():
 
263
  prompt = gr.Textbox(label="Personalized Text", value="", show_label=True)
264
  with gr.Row():
265
  run_button = gr.Button("Generate 🪄", min_width=5, variant='primary')
@@ -267,22 +268,12 @@ with gr.Blocks(css=css, theme="NoCrypt/miku@1.2.1") as demo:
267
  clear_button = gr.Button("Reset Sketch Session", min_width=10, variant='primary')
268
  with gr.Accordion("S2I Advances Option", open=True):
269
  with gr.Row():
270
- # input_type = gr.Radio(
271
- # choices=["live-sketch", "url-sketch"],
272
- # value="live-sketch",
273
- # label="Type Sketch2Image models",
274
- # interactive=True)
275
-
276
- input_type = gr.Textbox(
277
- label="Check URL or Real-time Input",
278
- interactive=True)
279
-
280
- prompt_quality = gr.Radio(
281
- choices=["short-sentences", "long-sentences"],
282
- value="short-sentences",
283
- label="Long/Short of Text Prompt",
284
  interactive=True)
285
-
286
  style = gr.Dropdown(
287
  label="Style",
288
  choices=controller.STYLE_NAMES,
 
260
  show_download_button=True,
261
  )
262
  with gr.Group():
263
+ use_enhancer = gr.Checkbox(label="Use Automatic Prompt High-Quality", value=False)
264
  prompt = gr.Textbox(label="Personalized Text", value="", show_label=True)
265
  with gr.Row():
266
  run_button = gr.Button("Generate 🪄", min_width=5, variant='primary')
 
268
  clear_button = gr.Button("Reset Sketch Session", min_width=10, variant='primary')
269
  with gr.Accordion("S2I Advances Option", open=True):
270
  with gr.Row():
271
+ input_type = gr.Radio(
272
+ choices=["live-sketch", "url-sketch"],
273
+ value="live-sketch",
274
+ label="Type Sketch2Image models",
 
 
 
 
 
 
 
 
 
 
275
  interactive=True)
276
+
277
  style = gr.Dropdown(
278
  label="Style",
279
  choices=controller.STYLE_NAMES,