Spaces:
Running
on
Zero
Running
on
Zero
realantonvoronov
commited on
Commit
•
5487649
1
Parent(s):
760fde0
remove torch.cuda.is_available from inner scripts
Browse files- app.py +1 -0
- models/switti.py +4 -1
app.py
CHANGED
@@ -7,6 +7,7 @@ from models import SwittiPipeline
|
|
7 |
import torch
|
8 |
|
9 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
10 |
model_repo_id = "yresearch/Switti"
|
11 |
|
12 |
|
|
|
7 |
import torch
|
8 |
|
9 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
10 |
+
print("Device:", device)
|
11 |
model_repo_id = "yresearch/Switti"
|
12 |
|
13 |
|
models/switti.py
CHANGED
@@ -53,6 +53,7 @@ class Switti(nn.Module):
|
|
53 |
use_swiglu_ffn=True,
|
54 |
use_ar=False,
|
55 |
use_crop_cond=True,
|
|
|
56 |
):
|
57 |
super().__init__()
|
58 |
# 0. hyperparameters
|
@@ -71,7 +72,7 @@ class Switti(nn.Module):
|
|
71 |
self.rope = rope
|
72 |
|
73 |
self.num_stages_minus_1 = len(self.patch_nums) - 1
|
74 |
-
self.rng = torch.Generator(device=
|
75 |
|
76 |
# 1. input (word) embedding
|
77 |
self.word_embed = nn.Linear(self.Cvae, self.C)
|
@@ -392,6 +393,7 @@ class SwittiHF(Switti, PyTorchModelHubMixin):
|
|
392 |
use_swiglu_ffn=True,
|
393 |
use_ar=False,
|
394 |
use_crop_cond=True,
|
|
|
395 |
):
|
396 |
heads = depth
|
397 |
width = depth * 64
|
@@ -406,4 +408,5 @@ class SwittiHF(Switti, PyTorchModelHubMixin):
|
|
406 |
use_swiglu_ffn=use_swiglu_ffn,
|
407 |
use_ar=use_ar,
|
408 |
use_crop_cond=use_crop_cond,
|
|
|
409 |
)
|
|
|
53 |
use_swiglu_ffn=True,
|
54 |
use_ar=False,
|
55 |
use_crop_cond=True,
|
56 |
+
device='cuda',
|
57 |
):
|
58 |
super().__init__()
|
59 |
# 0. hyperparameters
|
|
|
72 |
self.rope = rope
|
73 |
|
74 |
self.num_stages_minus_1 = len(self.patch_nums) - 1
|
75 |
+
self.rng = torch.Generator(device=device)
|
76 |
|
77 |
# 1. input (word) embedding
|
78 |
self.word_embed = nn.Linear(self.Cvae, self.C)
|
|
|
393 |
use_swiglu_ffn=True,
|
394 |
use_ar=False,
|
395 |
use_crop_cond=True,
|
396 |
+
device='cuda',
|
397 |
):
|
398 |
heads = depth
|
399 |
width = depth * 64
|
|
|
408 |
use_swiglu_ffn=use_swiglu_ffn,
|
409 |
use_ar=use_ar,
|
410 |
use_crop_cond=use_crop_cond,
|
411 |
+
device=device,
|
412 |
)
|