import time import pytest from tests.utils import wrap_test_forked from src.enums import source_prefix, source_postfix from src.prompter import generate_prompt example_data_point0 = dict(instruction="Summarize", input="Ducks eat seeds by the lake, then swim in the lake where fish eat small animals.", output="Ducks eat and swim at the lake.") example_data_point1 = dict(instruction="Who is smarter, Einstein or Newton?", output="Einstein.") example_data_point2 = dict(input="Who is smarter, Einstein or Newton?", output="Einstein.") example_data_points = [example_data_point0, example_data_point1, example_data_point2] @wrap_test_forked def test_train_prompt(prompt_type='instruct', data_point=0): example_data_point = example_data_points[data_point] return generate_prompt(example_data_point, prompt_type, '', False, False, False) @wrap_test_forked def test_test_prompt(prompt_type='instruct', data_point=0): example_data_point = example_data_points[data_point] example_data_point.pop('output', None) return generate_prompt(example_data_point, prompt_type, '', False, False, False) @wrap_test_forked def test_test_prompt2(prompt_type='human_bot', data_point=0): example_data_point = example_data_points[data_point] example_data_point.pop('output', None) res = generate_prompt(example_data_point, prompt_type, '', False, False, False) print(res, flush=True) return res prompt_fastchat = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: Hello! ASSISTANT: Hi!USER: How are you? ASSISTANT: I'm goodUSER: Go to the market? ASSISTANT:""" prompt_humanbot = """: Hello!\n: Hi!\n: How are you?\n: I'm good\n: Go to the market?\n:""" prompt_prompt_answer = "<|prompt|>Hello!<|endoftext|><|answer|>Hi!<|endoftext|><|prompt|>How are you?<|endoftext|><|answer|>I'm good<|endoftext|><|prompt|>Go to the market?<|endoftext|><|answer|>" prompt_prompt_answer_openllama = "<|prompt|>Hello!<|answer|>Hi!<|prompt|>How are you?<|answer|>I'm good<|prompt|>Go to the market?<|answer|>" prompt_mpt_instruct = """Below is an instruction that describes a task. Write a response that appropriately completes the request. ### Instruction Hello! ### Response Hi! ### Instruction How are you? ### Response I'm good ### Instruction Go to the market? ### Response """ prompt_mpt_chat = """<|im_start|>system A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers. <|im_end|><|im_start|>user Hello!<|im_end|><|im_start|>assistant Hi!<|im_end|><|im_start|>user How are you?<|im_end|><|im_start|>assistant I'm good<|im_end|><|im_start|>user Go to the market?<|im_end|><|im_start|>assistant """ prompt_falcon = """User: Hello! Assistant: Hi! User: How are you? Assistant: I'm good User: Go to the market? Assistant:""" @wrap_test_forked @pytest.mark.parametrize("prompt_type,expected", [ ('vicuna11', prompt_fastchat), ('human_bot', prompt_humanbot), ('prompt_answer', prompt_prompt_answer), ('prompt_answer_openllama', prompt_prompt_answer_openllama), ('mptinstruct', prompt_mpt_instruct), ('mptchat', prompt_mpt_chat), ('falcon', prompt_falcon), ] ) def test_prompt_with_context(prompt_type, expected): prompt_dict = None # not used unless prompt_type='custom' langchain_mode = 'Disabled' chat = True model_max_length = 2048 memory_restriction_level = 0 keep_sources_in_context1 = False iinput = '' stream_output = False debug = False from src.prompter import Prompter from src.gen import history_to_context t0 = time.time() history = [["Hello!", "Hi!"], ["How are you?", "I'm good"], ["Go to the market?", None] ] print("duration1: %s %s" % (prompt_type, time.time() - t0), flush=True) t0 = time.time() context = history_to_context(history, langchain_mode, prompt_type, prompt_dict, chat, model_max_length, memory_restriction_level, keep_sources_in_context1) print("duration2: %s %s" % (prompt_type, time.time() - t0), flush=True) t0 = time.time() instruction = history[-1][0] # get prompt prompter = Prompter(prompt_type, prompt_dict, debug=debug, chat=chat, stream_output=stream_output) print("duration3: %s %s" % (prompt_type, time.time() - t0), flush=True) t0 = time.time() data_point = dict(context=context, instruction=instruction, input=iinput) prompt = prompter.generate_prompt(data_point) print(prompt) print("duration4: %s %s" % (prompt_type, time.time() - t0), flush=True) assert prompt == expected assert prompt.find(source_prefix) == -1 prompt_fastchat1 = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: Go to the market? ASSISTANT:""" prompt_humanbot1 = """: Go to the market?\n:""" prompt_prompt_answer1 = "<|prompt|>Go to the market?<|endoftext|><|answer|>" prompt_prompt_answer_openllama1 = "<|prompt|>Go to the market?<|answer|>" prompt_mpt_instruct1 = """Below is an instruction that describes a task. Write a response that appropriately completes the request. ### Instruction Go to the market? ### Response """ prompt_mpt_chat1 = """<|im_start|>system A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers. <|im_end|><|im_start|>user Go to the market?<|im_end|><|im_start|>assistant """ prompt_falcon1 = """User: Go to the market? Assistant:""" @pytest.mark.parametrize("prompt_type,expected", [ ('vicuna11', prompt_fastchat1), ('human_bot', prompt_humanbot1), ('prompt_answer', prompt_prompt_answer1), ('prompt_answer_openllama', prompt_prompt_answer_openllama1), ('mptinstruct', prompt_mpt_instruct1), ('mptchat', prompt_mpt_chat1), ('falcon', prompt_falcon1), ] ) @wrap_test_forked def test_prompt_with_no_context(prompt_type, expected): prompt_dict = None # not used unless prompt_type='custom' chat = True iinput = '' stream_output = False debug = False from src.prompter import Prompter context = '' instruction = "Go to the market?" # get prompt prompter = Prompter(prompt_type, prompt_dict, debug=debug, chat=chat, stream_output=stream_output) data_point = dict(context=context, instruction=instruction, input=iinput) prompt = prompter.generate_prompt(data_point) print(prompt) assert prompt == expected assert prompt.find(source_prefix) == -1 @wrap_test_forked def test_source(): prompt = "Who are you?%s\nFOO\n%s" % (source_prefix, source_postfix) assert prompt.find(source_prefix) >= 0