import pytest as pytest from grouped_sampling import GroupedSamplingPipeLine from available_models import AVAILABLE_MODELS from hanlde_form_submit import create_pipeline, on_form_submit def test_on_form_submit(): model_name = "gpt2" output_length = 10 prompt = "Answer yes or no, is the sky blue?" output = on_form_submit(model_name, output_length, prompt) assert output is not None assert len(output) > 0 empty_prompt = "" with pytest.raises(ValueError): on_form_submit(model_name, output_length, empty_prompt) def test_create_pipeline(): pipeline: GroupedSamplingPipeLine = create_pipeline("gpt2") assert pipeline is not None assert pipeline.model_name == "gpt2" assert pipeline.wrapped_model.end_of_sentence_stop is False del pipeline if __name__ == "__main__": pytest.main()