Spaces:
Runtime error
Runtime error
import pytest | |
import torch | |
from unittest.mock import Mock | |
from swarms.models.huggingface import HuggingFaceLLM | |
def mock_torch(): | |
return Mock() | |
def mock_autotokenizer(): | |
return Mock() | |
def mock_automodelforcausallm(): | |
return Mock() | |
def mock_bitsandbytesconfig(): | |
return Mock() | |
def hugging_face_llm(mock_torch, mock_autotokenizer, mock_automodelforcausallm, mock_bitsandbytesconfig): | |
HuggingFaceLLM.torch = mock_torch | |
HuggingFaceLLM.AutoTokenizer = mock_autotokenizer | |
HuggingFaceLLM.AutoModelForCausalLM = mock_automodelforcausallm | |
HuggingFaceLLM.BitsAndBytesConfig = mock_bitsandbytesconfig | |
return HuggingFaceLLM(model_id='test') | |
def test_init(hugging_face_llm, mock_autotokenizer, mock_automodelforcausallm): | |
assert hugging_face_llm.model_id == 'test' | |
mock_autotokenizer.from_pretrained.assert_called_once_with('test') | |
mock_automodelforcausallm.from_pretrained.assert_called_once_with('test', quantization_config=None) | |
def test_init_with_quantize(hugging_face_llm, mock_autotokenizer, mock_automodelforcausallm, mock_bitsandbytesconfig): | |
quantization_config = { | |
'load_in_4bit': True, | |
'bnb_4bit_use_double_quant': True, | |
'bnb_4bit_quant_type': "nf4", | |
'bnb_4bit_compute_dtype': torch.bfloat16 | |
} | |
mock_bitsandbytesconfig.return_value = quantization_config | |
HuggingFaceLLM(model_id='test', quantize=True) | |
mock_bitsandbytesconfig.assert_called_once_with(**quantization_config) | |
mock_autotokenizer.from_pretrained.assert_called_once_with('test') | |
mock_automodelforcausallm.from_pretrained.assert_called_once_with('test', quantization_config=quantization_config) | |
def test_generate_text(hugging_face_llm): | |
prompt_text = 'test prompt' | |
expected_output = 'test output' | |
hugging_face_llm.tokenizer.encode.return_value = torch.tensor([0]) # Mock tensor | |
hugging_face_llm.model.generate.return_value = torch.tensor([0]) # Mock tensor | |
hugging_face_llm.tokenizer.decode.return_value = expected_output | |
output = hugging_face_llm.generate_text(prompt_text) | |
assert output == expected_output | |