realantonvoronov commited on
Commit
5487649
1 Parent(s): 760fde0

remove torch.cuda.is_available from inner scripts

Browse files
Files changed (2) hide show
  1. app.py +1 -0
  2. models/switti.py +4 -1
app.py CHANGED
@@ -7,6 +7,7 @@ from models import SwittiPipeline
7
  import torch
8
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
10
  model_repo_id = "yresearch/Switti"
11
 
12
 
 
7
  import torch
8
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
+ print("Device:", device)
11
  model_repo_id = "yresearch/Switti"
12
 
13
 
models/switti.py CHANGED
@@ -53,6 +53,7 @@ class Switti(nn.Module):
53
  use_swiglu_ffn=True,
54
  use_ar=False,
55
  use_crop_cond=True,
 
56
  ):
57
  super().__init__()
58
  # 0. hyperparameters
@@ -71,7 +72,7 @@ class Switti(nn.Module):
71
  self.rope = rope
72
 
73
  self.num_stages_minus_1 = len(self.patch_nums) - 1
74
- self.rng = torch.Generator(device="cuda" if torch.cuda.is_available() else "cpu")
75
 
76
  # 1. input (word) embedding
77
  self.word_embed = nn.Linear(self.Cvae, self.C)
@@ -392,6 +393,7 @@ class SwittiHF(Switti, PyTorchModelHubMixin):
392
  use_swiglu_ffn=True,
393
  use_ar=False,
394
  use_crop_cond=True,
 
395
  ):
396
  heads = depth
397
  width = depth * 64
@@ -406,4 +408,5 @@ class SwittiHF(Switti, PyTorchModelHubMixin):
406
  use_swiglu_ffn=use_swiglu_ffn,
407
  use_ar=use_ar,
408
  use_crop_cond=use_crop_cond,
 
409
  )
 
53
  use_swiglu_ffn=True,
54
  use_ar=False,
55
  use_crop_cond=True,
56
+ device='cuda',
57
  ):
58
  super().__init__()
59
  # 0. hyperparameters
 
72
  self.rope = rope
73
 
74
  self.num_stages_minus_1 = len(self.patch_nums) - 1
75
+ self.rng = torch.Generator(device=device)
76
 
77
  # 1. input (word) embedding
78
  self.word_embed = nn.Linear(self.Cvae, self.C)
 
393
  use_swiglu_ffn=True,
394
  use_ar=False,
395
  use_crop_cond=True,
396
+ device='cuda',
397
  ):
398
  heads = depth
399
  width = depth * 64
 
408
  use_swiglu_ffn=use_swiglu_ffn,
409
  use_ar=use_ar,
410
  use_crop_cond=use_crop_cond,
411
+ device=device,
412
  )