|
--- |
|
{} |
|
--- |
|
Small dummy LLama2-type Model useable for Unit/Integration tests. Suitable for CPU only machines, see [H2O LLM Studio](https://github.com/h2oai/h2o-llmstudio/blob/main/tests/integration/test_integration.py) for an example integration test. |
|
|
|
Model was created as follows: |
|
```python |
|
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM |
|
|
|
repo_name = "MaxJeblick/llama2-0b-unit-test" |
|
model_name = "h2oai/h2ogpt-4096-llama2-7b-chat" |
|
config = AutoConfig.from_pretrained(model_name) |
|
config.hidden_size = 12 |
|
config.max_position_embeddings = 1024 |
|
config.intermediate_size = 24 |
|
config.num_attention_heads = 2 |
|
config.num_hidden_layers = 2 |
|
config.num_key_value_heads = 2 |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
|
model = AutoModelForCausalLM.from_config(config) |
|
print(model.num_parameters()) # 770_940 |
|
|
|
model.push_to_hub(repo_name, private=False) |
|
tokenizer.push_to_hub(repo_name, private=False) |
|
config.push_to_hub(repo_name, private=False) |
|
``` |
|
|
|
Below is a small example that will run in ~ 1 second. |
|
|
|
```python |
|
import torch |
|
from transformers import AutoModelForCausalLM |
|
|
|
|
|
def test_manual_greedy_generate(): |
|
max_new_tokens = 10 |
|
|
|
# note this is on CPU! |
|
model = AutoModelForCausalLM.from_pretrained("MaxJeblick/llama2-0b-unit-test").eval() |
|
input_ids = model.dummy_inputs["input_ids"] |
|
|
|
y = model.generate(input_ids, max_new_tokens=max_new_tokens) |
|
|
|
assert y.shape == (3, input_ids.shape[1] + max_new_tokens) |
|
|
|
for _ in range(max_new_tokens): |
|
with torch.no_grad(): |
|
outputs = model(input_ids) |
|
|
|
next_token_logits = outputs.logits[:, -1, :] |
|
next_token_id = torch.argmax(next_token_logits, dim=-1).unsqueeze(-1) |
|
|
|
input_ids = torch.cat([input_ids, next_token_id], dim=-1) |
|
|
|
assert torch.allclose(y, input_ids) |
|
``` |
|
|
|
Tipp: |
|
|
|
Use fixtures with session scope to load the model only once. This will decrease test runtime further. |
|
|
|
```python |
|
import pytest |
|
from transformers import AutoModelForCausalLM |
|
@pytest.fixture(scope="session") |
|
def model(): |
|
return AutoModelForCausalLM.from_pretrained("MaxJeblick/llama2-0b-unit-test").eval() |
|
``` |
|
|