EasyPrompt / gpt2_generation.py
Trace2333's picture
delete some files and retain necesary files.
a3cbb87
raw
history blame
No virus
13.8 kB
#!/usr/bin/env python
# coding=utf-8
import inspect
import logging
import nltk
from typing import Tuple
import torch
from transformers import (
AutoTokenizer,
BloomForCausalLM,
BloomTokenizerFast,
CTRLLMHeadModel,
CTRLTokenizer,
GenerationMixin,
GPT2LMHeadModel,
GPT2Tokenizer,
GPTJForCausalLM,
LlamaForCausalLM,
LlamaTokenizer,
OpenAIGPTLMHeadModel,
OpenAIGPTTokenizer,
OPTForCausalLM,
TransfoXLLMHeadModel,
TransfoXLTokenizer,
XLMTokenizer,
XLMWithLMHeadModel,
XLNetLMHeadModel,
XLNetTokenizer,
AutoModelForSeq2SeqLM,
)
from transformers.modeling_outputs import CausalLMOutputWithPast
from forbidden import FORBIDDEN_NOUN
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop
MODEL_CLASSES = {
"gpt2": (GPT2LMHeadModel, GPT2Tokenizer),
"ctrl": (CTRLLMHeadModel, CTRLTokenizer),
"openai-gpt": (OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
"xlnet": (XLNetLMHeadModel, XLNetTokenizer),
"transfo-xl": (TransfoXLLMHeadModel, TransfoXLTokenizer),
"xlm": (XLMWithLMHeadModel, XLMTokenizer),
"gptj": (GPTJForCausalLM, AutoTokenizer),
"bloom": (BloomForCausalLM, BloomTokenizerFast),
"llama": (LlamaForCausalLM, LlamaTokenizer),
"opt": (OPTForCausalLM, GPT2Tokenizer),
}
FORBIDDEN_NOUN = set(FORBIDDEN_NOUN)
class Translator:
def __init__(self, model_name):
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
def translate(self, text):
inputs = self.tokenizer(text, return_tensors="pt", padding=True)
outputs = self.model.generate(**inputs)
translated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
return translated_text
def __call__(self, text):
return self.translate(text)
#
# Functions to prepare models' input
#
def prepare_ctrl_input(args, _, tokenizer, prompt_text):
if args["temperature"] > 0.7:
pass
encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False)
if not any(encoded_prompt[0] == x for x in tokenizer.control_codes.values()):
pass
return prompt_text
def prepare_xlm_input(args, model, tokenizer, prompt_text):
# kwargs = {"language": None, "mask_token_id": None}
# Set the language
use_lang_emb = hasattr(model.config, "use_lang_emb") and model.config.use_lang_emb
if hasattr(model.config, "lang2id") and use_lang_emb:
available_languages = model.config.lang2id.keys()
if args["xlm_language"] in available_languages:
language = args["xlm_language"]
else:
language = None
while language not in available_languages:
language = input("Using XLM. Select language in " + str(list(available_languages)) + " >>> ")
model.config.lang_id = model.config.lang2id[language]
# kwargs["language"] = tokenizer.lang2id[language]
return prompt_text
def prepare_xlnet_input(args, _, tokenizer, prompt_text):
prefix = args["prefix"] if args["prefix"] else args["padding_text"] if args["padding_text"] else ""
prompt_text = prefix + prompt_text
return prompt_text
def prepare_transfoxl_input(args, _, tokenizer, prompt_text):
prefix = args["prefix"] if args["prefix"] else args["padding_text"] if args["padding_text"] else ""
prompt_text = prefix + prompt_text
return prompt_text
PREPROCESSING_FUNCTIONS = {
"ctrl": prepare_ctrl_input,
"xlm": prepare_xlm_input,
"xlnet": prepare_xlnet_input,
"transfo-xl": prepare_transfoxl_input,
}
def adjust_length_to_model(length, max_sequence_length):
if length < 0 and max_sequence_length > 0:
length = max_sequence_length
elif 0 < max_sequence_length < length:
length = max_sequence_length # No generation bigger than model size
elif length < 0:
length = MAX_LENGTH # avoid infinite loop
return length
def sparse_model_config(model_config):
embedding_size = None
if hasattr(model_config, "hidden_size"):
embedding_size = model_config.hidden_size
elif hasattr(model_config, "n_embed"):
embedding_size = model_config.n_embed
elif hasattr(model_config, "n_embd"):
embedding_size = model_config.n_embd
num_head = None
if hasattr(model_config, "num_attention_heads"):
num_head = model_config.num_attention_heads
elif hasattr(model_config, "n_head"):
num_head = model_config.n_head
if embedding_size is None or num_head is None or num_head == 0:
raise ValueError("Check the model config")
num_embedding_size_per_head = int(embedding_size / num_head)
if hasattr(model_config, "n_layer"):
num_layer = model_config.n_layer
elif hasattr(model_config, "num_hidden_layers"):
num_layer = model_config.num_hidden_layers
else:
raise ValueError("Number of hidden layers couldn't be determined from the model config")
return num_layer, num_head, num_embedding_size_per_head
def generate_past_key_values(model, batch_size, seq_len):
num_block_layers, num_attention_heads, num_embedding_size_per_head = sparse_model_config(model.config)
if model.config.model_type == "bloom":
past_key_values = tuple(
(
torch.empty(int(num_attention_heads * batch_size), num_embedding_size_per_head, seq_len)
.to(model.dtype)
.to(model.device),
torch.empty(int(num_attention_heads * batch_size), seq_len, num_embedding_size_per_head)
.to(model.dtype)
.to(model.device),
)
for _ in range(num_block_layers)
)
else:
past_key_values = tuple(
(
torch.empty(batch_size, num_attention_heads, seq_len, num_embedding_size_per_head)
.to(model.dtype)
.to(model.device),
torch.empty(batch_size, num_attention_heads, seq_len, num_embedding_size_per_head)
.to(model.dtype)
.to(model.device),
)
for _ in range(num_block_layers)
)
return past_key_values
def prepare_jit_inputs(inputs, model, tokenizer):
batch_size = len(inputs)
dummy_input = tokenizer.batch_encode_plus(inputs, return_tensors="pt")
dummy_input = dummy_input.to(model.device)
if model.config.use_cache:
dummy_input["past_key_values"] = generate_past_key_values(model, batch_size, 1)
dummy_input["attention_mask"] = torch.cat(
[
torch.zeros(dummy_input["attention_mask"].shape[0], 1)
.to(dummy_input["attention_mask"].dtype)
.to(model.device),
dummy_input["attention_mask"],
],
-1,
)
return dummy_input
class _ModelFallbackWrapper(GenerationMixin):
__slots__ = ("_optimized", "_default")
def __init__(self, optimized, default):
self._optimized = optimized
self._default = default
def __call__(self, *args, **kwargs):
if kwargs["past_key_values"] is None and self._default.config.use_cache:
kwargs["past_key_values"] = generate_past_key_values(self._default, kwargs["input_ids"].shape[0], 0)
kwargs.pop("position_ids", None)
for k in list(kwargs.keys()):
if kwargs[k] is None or isinstance(kwargs[k], bool):
kwargs.pop(k)
outputs = self._optimized(**kwargs)
lm_logits = outputs[0]
past_key_values = outputs[1]
fixed_output = CausalLMOutputWithPast(
loss=None,
logits=lm_logits,
past_key_values=past_key_values,
hidden_states=None,
attentions=None,
)
return fixed_output
def __getattr__(self, item):
return getattr(self._default, item)
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, inputs_embeds=None, use_cache=None, **kwargs
):
return self._default.prepare_inputs_for_generation(
input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, **kwargs
)
def _reorder_cache(
self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
) -> Tuple[Tuple[torch.Tensor]]:
"""
This function is used to re-order the `past_key_values` cache if [`~PretrainedModel.beam_search`] or
[`~PretrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
beam_idx at every generation step.
"""
return self._default._reorder_cache(past_key_values, beam_idx)
def remove_tokens_before_copula(text):
sentences = text.split(",")
result = [sentences[0]]
for sentence in sentences[1:]:
tokens = nltk.word_tokenize(sentence)
target_indices = [i for i, token in enumerate(tokens) if token.lower() in ["is", "are", "am"]]
if target_indices:
last_target_index = target_indices[-1]
result.append(tokens[last_target_index + 1:])
else:
result.append(tokens)
all_sentences = [" ".join(sen) for sen in result[1:]]
all_sentences.insert(0, result[0])
result_text = ",".join(all_sentences)
return result_text
def generate_prompt(
prompt_text,
args,
zh_en_translator,
nlp,
model,
tokenizer,
distributed_state,
):
max_seq_length = getattr(model.config, "max_position_embeddings", 0)
args["length"] = adjust_length_to_model(args["length"], max_sequence_length=max_seq_length)
while(1):
prompt_text = zh_en_translator(prompt_text)
# only support single input.
# Different models need different input formatting and/or extra arguments
requires_preprocessing = args["model_type"] in PREPROCESSING_FUNCTIONS.keys()
if requires_preprocessing:
prepare_input = PREPROCESSING_FUNCTIONS.get(args["model_type"])
preprocessed_prompt_text = prepare_input(args, model, tokenizer, prompt_text)
if model.__class__.__name__ in ["TransfoXLLMHeadModel"]:
tokenizer_kwargs = {"add_space_before_punct_symbol": True}
else:
tokenizer_kwargs = {}
encoded_prompt = tokenizer.encode(
preprocessed_prompt_text, add_special_tokens=False, return_tensors="pt", **tokenizer_kwargs
)
else:
prefix = args["prefix"] if args["prefix"] else args["padding_text"]
encoded_prompt = tokenizer.encode(prefix + prompt_text, add_special_tokens=False, return_tensors="pt")
encoded_prompt = encoded_prompt.to(distributed_state.device)
if encoded_prompt.size()[-1] == 0:
input_ids = None
else:
input_ids = encoded_prompt
if args["jit"]:
jit_input_texts = ["enable jit"]
jit_inputs = prepare_jit_inputs(jit_input_texts, model, tokenizer)
torch._C._jit_set_texpr_fuser_enabled(False)
model.config.return_dict = False
if hasattr(model, "forward"):
sig = inspect.signature(model.forward)
else:
sig = inspect.signature(model.__call__)
jit_inputs = tuple(jit_inputs[key] for key in sig.parameters if jit_inputs.get(key, None) is not None)
traced_model = torch.jit.trace(model, jit_inputs, strict=False)
traced_model = torch.jit.freeze(traced_model.eval())
traced_model(*jit_inputs)
traced_model(*jit_inputs)
model = _ModelFallbackWrapper(traced_model, model)
generated_sequences = []
for generated_sequence_idx in range(args["num_return_sequences"]):
repeat_gen_time = 0
while(1):
repeat_gen_time = repeat_gen_time + 1
generated_sequence = model.generate(
input_ids=input_ids,
length_penalty=args["length_penalty"],
max_length=args["length"] + len(encoded_prompt[0]),
temperature=args["temperature"],
top_k=args["k"],
top_p=args["p"],
repetition_penalty=args["repetition_penalty"],
do_sample=True,
num_return_sequences=1,
pad_token_id=tokenizer.pad_token_id
)
# Remove the n_sequence dimension when returning single sequence
if len(generated_sequence.shape) >1:
generated_sequence.squeeze_()
generated_sequence = generated_sequence.tolist()
# Decode text
text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
# Remove all text after the stop token
text = text[: text.find(args["stop_token"]) if args["stop_token"] else None]
# Add the prompt at the beginning of the sequence. Remove the excess text that was used for pre-processing
total_sequence = (
prompt_text + text[len(tokenizer.decode(encoded_prompt[0], clean_up_tokenization_spaces=True)) :]
)
break
total_sequence = remove_tokens_before_copula(total_sequence)
generated_sequences.append(total_sequence)
return generated_sequences
if __name__ == "__main__":
generate_prompt()