ChenWu98 commited on
Commit
55bfc51
·
1 Parent(s): f9410ef

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -47,7 +47,7 @@ class LocalBlend:
47
  mask = nnf.interpolate(mask, size=(x_t.shape[2:]))
48
  mask = mask / mask.max(2, keepdims=True)[0].max(3, keepdims=True)[0]
49
  mask = mask.gt(self.threshold)
50
- mask = (mask[:1] + mask[1:]).float()
51
  x_t = x_t[:1] + mask * (x_t - x_t[:1])
52
  return x_t
53
 
@@ -221,7 +221,7 @@ def get_equalizer(text: str, word_select: Union[int, Tuple[int, ...]], values: U
221
  if type(word_select) is int or type(word_select) is str:
222
  word_select = (word_select,)
223
  equalizer = torch.ones(len(values), 77)
224
- values = torch.tensor(values, dtype=torch.float32)
225
  for word in word_select:
226
  inds = ptp_utils.get_word_inds(text, word, tokenizer)
227
  equalizer[:, inds] = values
 
47
  mask = nnf.interpolate(mask, size=(x_t.shape[2:]))
48
  mask = mask / mask.max(2, keepdims=True)[0].max(3, keepdims=True)[0]
49
  mask = mask.gt(self.threshold)
50
+ mask = (mask[:1] + mask[1:]).to(x_t.dtype)
51
  x_t = x_t[:1] + mask * (x_t - x_t[:1])
52
  return x_t
53
 
 
221
  if type(word_select) is int or type(word_select) is str:
222
  word_select = (word_select,)
223
  equalizer = torch.ones(len(values), 77)
224
+ values = torch.tensor(values, dtype=torch_dtype)
225
  for word in word_select:
226
  inds = ptp_utils.get_word_inds(text, word, tokenizer)
227
  equalizer[:, inds] = values