Spaces:
Runtime error
Runtime error
yupeng.zhou
commited on
Commit
•
14f69c4
1
Parent(s):
a5df616
fix
Browse files
app.py
CHANGED
@@ -110,7 +110,7 @@ class SpatialAttnProcessor2_0(torch.nn.Module):
|
|
110 |
encoder_hidden_states = torch.cat((self.id_bank[cur_step][0].to(self.device),hidden_states[:1],self.id_bank[cur_step][1].to(self.device),hidden_states[1:]))
|
111 |
# 判断随机数是否大于0.5
|
112 |
if cur_step <5:
|
113 |
-
hidden_states = self.__call2__(attn, hidden_states,
|
114 |
else: # 256 1024 4096
|
115 |
random_number = random.random()
|
116 |
if cur_step <20:
|
|
|
110 |
encoder_hidden_states = torch.cat((self.id_bank[cur_step][0].to(self.device),hidden_states[:1],self.id_bank[cur_step][1].to(self.device),hidden_states[1:]))
|
111 |
# 判断随机数是否大于0.5
|
112 |
if cur_step <5:
|
113 |
+
hidden_states = self.__call2__(attn, hidden_states,None,attention_mask,temb)
|
114 |
else: # 256 1024 4096
|
115 |
random_number = random.random()
|
116 |
if cur_step <20:
|