realantonvoronov commited on
Commit
8b9fcd0
1 Parent(s): bfb7c0b

move torch.Generator from switti init to pipeline call

Browse files
Files changed (3) hide show
  1. app.py +0 -1
  2. models/pipeline.py +1 -2
  3. models/switti.py +0 -1
app.py CHANGED
@@ -7,7 +7,6 @@ from models import SwittiPipeline
7
  import torch
8
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
- print("Device:", device)
11
  model_repo_id = "yresearch/Switti"
12
 
13
 
 
7
  import torch
8
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
10
  model_repo_id = "yresearch/Switti"
11
 
12
 
models/pipeline.py CHANGED
@@ -113,8 +113,7 @@ class SwittiPipeline:
113
  if seed is None:
114
  rng = None
115
  else:
116
- switti.rng.manual_seed(seed)
117
- rng = switti.rng
118
 
119
  context, cond_vector, context_attn_bias = self.encode_prompt(prompt, null_prompt)
120
 
 
113
  if seed is None:
114
  rng = None
115
  else:
116
+ rng = torch.Generator(self.device).manual_seed(seed)
 
117
 
118
  context, cond_vector, context_attn_bias = self.encode_prompt(prompt, null_prompt)
119
 
models/switti.py CHANGED
@@ -72,7 +72,6 @@ class Switti(nn.Module):
72
  self.rope = rope
73
 
74
  self.num_stages_minus_1 = len(self.patch_nums) - 1
75
- self.rng = torch.Generator(device=device)
76
 
77
  # 1. input (word) embedding
78
  self.word_embed = nn.Linear(self.Cvae, self.C)
 
72
  self.rope = rope
73
 
74
  self.num_stages_minus_1 = len(self.patch_nums) - 1
 
75
 
76
  # 1. input (word) embedding
77
  self.word_embed = nn.Linear(self.Cvae, self.C)