Update generator.py
Browse files- generator.py +1 -0
generator.py
CHANGED
|
@@ -177,6 +177,7 @@ def load_csm_1b(ckpt_path: str = "ckpt.pt", device: str = "cuda") -> Generator:
|
|
| 177 |
model = Model(model_args).to(device=device, dtype=torch.bfloat16)
|
| 178 |
state_dict = torch.load(ckpt_path)
|
| 179 |
model.load_state_dict(state_dict)
|
|
|
|
| 180 |
|
| 181 |
generator = Generator(model)
|
| 182 |
return generator
|
|
|
|
| 177 |
model = Model(model_args).to(device=device, dtype=torch.bfloat16)
|
| 178 |
state_dict = torch.load(ckpt_path)
|
| 179 |
model.load_state_dict(state_dict)
|
| 180 |
+
model.decoder = torch.compile(model.decoder, fullgraph=True, mode='max-autotune')
|
| 181 |
|
| 182 |
generator = Generator(model)
|
| 183 |
return generator
|