|
from transformers import TextGenerationPipeline |
|
from transformers.pipelines.text_generation import ReturnType |
|
|
|
|
|
|
|
|
|
|
|
class H2OTextGenerationPipeline(TextGenerationPipeline): |
|
def __init__(self, *args, debug=False, chat=False, stream_output=False, |
|
sanitize_bot_response=True, |
|
use_prompter=True, prompter=None, prompt_type=None, |
|
max_input_tokens=2048 - 256, **kwargs): |
|
""" |
|
HF-like pipeline, but handle instruction prompting and stopping (for some models) |
|
:param args: |
|
:param debug: |
|
:param chat: |
|
:param stream_output: |
|
:param sanitize_bot_response: |
|
:param use_prompter: Whether to use prompter. If pass prompt_type, will make prompter |
|
:param prompter: prompter, can pass if have already |
|
:param prompt_type: prompt_type, e.g. human_bot. See prompt_type to model mapping in from prompter.py. |
|
If use_prompter, then will make prompter and use it. |
|
:param max_input_tokens: |
|
:param kwargs: |
|
""" |
|
super().__init__(*args, **kwargs) |
|
self.prompt_text = None |
|
self.use_prompter = use_prompter |
|
self.prompt_type = prompt_type |
|
self.prompter = prompter |
|
if self.use_prompter: |
|
if self.prompter is not None: |
|
assert self.prompter.prompt_type is not None |
|
else: |
|
self.prompter = Prompter(self.prompt_type, debug=debug, chat=chat, stream_output=stream_output) |
|
self.human = self.prompter.humanstr |
|
self.bot = self.prompter.botstr |
|
self.can_stop = True |
|
else: |
|
self.prompter = None |
|
self.human = None |
|
self.bot = None |
|
self.can_stop = False |
|
self.sanitize_bot_response = sanitize_bot_response |
|
self.max_input_tokens = max_input_tokens |
|
|
|
def preprocess(self, prompt_text, prefix="", handle_long_generation=None, **generate_kwargs): |
|
data_point = dict(context='', instruction=prompt_text, input='') |
|
if self.prompter is not None: |
|
prompt_text = self.prompter.generate_prompt(data_point) |
|
self.prompt_text = prompt_text |
|
if handle_long_generation is None: |
|
|
|
handle_long_generation = 'hole' |
|
return super().preprocess(prompt_text, prefix=prefix, handle_long_generation=handle_long_generation, |
|
**generate_kwargs) |
|
|
|
def postprocess(self, model_outputs, return_type=ReturnType.FULL_TEXT, clean_up_tokenization_spaces=True): |
|
records = super().postprocess(model_outputs, return_type=return_type, |
|
clean_up_tokenization_spaces=clean_up_tokenization_spaces) |
|
for rec in records: |
|
if self.use_prompter: |
|
outputs = rec['generated_text'] |
|
outputs = self.prompter.get_response(outputs, prompt=self.prompt_text, |
|
sanitize_bot_response=self.sanitize_bot_response) |
|
elif self.bot and self.human: |
|
outputs = rec['generated_text'].split(self.bot)[1].strip().split(self.human)[0].strip() |
|
else: |
|
outputs = rec['generated_text'] |
|
rec['generated_text'] = outputs |
|
return records |
|
|
|
def _forward(self, model_inputs, **generate_kwargs): |
|
if self.can_stop: |
|
stopping_criteria = get_stopping(self.prompt_type, self.tokenizer, self.device, human=self.human, |
|
bot=self.bot) |
|
generate_kwargs['stopping_criteria'] = stopping_criteria |
|
|
|
return self.__forward(model_inputs, **generate_kwargs) |
|
|
|
|
|
|
|
def __forward(self, model_inputs, **generate_kwargs): |
|
input_ids = model_inputs["input_ids"] |
|
attention_mask = model_inputs.get("attention_mask", None) |
|
|
|
if input_ids.shape[1] == 0: |
|
input_ids = None |
|
attention_mask = None |
|
in_b = 1 |
|
else: |
|
in_b = input_ids.shape[0] |
|
prompt_text = model_inputs.pop("prompt_text") |
|
|
|
|
|
|
|
|
|
prefix_length = generate_kwargs.pop("prefix_length", 0) |
|
if prefix_length > 0: |
|
has_max_new_tokens = "max_new_tokens" in generate_kwargs or ( |
|
"generation_config" in generate_kwargs |
|
and generate_kwargs["generation_config"].max_new_tokens is not None |
|
) |
|
if not has_max_new_tokens: |
|
generate_kwargs["max_length"] = generate_kwargs.get("max_length") or self.model.config.max_length |
|
generate_kwargs["max_length"] += prefix_length |
|
has_min_new_tokens = "min_new_tokens" in generate_kwargs or ( |
|
"generation_config" in generate_kwargs |
|
and generate_kwargs["generation_config"].min_new_tokens is not None |
|
) |
|
if not has_min_new_tokens and "min_length" in generate_kwargs: |
|
generate_kwargs["min_length"] += prefix_length |
|
|
|
|
|
generated_sequence = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs) |
|
out_b = generated_sequence.shape[0] |
|
if self.framework == "pt": |
|
generated_sequence = generated_sequence.reshape(in_b, out_b // in_b, *generated_sequence.shape[1:]) |
|
elif self.framework == "tf": |
|
from transformers import is_tf_available |
|
if is_tf_available(): |
|
import tensorflow as tf |
|
generated_sequence = tf.reshape(generated_sequence, |
|
(in_b, out_b // in_b, *generated_sequence.shape[1:])) |
|
else: |
|
raise ValueError("TF not avaialble.") |
|
return {"generated_sequence": generated_sequence, "input_ids": input_ids, "prompt_text": prompt_text} |
|
import torch |
|
from transformers import StoppingCriteria, StoppingCriteriaList |
|
|
|
|
|
|
|
class StoppingCriteriaSub(StoppingCriteria): |
|
|
|
def __init__(self, stops=[], encounters=[], device="cuda"): |
|
super().__init__() |
|
assert len(stops) % len(encounters) == 0, "Number of stops and encounters must match" |
|
self.encounters = encounters |
|
self.stops = [stop.to(device) for stop in stops] |
|
self.num_stops = [0] * len(stops) |
|
|
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: |
|
for stopi, stop in enumerate(self.stops): |
|
if torch.all((stop == input_ids[0][-len(stop):])).item(): |
|
self.num_stops[stopi] += 1 |
|
if self.num_stops[stopi] >= self.encounters[stopi % len(self.encounters)]: |
|
|
|
return True |
|
|
|
|
|
return False |
|
|
|
|
|
def get_stopping(prompt_type, tokenizer, device, human='<human>:', bot="<bot>:"): |
|
if prompt_type in [PromptType.human_bot.name, PromptType.instruct_vicuna.name, PromptType.instruct_with_end.name]: |
|
if prompt_type == PromptType.human_bot.name: |
|
|
|
|
|
|
|
stop_words = [human, bot, '\n' + human, '\n' + bot] |
|
encounters = [1, 2] |
|
elif prompt_type == PromptType.instruct_vicuna.name: |
|
|
|
stop_words = [ |
|
'### Human:', |
|
""" |
|
### Human:""", |
|
""" |
|
### Human: |
|
""", |
|
'### Assistant:', |
|
""" |
|
### Assistant:""", |
|
""" |
|
### Assistant: |
|
""", |
|
] |
|
encounters = [1, 2] |
|
else: |
|
|
|
stop_words = ['### End'] |
|
encounters = [1] |
|
stop_words_ids = [ |
|
tokenizer(stop_word, return_tensors='pt')['input_ids'].squeeze() for stop_word in stop_words] |
|
|
|
stop_words_ids = [x if len(x.shape) > 0 else torch.tensor([x]) for x in stop_words_ids] |
|
stop_words_ids = [x for x in stop_words_ids if x.shape[0] > 0] |
|
|
|
if tokenizer._pad_token: |
|
stop_words_ids = [x[1:] if x[0] == tokenizer.pad_token_id and len(x) > 1 else x for x in stop_words_ids] |
|
|
|
stop_words_ids = [x[1:] if y[0] == '\n' else x for x, y in zip(stop_words_ids, stop_words)] |
|
|
|
stopping_criteria = StoppingCriteriaList( |
|
[StoppingCriteriaSub(stops=stop_words_ids, encounters=encounters, device=device)]) |
|
else: |
|
stopping_criteria = StoppingCriteriaList() |
|
return stopping_criteria |
|
import time |
|
from enum import Enum |
|
|
|
non_hf_types = ['gpt4all_llama', 'llama', 'gptj'] |
|
|
|
|
|
class PromptType(Enum): |
|
plain = 0 |
|
instruct = 1 |
|
quality = 2 |
|
human_bot = 3 |
|
dai_faq = 4 |
|
summarize = 5 |
|
simple_instruct = 6 |
|
instruct_vicuna = 7 |
|
instruct_with_end = 8 |
|
human_bot_orig = 9 |
|
prompt_answer = 10 |
|
open_assistant = 11 |
|
wizard_lm = 12 |
|
wizard_mega = 13 |
|
instruct_vicuna2 = 14 |
|
instruct_vicuna3 = 15 |
|
wizard2 = 16 |
|
wizard3 = 17 |
|
|
|
|
|
prompt_type_to_model_name = { |
|
'plain': [ |
|
'EleutherAI/gpt-j-6B', |
|
'EleutherAI/pythia-6.9b', |
|
'EleutherAI/pythia-12b', |
|
'EleutherAI/pythia-12b-deduped', |
|
'EleutherAI/gpt-neox-20b', |
|
'openlm-research/open_llama_7b_700bt_preview', |
|
'decapoda-research/llama-7b-hf', |
|
'decapoda-research/llama-13b-hf', |
|
'decapoda-research/llama-30b-hf', |
|
'decapoda-research/llama-65b-hf', |
|
'facebook/mbart-large-50-many-to-many-mmt', |
|
'philschmid/bart-large-cnn-samsum', |
|
'philschmid/flan-t5-base-samsum', |
|
'gpt2', |
|
'distilgpt2', |
|
'mosaicml/mpt-7b-storywriter', |
|
'mosaicml/mpt-7b-instruct', |
|
'mosaicml/mpt-7b-chat', |
|
'gptj', |
|
'llama', |
|
'gpt4all_llama', |
|
], |
|
'prompt_answer': [ |
|
'h2oai/h2ogpt-gm-oasst1-en-1024-20b', |
|
'h2oai/h2ogpt-gm-oasst1-en-1024-12b', |
|
'h2oai/h2ogpt-gm-oasst1-multilang-1024-20b', |
|
'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt', |
|
'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt-v2', |
|
'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-700bt', |
|
], |
|
'instruct': [], |
|
'instruct_with_end': ['databricks/dolly-v2-12b'], |
|
'quality': [], |
|
'human_bot': [ |
|
'h2oai/h2ogpt-oasst1-512-12b', |
|
'h2oai/h2ogpt-oasst1-512-20b', |
|
'h2oai/h2ogpt-oig-oasst1-256-6_9b', |
|
'h2oai/h2ogpt-oig-oasst1-512-6_9b', |
|
'h2oai/h2ogpt-oig-oasst1-256-6.9b', |
|
'h2oai/h2ogpt-oig-oasst1-512-6.9b', |
|
'h2oai/h2ogpt-research-oasst1-512-30b', |
|
'h2oai/h2ogpt-oasst1-falcon-40b', |
|
], |
|
'dai_faq': [], |
|
'summarize': [], |
|
'simple_instruct': ['t5-small', 't5-large', 'google/flan-t5', 'google/flan-t5-xxl', 'google/flan-ul2'], |
|
'instruct_vicuna': ['AlekseyKorshuk/vicuna-7b', 'TheBloke/stable-vicuna-13B-HF', 'junelee/wizard-vicuna-13b'], |
|
'human_bot_orig': ['togethercomputer/GPT-NeoXT-Chat-Base-20B'], |
|
"open_assistant": ['OpenAssistant/oasst-sft-7-llama-30b-xor', 'oasst-sft-7-llama-30b'], |
|
"wizard_lm": ['ehartford/WizardLM-7B-Uncensored', 'ehartford/WizardLM-13B-Uncensored'], |
|
"wizard_mega": ['openaccess-ai-collective/wizard-mega-13b'], |
|
} |
|
|
|
inv_prompt_type_to_model_name = {v.strip(): k for k, l in prompt_type_to_model_name.items() for v in l} |
|
inv_prompt_type_to_model_lower = {v.strip().lower(): k for k, l in prompt_type_to_model_name.items() for v in l} |
|
|
|
prompt_types_strings = [] |
|
for p in PromptType: |
|
prompt_types_strings.extend([p.name]) |
|
|
|
prompt_types = [] |
|
for p in PromptType: |
|
prompt_types.extend([p.name, p.value, str(p.value)]) |
|
|
|
|
|
def get_prompt(prompt_type, chat, context, reduced): |
|
if prompt_type in [PromptType.plain.value, str(PromptType.plain.value), |
|
PromptType.plain.name]: |
|
promptA = promptB = PreInstruct = PreInput = PreResponse = '' |
|
terminate_response = [] |
|
chat_sep = '' |
|
humanstr = '' |
|
botstr = '' |
|
elif prompt_type == 'simple_instruct': |
|
promptA = promptB = PreInstruct = PreInput = PreResponse = None |
|
terminate_response = [] |
|
chat_sep = '\n' |
|
humanstr = '' |
|
botstr = '' |
|
elif prompt_type in [PromptType.instruct.value, str(PromptType.instruct.value), |
|
PromptType.instruct.name] + [PromptType.instruct_with_end.value, |
|
str(PromptType.instruct_with_end.value), |
|
PromptType.instruct_with_end.name]: |
|
promptA = 'Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n' if not ( |
|
chat and reduced) else '' |
|
promptB = 'Below is an instruction that describes a task. Write a response that appropriately completes the request.\n' if not ( |
|
chat and reduced) else '' |
|
|
|
PreInstruct = """ |
|
### Instruction: |
|
""" |
|
|
|
PreInput = """ |
|
### Input: |
|
""" |
|
|
|
PreResponse = """ |
|
### Response: |
|
""" |
|
if prompt_type in [PromptType.instruct_with_end.value, str(PromptType.instruct_with_end.value), |
|
PromptType.instruct_with_end.name]: |
|
terminate_response = ['### End'] |
|
else: |
|
terminate_response = None |
|
chat_sep = '\n' |
|
humanstr = PreInstruct |
|
botstr = PreResponse |
|
elif prompt_type in [PromptType.quality.value, str(PromptType.quality.value), |
|
PromptType.quality.name]: |
|
promptA = 'Write a detailed high-quality, accurate, fair, Response with about 100 words by following the Instruction as applied on the Input.\n' if not ( |
|
chat and reduced) else '' |
|
promptB = 'Write a detailed high-quality, accurate, fair, Response with about 100 words by following the Instruction.\n' if not ( |
|
chat and reduced) else '' |
|
|
|
PreInstruct = """ |
|
### Instruction: |
|
""" |
|
|
|
PreInput = """ |
|
### Input: |
|
""" |
|
|
|
PreResponse = """ |
|
### Response: |
|
""" |
|
terminate_response = None |
|
chat_sep = '\n' |
|
humanstr = PreInstruct |
|
botstr = PreResponse |
|
elif prompt_type in [PromptType.human_bot.value, str(PromptType.human_bot.value), |
|
PromptType.human_bot.name] + [PromptType.human_bot_orig.value, |
|
str(PromptType.human_bot_orig.value), |
|
PromptType.human_bot_orig.name]: |
|
human = '<human>:' |
|
bot = "<bot>:" |
|
if reduced or context or prompt_type in [PromptType.human_bot.value, str(PromptType.human_bot.value), |
|
PromptType.human_bot.name]: |
|
preprompt = '' |
|
else: |
|
cur_date = time.strftime('%Y-%m-%d') |
|
cur_time = time.strftime('%H:%M:%S %p %Z') |
|
|
|
PRE_PROMPT = """\ |
|
Current Date: {} |
|
Current Time: {} |
|
|
|
""" |
|
preprompt = PRE_PROMPT.format(cur_date, cur_time) |
|
start = human |
|
promptB = promptA = '%s%s ' % (preprompt, start) |
|
|
|
PreInstruct = "" |
|
|
|
PreInput = None |
|
|
|
if reduced: |
|
|
|
PreResponse = bot + ' ' |
|
else: |
|
|
|
|
|
PreResponse = bot |
|
|
|
terminate_response = [start, PreResponse] |
|
chat_sep = '\n' |
|
humanstr = human |
|
botstr = bot |
|
elif prompt_type in [PromptType.dai_faq.value, str(PromptType.dai_faq.value), |
|
PromptType.dai_faq.name]: |
|
promptA = '' |
|
promptB = 'Answer the following Driverless AI question.\n' |
|
|
|
PreInstruct = """ |
|
### Driverless AI frequently asked question: |
|
""" |
|
|
|
PreInput = None |
|
|
|
PreResponse = """ |
|
### Driverless AI documentation answer: |
|
""" |
|
terminate_response = ['\n\n'] |
|
chat_sep = terminate_response |
|
humanstr = PreInstruct |
|
botstr = PreResponse |
|
elif prompt_type in [PromptType.summarize.value, str(PromptType.summarize.value), |
|
PromptType.summarize.name]: |
|
promptA = promptB = PreInput = '' |
|
PreInstruct = '## Main Text\n\n' |
|
PreResponse = '\n\n## Summary\n\n' |
|
terminate_response = None |
|
chat_sep = '\n' |
|
humanstr = PreInstruct |
|
botstr = PreResponse |
|
elif prompt_type in [PromptType.instruct_vicuna.value, str(PromptType.instruct_vicuna.value), |
|
PromptType.instruct_vicuna.name]: |
|
promptA = promptB = "A chat between a curious human and an artificial intelligence assistant. " \ |
|
"The assistant gives helpful, detailed, and polite answers to the human's questions." if not ( |
|
chat and reduced) else '' |
|
|
|
PreInstruct = """ |
|
### Human: |
|
""" |
|
|
|
PreInput = None |
|
|
|
PreResponse = """ |
|
### Assistant: |
|
""" |
|
terminate_response = [ |
|
'### Human:'] |
|
chat_sep = '\n' |
|
humanstr = PreInstruct |
|
botstr = PreResponse |
|
elif prompt_type in [PromptType.prompt_answer.value, str(PromptType.prompt_answer.value), |
|
PromptType.prompt_answer.name]: |
|
preprompt = '' |
|
prompt_tokens = "<|prompt|>" |
|
answer_tokens = "<|answer|>" |
|
start = prompt_tokens |
|
promptB = promptA = '%s%s' % (preprompt, start) |
|
PreInstruct = "" |
|
PreInput = None |
|
PreResponse = answer_tokens |
|
eos = '<|endoftext|>' |
|
terminate_response = [start, PreResponse, eos] |
|
chat_sep = eos |
|
humanstr = prompt_tokens |
|
botstr = answer_tokens |
|
elif prompt_type in [PromptType.open_assistant.value, str(PromptType.open_assistant.value), |
|
PromptType.open_assistant.name]: |
|
|
|
preprompt = '' |
|
prompt_tokens = "<|prompter|>" |
|
answer_tokens = "<|assistant|>" |
|
start = prompt_tokens |
|
promptB = promptA = '%s%s' % (preprompt, start) |
|
PreInstruct = "" |
|
PreInput = None |
|
PreResponse = answer_tokens |
|
pend = "<|prefix_end|>" |
|
eos = "</s>" |
|
terminate_response = [start, PreResponse, pend, eos] |
|
chat_sep = eos |
|
humanstr = prompt_tokens |
|
botstr = answer_tokens |
|
elif prompt_type in [PromptType.wizard_lm.value, str(PromptType.wizard_lm.value), |
|
PromptType.wizard_lm.name]: |
|
|
|
preprompt = '' |
|
start = '' |
|
promptB = promptA = '%s%s' % (preprompt, start) |
|
PreInstruct = "" |
|
PreInput = None |
|
PreResponse = "\n\n### Response\n" |
|
eos = "</s>" |
|
terminate_response = [PreResponse, eos] |
|
chat_sep = eos |
|
humanstr = promptA |
|
botstr = PreResponse |
|
elif prompt_type in [PromptType.wizard_mega.value, str(PromptType.wizard_mega.value), |
|
PromptType.wizard_mega.name]: |
|
preprompt = '' |
|
start = '' |
|
promptB = promptA = '%s%s' % (preprompt, start) |
|
PreInstruct = """ |
|
### Instruction: |
|
""" |
|
PreInput = None |
|
PreResponse = """ |
|
### Assistant: |
|
""" |
|
terminate_response = [PreResponse] |
|
chat_sep = '\n' |
|
humanstr = PreInstruct |
|
botstr = PreResponse |
|
elif prompt_type in [PromptType.instruct_vicuna2.value, str(PromptType.instruct_vicuna2.value), |
|
PromptType.instruct_vicuna2.name]: |
|
promptA = promptB = "" if not ( |
|
chat and reduced) else '' |
|
|
|
PreInstruct = """ |
|
HUMAN: |
|
""" |
|
|
|
PreInput = None |
|
|
|
PreResponse = """ |
|
ASSISTANT: |
|
""" |
|
terminate_response = [ |
|
'HUMAN:'] |
|
chat_sep = '\n' |
|
humanstr = PreInstruct |
|
botstr = PreResponse |
|
elif prompt_type in [PromptType.instruct_vicuna3.value, str(PromptType.instruct_vicuna3.value), |
|
PromptType.instruct_vicuna3.name]: |
|
promptA = promptB = "" if not ( |
|
chat and reduced) else '' |
|
|
|
PreInstruct = """ |
|
### User: |
|
""" |
|
|
|
PreInput = None |
|
|
|
PreResponse = """ |
|
### Assistant: |
|
""" |
|
terminate_response = [ |
|
'### User:'] |
|
chat_sep = '\n' |
|
humanstr = PreInstruct |
|
botstr = PreResponse |
|
elif prompt_type in [PromptType.wizard2.value, str(PromptType.wizard2.value), |
|
PromptType.wizard2.name]: |
|
|
|
preprompt = """Below is an instruction that describes a task. Write a response that appropriately completes the request.""" |
|
start = '' |
|
promptB = promptA = '%s%s' % (preprompt, start) |
|
PreInstruct = """ |
|
### Instruction: |
|
""" |
|
PreInput = None |
|
PreResponse = """ |
|
### Response: |
|
""" |
|
terminate_response = [PreResponse] |
|
chat_sep = '\n' |
|
humanstr = PreInstruct |
|
botstr = PreResponse |
|
elif prompt_type in [PromptType.wizard3.value, str(PromptType.wizard3.value), |
|
PromptType.wizard3.name]: |
|
|
|
preprompt = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.""" |
|
start = '' |
|
promptB = promptA = '%s%s' % (preprompt, start) |
|
PreInstruct = """USER: """ |
|
PreInput = None |
|
PreResponse = """ASSISTANT: """ |
|
terminate_response = [PreResponse] |
|
chat_sep = '\n' |
|
humanstr = PreInstruct |
|
botstr = PreResponse |
|
|
|
else: |
|
raise RuntimeError("No such prompt_type=%s" % prompt_type) |
|
|
|
return promptA, promptB, PreInstruct, PreInput, PreResponse, terminate_response, chat_sep, humanstr, botstr |
|
|
|
|
|
def generate_prompt(data_point, prompt_type, chat, reduced): |
|
context = data_point.get('context') |
|
if context is None: |
|
context = '' |
|
instruction = data_point.get('instruction') |
|
input = data_point.get('input') |
|
output = data_point.get('output') |
|
prompt_type = data_point.get('prompt_type', prompt_type) |
|
assert prompt_type in prompt_types, "Bad prompt type: %s" % prompt_type |
|
promptA, promptB, PreInstruct, PreInput, PreResponse, \ |
|
terminate_response, chat_sep, humanstr, botstr = get_prompt(prompt_type, chat, context, reduced) |
|
|
|
prompt = context if not reduced else '' |
|
|
|
if input and promptA: |
|
prompt += f"""{promptA}""" |
|
elif promptB: |
|
prompt += f"""{promptB}""" |
|
|
|
if instruction and PreInstruct is not None and input and PreInput is not None: |
|
prompt += f"""{PreInstruct}{instruction}{PreInput}{input}""" |
|
prompt = inject_newline(prompt_type, prompt) |
|
elif instruction and input and PreInstruct is None and PreInput is not None: |
|
prompt += f"""{PreInput}{instruction} |
|
{input}""" |
|
prompt = inject_newline(prompt_type, prompt) |
|
elif input and instruction and PreInput is None and PreInstruct is not None: |
|
prompt += f"""{PreInstruct}{instruction} |
|
{input}""" |
|
prompt = inject_newline(prompt_type, prompt) |
|
elif instruction and PreInstruct is not None: |
|
prompt += f"""{PreInstruct}{instruction}""" |
|
prompt = inject_newline(prompt_type, prompt) |
|
elif input and PreInput is not None: |
|
prompt += f"""{PreInput}{input}""" |
|
prompt = inject_newline(prompt_type, prompt) |
|
elif input and instruction and PreInput is not None: |
|
prompt += f"""{PreInput}{instruction}{input}""" |
|
prompt = inject_newline(prompt_type, prompt) |
|
elif input and instruction and PreInstruct is not None: |
|
prompt += f"""{PreInstruct}{instruction}{input}""" |
|
prompt = inject_newline(prompt_type, prompt) |
|
elif input and instruction: |
|
|
|
prompt += f"""{instruction}: {input}""" |
|
prompt = inject_newline(prompt_type, prompt) |
|
elif input: |
|
prompt += f"""{input}""" |
|
prompt = inject_newline(prompt_type, prompt) |
|
elif instruction: |
|
prompt += f"""{instruction}""" |
|
prompt = inject_newline(prompt_type, prompt) |
|
|
|
if PreResponse is not None: |
|
prompt += f"""{PreResponse}""" |
|
pre_response = PreResponse |
|
else: |
|
pre_response = '' |
|
|
|
if output: |
|
prompt += f"""{output}""" |
|
|
|
return prompt, pre_response, terminate_response, chat_sep |
|
|
|
|
|
def inject_newline(prompt_type, prompt): |
|
if prompt_type not in [-1, '-1', 'plain', 'simple_instruct']: |
|
|
|
prompt += '\n' |
|
return prompt |
|
|
|
|
|
class Prompter(object): |
|
def __init__(self, prompt_type, debug=False, chat=False, stream_output=False, repeat_penalty=True, |
|
allowed_repeat_line_length=10): |
|
self.prompt_type = prompt_type |
|
data_point = dict(instruction='', input='', output='') |
|
_, self.pre_response, self.terminate_response, self.chat_sep = \ |
|
generate_prompt(data_point, prompt_type, chat, False) |
|
self.debug = debug |
|
self.chat = chat |
|
self.stream_output = stream_output |
|
self.repeat_penalty = repeat_penalty |
|
self.allowed_repeat_line_length = allowed_repeat_line_length |
|
self.prompt = None |
|
context = "" |
|
reduced = False |
|
self.promptA, self.promptB, self.PreInstruct, self.PreInput, self.PreResponse, \ |
|
self.terminate_response, self.chat_sep, self.humanstr, self.botstr = \ |
|
get_prompt(prompt_type, chat, context, reduced) |
|
|
|
def generate_prompt(self, data_point): |
|
reduced = False |
|
prompt, _, _, _ = generate_prompt(data_point, self.prompt_type, self.chat, reduced) |
|
if self.debug: |
|
print("prompt: ", prompt, flush=True) |
|
self.prompt = prompt |
|
return prompt |
|
|
|
def get_response(self, outputs, prompt=None, sanitize_bot_response=True): |
|
if isinstance(outputs, str): |
|
outputs = [outputs] |
|
if self.debug: |
|
print("output:\n", '\n\n'.join(outputs), flush=True) |
|
if prompt is not None: |
|
self.prompt = prompt |
|
|
|
def clean_response(response): |
|
meaningless_words = ['<pad>', '</s>', '<|endoftext|>'] |
|
for word in meaningless_words: |
|
response = response.replace(word, "") |
|
if sanitize_bot_response: |
|
from better_profanity import profanity |
|
response = profanity.censor(response) |
|
response = response.strip("\n") |
|
return response |
|
|
|
def clean_repeats(response): |
|
lines = response.split('\n') |
|
new_lines = [] |
|
[new_lines.append(line) for line in lines if |
|
line not in new_lines or len(line) < self.allowed_repeat_line_length] |
|
if self.debug and len(lines) != len(new_lines): |
|
print("cleaned repeats: %s %s" % (len(lines), len(new_lines)), flush=True) |
|
response = '\n'.join(new_lines) |
|
return response |
|
|
|
multi_output = len(outputs) > 1 |
|
|
|
for oi, output in enumerate(outputs): |
|
if self.prompt_type in [PromptType.plain.value, str(PromptType.plain.value), PromptType.plain.name]: |
|
output = clean_response(output) |
|
elif prompt is None: |
|
|
|
if self.botstr in output: |
|
if self.humanstr: |
|
output = clean_response(output.split(self.botstr)[1].strip().split(self.humanstr)[0].strip()) |
|
else: |
|
|
|
output = clean_response(output.split(self.botstr)[1].strip().split(self.botstr)[0].strip()) |
|
else: |
|
|
|
|
|
output = "" |
|
else: |
|
|
|
|
|
|
|
if self.pre_response: |
|
outputi = output.find(prompt) |
|
if outputi >= 0: |
|
output = output[outputi + len(prompt):] |
|
allow_terminate = True |
|
else: |
|
|
|
output = output[len(prompt) - len(self.pre_response):] |
|
|
|
if self.pre_response in output: |
|
output = output.split(self.pre_response)[1] |
|
allow_terminate = True |
|
else: |
|
if output: |
|
print("Failure of parsing or not enough output yet: %s" % output, flush=True) |
|
allow_terminate = False |
|
else: |
|
allow_terminate = True |
|
output = output[len(prompt):] |
|
|
|
output = clean_response(output).strip() |
|
if self.repeat_penalty: |
|
output = clean_repeats(output).strip() |
|
if self.terminate_response and allow_terminate: |
|
finds = [] |
|
for term in self.terminate_response: |
|
finds.append(output.find(term)) |
|
finds = [x for x in finds if x >= 0] |
|
if len(finds) > 0: |
|
termi = finds[0] |
|
output = output[:termi].strip() |
|
else: |
|
output = output.strip() |
|
else: |
|
output = output.strip() |
|
if multi_output: |
|
|
|
output = "\n=========== Output %d\n\n" % (1 + oi) + output |
|
if oi > 0: |
|
|
|
output += '\n' |
|
outputs[oi] = output |
|
|
|
output = '\n'.join(outputs) |
|
if self.debug: |
|
print("outputclean:\n", '\n\n'.join(outputs), flush=True) |
|
return output |
|
|