File size: 4,251 Bytes
e71a2ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import pytest
import torch
import transformers
from hivemind import get_logger, use_hivemind_log_handler
from test_utils import *

from src.bloom.model import BloomForCausalLM
from src.client.remote_model import DistributedBloomForCausalLM

use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)


@pytest.mark.forked
def test_full_model_exact_match(atol_forward=1e-3, atol_inference=1e-3):
    tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
    model = DistributedBloomForCausalLM.from_pretrained(
        MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32
    )
    config = model.config
    assert isinstance(model, DistributedBloomForCausalLM)
    assert len(model.transformer.h) == model.config.n_layer

    test_inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]

    with torch.inference_mode():
        parallel_outputs = model.forward(test_inputs).logits
        assert torch.all(torch.isfinite(parallel_outputs))
        logger.info("Forward outputs are finite")

        embs = model.transformer.word_embeddings(test_inputs)
        embs = model.transformer.word_embeddings_layernorm(embs)
        recurrent_outputs = []
        with model.transformer.h.inference_session(max_length=embs.shape[1]) as sess:
            for t in range(embs.shape[1]):
                recurrent_outputs.append(sess.step(embs[:, t : t + 1, :]))
        recurrent_outputs = torch.cat(recurrent_outputs, dim=1)
        recurrent_outputs = model.transformer.ln_f(recurrent_outputs)
        recurrent_outputs = model.lm_head(recurrent_outputs)
        assert torch.allclose(recurrent_outputs, parallel_outputs, rtol=0, atol=atol_inference)
        logger.info("Inference is consistent with forward")

        del model, embs, recurrent_outputs

        if REF_NAME:
            ref_model = transformers.BloomForCausalLM.from_pretrained(
                REF_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32
            )
            if config.vocab_size < ref_model.config.vocab_size:
                ref_model.resize_token_embeddings(config.vocab_size)
                logger.warning(f"Resized the reference model embeddings, new total = {ref_model.config.vocab_size}")

            dummy_mask = torch.ones_like(test_inputs, dtype=torch.bool)
            # note: this creates a dummy mask to make the test compatible with older transformer versions
            # prior to https://github.com/huggingface/transformers/pull/17837
            ref_outputs = ref_model.forward(test_inputs, attention_mask=dummy_mask).logits.float()
            assert torch.allclose(ref_outputs, parallel_outputs, rtol=0, atol=atol_forward)
            logger.warning(f"Distributed forward is consistent with {type(ref_model)}.forward")
            del ref_model, ref_outputs, dummy_mask
        else:
            logger.warning("Did not test exact match with local model: REF_NAME environment variable is not set")
            assert False


@pytest.mark.forked
def test_greedy_generation(max_new_tokens=4):
    tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
    model = DistributedBloomForCausalLM.from_pretrained(
        MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32
    )
    inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
    remote_outputs = model.generate(
        inputs,
        max_new_tokens=max_new_tokens,
    )
    hf_outputs = BloomForCausalLM.greedy_search(model, input_ids=inputs, max_length=inputs.size(1) + max_new_tokens)
    assert torch.allclose(remote_outputs, hf_outputs), "Greedy search are not identical to HF"

    inputs_batch = tokenizer(["A cat sat on a mat", "A dog sat on a mat"], return_tensors="pt", padding=True)[
        "input_ids"
    ]
    remote_outputs_batch = model.generate(
        inputs_batch,
        max_new_tokens=max_new_tokens,
    )
    hf_outputs_batch = BloomForCausalLM.greedy_search(
        model, input_ids=inputs_batch, max_length=inputs_batch.size(1) + max_new_tokens
    )
    assert torch.allclose(
        remote_outputs_batch, hf_outputs_batch
    ), "Greedy search are not identical to HF in multibatch mode"