fix model parallel (#816)
Browse files
src/axolotl/utils/models.py
CHANGED
@@ -442,14 +442,7 @@ def load_model(
|
|
442 |
if cfg.ddp and not load_in_8bit:
|
443 |
model.to(f"cuda:{cfg.local_rank}")
|
444 |
|
445 |
-
if (
|
446 |
-
torch.cuda.device_count() > 1
|
447 |
-
and int(os.getenv("WORLD_SIZE", "1")) > 1
|
448 |
-
and (cfg.load_in_4bit)
|
449 |
-
):
|
450 |
-
# llama is PROBABLY model parallelizable, but the default isn't that it is
|
451 |
-
# so let's only set it for the 4bit, see
|
452 |
-
# https://github.com/johnsmith0031/alpaca_lora_4bit/blob/08b3fca4a4a9e0d3945be1bab4529f100a428636/finetune.py#L130-L133
|
453 |
setattr(model, "is_parallelizable", True)
|
454 |
setattr(model, "model_parallel", True)
|
455 |
|
|
|
442 |
if cfg.ddp and not load_in_8bit:
|
443 |
model.to(f"cuda:{cfg.local_rank}")
|
444 |
|
445 |
+
if torch.cuda.device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) == 1:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
446 |
setattr(model, "is_parallelizable", True)
|
447 |
setattr(model, "model_parallel", True)
|
448 |
|