| import pytest as pytest | |
| from grouped_sampling import GroupedSamplingPipeLine | |
| from available_models import AVAILABLE_MODELS | |
| from hanlde_form_submit import create_pipeline, on_form_submit | |
| def test_on_form_submit(): | |
| model_name = "gpt2" | |
| output_length = 10 | |
| prompt = "Answer yes or no, is the sky blue?" | |
| output = on_form_submit(model_name, output_length, prompt) | |
| assert output is not None | |
| assert len(output) > 0 | |
| empty_prompt = "" | |
| with pytest.raises(ValueError): | |
| on_form_submit(model_name, output_length, empty_prompt) | |
| def test_create_pipeline(): | |
| pipeline: GroupedSamplingPipeLine = create_pipeline("gpt2") | |
| assert pipeline is not None | |
| assert pipeline.model_name == "gpt2" | |
| assert pipeline.wrapped_model.end_of_sentence_stop is False | |
| del pipeline | |
| if __name__ == "__main__": | |
| pytest.main() | |