Spaces:
Running
on
Zero
Running
on
Zero
realantonvoronov
commited on
Commit
•
8b9fcd0
1
Parent(s):
bfb7c0b
move torch.Generator from switti init to pipeline call
Browse files- app.py +0 -1
- models/pipeline.py +1 -2
- models/switti.py +0 -1
app.py
CHANGED
@@ -7,7 +7,6 @@ from models import SwittiPipeline
|
|
7 |
import torch
|
8 |
|
9 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
10 |
-
print("Device:", device)
|
11 |
model_repo_id = "yresearch/Switti"
|
12 |
|
13 |
|
|
|
7 |
import torch
|
8 |
|
9 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
10 |
model_repo_id = "yresearch/Switti"
|
11 |
|
12 |
|
models/pipeline.py
CHANGED
@@ -113,8 +113,7 @@ class SwittiPipeline:
|
|
113 |
if seed is None:
|
114 |
rng = None
|
115 |
else:
|
116 |
-
|
117 |
-
rng = switti.rng
|
118 |
|
119 |
context, cond_vector, context_attn_bias = self.encode_prompt(prompt, null_prompt)
|
120 |
|
|
|
113 |
if seed is None:
|
114 |
rng = None
|
115 |
else:
|
116 |
+
rng = torch.Generator(self.device).manual_seed(seed)
|
|
|
117 |
|
118 |
context, cond_vector, context_attn_bias = self.encode_prompt(prompt, null_prompt)
|
119 |
|
models/switti.py
CHANGED
@@ -72,7 +72,6 @@ class Switti(nn.Module):
|
|
72 |
self.rope = rope
|
73 |
|
74 |
self.num_stages_minus_1 = len(self.patch_nums) - 1
|
75 |
-
self.rng = torch.Generator(device=device)
|
76 |
|
77 |
# 1. input (word) embedding
|
78 |
self.word_embed = nn.Linear(self.Cvae, self.C)
|
|
|
72 |
self.rope = rope
|
73 |
|
74 |
self.num_stages_minus_1 = len(self.patch_nums) - 1
|
|
|
75 |
|
76 |
# 1. input (word) embedding
|
77 |
self.word_embed = nn.Linear(self.Cvae, self.C)
|