ChenWu98 commited on
Commit
58ca927
·
1 Parent(s): 0f21f01

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -59,7 +59,7 @@ class LocalBlend:
59
  for word in words_:
60
  ind = ptp_utils.get_word_inds(prompt, word, tokenizer)
61
  alpha_layers[i, :, :, :, :, ind] = 1
62
- self.alpha_layers = alpha_layers.to(device)
63
  self.threshold = threshold
64
 
65
 
@@ -184,7 +184,7 @@ class AttentionControlEdit(AttentionStore, abc.ABC):
184
  local_blend: Optional[LocalBlend]):
185
  super(AttentionControlEdit, self).__init__()
186
  self.batch_size = len(prompts)
187
- self.cross_replace_alpha = ptp_utils.get_time_words_attention_alpha(prompts, num_steps, cross_replace_steps, tokenizer).to(device)
188
  if type(self_replace_steps) is float:
189
  self_replace_steps = 0, self_replace_steps
190
  self.num_self_replace = int(num_steps * self_replace_steps[0]), int(num_steps * self_replace_steps[1])
@@ -199,7 +199,7 @@ class AttentionReplace(AttentionControlEdit):
199
  def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float,
200
  local_blend: Optional[LocalBlend] = None):
201
  super(AttentionReplace, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend)
202
- self.mapper = seq_aligner.get_replacement_mapper(prompts, tokenizer).to(device)
203
 
204
 
205
  class AttentionRefine(AttentionControlEdit):
@@ -213,7 +213,7 @@ class AttentionRefine(AttentionControlEdit):
213
  local_blend: Optional[LocalBlend] = None):
214
  super(AttentionRefine, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend)
215
  self.mapper, alphas = seq_aligner.get_refinement_mapper(prompts, tokenizer)
216
- self.mapper, alphas = self.mapper.to(device), alphas.to(device)
217
  self.alphas = alphas.reshape(alphas.shape[0], 1, 1, alphas.shape[1])
218
 
219
 
 
59
  for word in words_:
60
  ind = ptp_utils.get_word_inds(prompt, word, tokenizer)
61
  alpha_layers[i, :, :, :, :, ind] = 1
62
+ self.alpha_layers = alpha_layers.to(device).to(torch_dtype)
63
  self.threshold = threshold
64
 
65
 
 
184
  local_blend: Optional[LocalBlend]):
185
  super(AttentionControlEdit, self).__init__()
186
  self.batch_size = len(prompts)
187
+ self.cross_replace_alpha = ptp_utils.get_time_words_attention_alpha(prompts, num_steps, cross_replace_steps, tokenizer).to(device).to(torch_dtype)
188
  if type(self_replace_steps) is float:
189
  self_replace_steps = 0, self_replace_steps
190
  self.num_self_replace = int(num_steps * self_replace_steps[0]), int(num_steps * self_replace_steps[1])
 
199
  def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float,
200
  local_blend: Optional[LocalBlend] = None):
201
  super(AttentionReplace, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend)
202
+ self.mapper = seq_aligner.get_replacement_mapper(prompts, tokenizer).to(device).to(torch_dtype)
203
 
204
 
205
  class AttentionRefine(AttentionControlEdit):
 
213
  local_blend: Optional[LocalBlend] = None):
214
  super(AttentionRefine, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend)
215
  self.mapper, alphas = seq_aligner.get_refinement_mapper(prompts, tokenizer)
216
+ self.mapper, alphas = self.mapper.to(device).to(torch_dtype), alphas.to(device).to(torch_dtype)
217
  self.alphas = alphas.reshape(alphas.shape[0], 1, 1, alphas.shape[1])
218
 
219