myn0908 commited on
Commit
204fcce
1 Parent(s): a2fc6fa

fixing prompt code

Browse files
S2I/commons/controller.py CHANGED
@@ -56,7 +56,7 @@ class Sketch2ImageController():
56
 
57
  def artwork(self, options, image, prompt, prompt_template, style_name, seed, val_r, faster, model_name, type_flag, prompt_quality):
58
  self.load_pipeline(zero_options=options)
59
- prompt = prompt_template.replace("{prompt}", prompt)
60
 
61
  # if type_flag == 'live-sketch':
62
  # img = Image.fromarray(np.array(image["composite"])[:, :, -1])
@@ -79,7 +79,7 @@ class Sketch2ImageController():
79
  noise = torch.randn((1, 4, H // 8, W // 8), device=c_t.device)
80
 
81
  with torch.no_grad():
82
- output_image = self.pipe.generate(c_t, prompt, prompt_quality, r=val_r, noise_map=noise, half_model=faster, model_name=model_name)
83
 
84
  output_pil = F.to_pil_image(output_image[0].cpu() * 0.5 + 0.5)
85
 
 
56
 
57
  def artwork(self, options, image, prompt, prompt_template, style_name, seed, val_r, faster, model_name, type_flag, prompt_quality):
58
  self.load_pipeline(zero_options=options)
59
+ # prompt = prompt_template.replace("{prompt}", prompt)
60
 
61
  # if type_flag == 'live-sketch':
62
  # img = Image.fromarray(np.array(image["composite"])[:, :, -1])
 
79
  noise = torch.randn((1, 4, H // 8, W // 8), device=c_t.device)
80
 
81
  with torch.no_grad():
82
+ output_image = self.pipe.generate(c_t, prompt, prompt_quality, prompt_template, r=val_r, noise_map=noise, half_model=faster, model_name=model_name)
83
 
84
  output_pil = F.to_pil_image(output_image[0].cpu() * 0.5 + 0.5)
85
 
S2I/modules/sketch2image.py CHANGED
@@ -13,9 +13,10 @@ class Sketch2ImagePipeline(PrimaryModel):
13
  super().__init__()
14
  self.timestep = torch.tensor([999], device="cuda").long()
15
 
16
- def generate(self, c_t, prompt=None, prompt_quality=None, prompt_tokens=None, r=1.0, noise_map=None, half_model=None, model_name=None):
17
  self.from_pretrained(model_name=model_name, r=r)
18
  prompt_enhanced = self.automatic_enhance_prompt(prompt, prompt_quality)
 
19
  assert (prompt is None) != (prompt_tokens is None), "Either prompt or prompt_tokens should be provided"
20
 
21
  if half_model == 'float16':
 
13
  super().__init__()
14
  self.timestep = torch.tensor([999], device="cuda").long()
15
 
16
+ def generate(self, c_t, prompt=None, prompt_quality=None, prompt_template=None, prompt_tokens=None, r=1.0, noise_map=None, half_model=None, model_name=None):
17
  self.from_pretrained(model_name=model_name, r=r)
18
  prompt_enhanced = self.automatic_enhance_prompt(prompt, prompt_quality)
19
+ prompt_enhanced = prompt_template.replace("{prompt}", prompt_enhanced)
20
  assert (prompt is None) != (prompt_tokens is None), "Either prompt or prompt_tokens should be provided"
21
 
22
  if half_model == 'float16':