Fix OOM
Browse files- models/vallex.py +1 -1
models/vallex.py
CHANGED
@@ -676,7 +676,7 @@ class VALLE(VALLF):
|
|
676 |
y_emb[:, prefix_len:] += embedding_layer(samples)
|
677 |
|
678 |
assert len(codes) == self.num_quantizers
|
679 |
-
del text_language_id, prompt_language_id, y_emb, x, y_pos, xy_pos, xy_dec, logits, samples
|
680 |
gc.collect()
|
681 |
return torch.stack(codes, dim=-1)
|
682 |
|
|
|
676 |
y_emb[:, prefix_len:] += embedding_layer(samples)
|
677 |
|
678 |
assert len(codes) == self.num_quantizers
|
679 |
+
del text_language_id, prompt_language_id, y_emb, x, y_pos, xy_pos, xy_dec, logits, samples, kv_cache, x_attn_mask, y_attn_mask, xy_attn_mask
|
680 |
gc.collect()
|
681 |
return torch.stack(codes, dim=-1)
|
682 |
|