File size: 3,475 Bytes
5caedb4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import os
import tempfile
from unittest.mock import MagicMock, patch

import pytest
import torch

from llm_studio.src.utils.modeling_utils import (
    check_disk_space,
    load_checkpoint,
    save_checkpoint,
    unwrap_model,
)


def test_unwrap_model():
    # Create a dummy model
    model = torch.nn.Linear(10, 10)

    # Wrap it in DataParallel
    wrapped_model = torch.nn.DataParallel(model)
    assert wrapped_model != model
    assert isinstance(wrapped_model, torch.nn.DataParallel)

    # Test unwrapping
    unwrapped = unwrap_model(wrapped_model)
    assert unwrapped == model
    assert not isinstance(unwrapped, torch.nn.DataParallel)


@pytest.mark.parametrize(
    "dtype",
    [
        (torch.float32),
        (torch.float16),
        (torch.bfloat16),
        (torch.int8),
        (torch.uint8),
        (torch.int16),
    ],
)
@pytest.mark.parametrize(
    "free_space,should_raise",
    [
        (1e12, False),  # Plenty of space
        (1, True),  # Not enough space
    ],
)
def test_check_disk_space(free_space, dtype, should_raise):
    # Mock model and shutil
    model = MagicMock()
    model.parameters.return_value = [torch.ones(1000, 1000, dtype=dtype)]

    with patch("shutil.disk_usage", return_value=(0, 0, free_space)):
        if should_raise:
            with pytest.raises(ValueError):
                check_disk_space(model, "/dummy/path")
        else:
            check_disk_space(model, "/dummy/path")  # Should not raise


class DummyModel(torch.nn.Module):
    def __init__(self, use_classification_head=True):
        super(DummyModel, self).__init__()
        self.backbone = torch.nn.Linear(10, 10)
        if use_classification_head:
            self.classification_head = torch.nn.Linear(10, 10)


@pytest.mark.parametrize("use_classification_head", [True, False])
def test_save_checkpoint(use_classification_head):
    model = DummyModel(use_classification_head)
    cfg = MagicMock()
    cfg.environment.use_deepspeed = False
    cfg.environment._local_rank = 0
    cfg.training.lora = False

    with tempfile.TemporaryDirectory() as tmpdir:
        save_checkpoint(model, tmpdir, cfg)
        assert os.path.exists(os.path.join(tmpdir, "checkpoint.pth"))
        if use_classification_head:
            assert os.path.exists(os.path.join(tmpdir, "classification_head.pth"))
        else:
            assert not os.path.exists(os.path.join(tmpdir, "classification_head.pth"))


def test_load_checkpoint():
    model = DummyModel()
    cfg = MagicMock()
    cfg.architecture.pretrained_weights = "dummy_weights.pth"
    cfg.environment.use_deepspeed = False
    cfg.environment._local_rank = 0

    # Mock torch.load
    dummy_state_dict = {"model": model.state_dict()}
    with patch("torch.load", return_value=dummy_state_dict):
        load_checkpoint(cfg, model, strict=True)


def test_load_checkpoint_mismatch():
    model = DummyModel(use_classification_head=True)
    model_no_classification_head = DummyModel(use_classification_head=False)
    cfg = MagicMock()
    cfg.architecture.pretrained_weights = "dummy_weights.pth"
    cfg.environment.use_deepspeed = False
    cfg.environment._local_rank = 0

    # Mock torch.load
    dummy_state_dict = {"model": model_no_classification_head.state_dict()}
    with patch("torch.load", return_value=dummy_state_dict):
        with pytest.raises(RuntimeError):
            load_checkpoint(cfg, model, strict=True)
        load_checkpoint(cfg, model, strict=False)