Small dummy LLama2-type Model useable for Unit/Integration tests. Suitable for CPU only machines, see H2O LLM Studio for an example integration test.

Model was created as follows:

from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM

repo_name = "MaxJeblick/llama2-0b-unit-test"
model_name = "h2oai/h2ogpt-4096-llama2-7b-chat"
config = AutoConfig.from_pretrained(model_name)
config.hidden_size = 12
config.max_position_embeddings = 1024
config.intermediate_size = 24
config.num_attention_heads = 2
config.num_hidden_layers = 2
config.num_key_value_heads = 2

tokenizer = AutoTokenizer.from_pretrained(model_name)

model = AutoModelForCausalLM.from_config(config)
print(model.num_parameters())  # 770_940

model.push_to_hub(repo_name, private=False)
tokenizer.push_to_hub(repo_name, private=False)
config.push_to_hub(repo_name, private=False)

Below is a small example that will run in ~ 1 second.

import torch
from transformers import AutoModelForCausalLM


def test_manual_greedy_generate():
    max_new_tokens = 10

    # note this is on CPU!
    model = AutoModelForCausalLM.from_pretrained("MaxJeblick/llama2-0b-unit-test").eval()
    input_ids = model.dummy_inputs["input_ids"]

    y = model.generate(input_ids, max_new_tokens=max_new_tokens)

    assert y.shape == (3, input_ids.shape[1] + max_new_tokens)

    for _ in range(max_new_tokens):
        with torch.no_grad():
            outputs = model(input_ids)

        next_token_logits = outputs.logits[:, -1, :]
        next_token_id = torch.argmax(next_token_logits, dim=-1).unsqueeze(-1)

        input_ids = torch.cat([input_ids, next_token_id], dim=-1)

    assert torch.allclose(y, input_ids)

Tipp:

Use fixtures with session scope to load the model only once. This will decrease test runtime further.

import pytest
from transformers import AutoModelForCausalLM
@pytest.fixture(scope="session")
def model():
    return AutoModelForCausalLM.from_pretrained("MaxJeblick/llama2-0b-unit-test").eval()
Downloads last month
4,740
Safetensors
Model size
771k params
Tensor type
F32
ยท
Inference Providers NEW
This model is not currently available via any of the supported Inference Providers.

Space using MaxJeblick/llama2-0b-unit-test 1