#!/usr/bin/env python # coding=utf-8 import inspect import logging import nltk from typing import Tuple import torch from transformers import ( AutoTokenizer, BloomForCausalLM, BloomTokenizerFast, CTRLLMHeadModel, CTRLTokenizer, GenerationMixin, GPT2LMHeadModel, GPT2Tokenizer, GPTJForCausalLM, LlamaForCausalLM, LlamaTokenizer, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer, OPTForCausalLM, TransfoXLLMHeadModel, TransfoXLTokenizer, XLMTokenizer, XLMWithLMHeadModel, XLNetLMHeadModel, XLNetTokenizer, AutoModelForSeq2SeqLM, ) from transformers.modeling_outputs import CausalLMOutputWithPast from forbidden import FORBIDDEN_NOUN logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop MODEL_CLASSES = { "gpt2": (GPT2LMHeadModel, GPT2Tokenizer), "ctrl": (CTRLLMHeadModel, CTRLTokenizer), "openai-gpt": (OpenAIGPTLMHeadModel, OpenAIGPTTokenizer), "xlnet": (XLNetLMHeadModel, XLNetTokenizer), "transfo-xl": (TransfoXLLMHeadModel, TransfoXLTokenizer), "xlm": (XLMWithLMHeadModel, XLMTokenizer), "gptj": (GPTJForCausalLM, AutoTokenizer), "bloom": (BloomForCausalLM, BloomTokenizerFast), "llama": (LlamaForCausalLM, LlamaTokenizer), "opt": (OPTForCausalLM, GPT2Tokenizer), } FORBIDDEN_NOUN = set(FORBIDDEN_NOUN) class Translator: def __init__(self, model_name): self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name) def translate(self, text): inputs = self.tokenizer(text, return_tensors="pt", padding=True) outputs = self.model.generate(**inputs) translated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) return translated_text def __call__(self, text): return self.translate(text) # # Functions to prepare models' input # def prepare_ctrl_input(args, _, tokenizer, prompt_text): if args["temperature"] > 0.7: pass encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False) if not any(encoded_prompt[0] == x for x in tokenizer.control_codes.values()): pass return prompt_text def prepare_xlm_input(args, model, tokenizer, prompt_text): # kwargs = {"language": None, "mask_token_id": None} # Set the language use_lang_emb = hasattr(model.config, "use_lang_emb") and model.config.use_lang_emb if hasattr(model.config, "lang2id") and use_lang_emb: available_languages = model.config.lang2id.keys() if args["xlm_language"] in available_languages: language = args["xlm_language"] else: language = None while language not in available_languages: language = input("Using XLM. Select language in " + str(list(available_languages)) + " >>> ") model.config.lang_id = model.config.lang2id[language] # kwargs["language"] = tokenizer.lang2id[language] return prompt_text def prepare_xlnet_input(args, _, tokenizer, prompt_text): prefix = args["prefix"] if args["prefix"] else args["padding_text"] if args["padding_text"] else "" prompt_text = prefix + prompt_text return prompt_text def prepare_transfoxl_input(args, _, tokenizer, prompt_text): prefix = args["prefix"] if args["prefix"] else args["padding_text"] if args["padding_text"] else "" prompt_text = prefix + prompt_text return prompt_text PREPROCESSING_FUNCTIONS = { "ctrl": prepare_ctrl_input, "xlm": prepare_xlm_input, "xlnet": prepare_xlnet_input, "transfo-xl": prepare_transfoxl_input, } def adjust_length_to_model(length, max_sequence_length): if length < 0 and max_sequence_length > 0: length = max_sequence_length elif 0 < max_sequence_length < length: length = max_sequence_length # No generation bigger than model size elif length < 0: length = MAX_LENGTH # avoid infinite loop return length def sparse_model_config(model_config): embedding_size = None if hasattr(model_config, "hidden_size"): embedding_size = model_config.hidden_size elif hasattr(model_config, "n_embed"): embedding_size = model_config.n_embed elif hasattr(model_config, "n_embd"): embedding_size = model_config.n_embd num_head = None if hasattr(model_config, "num_attention_heads"): num_head = model_config.num_attention_heads elif hasattr(model_config, "n_head"): num_head = model_config.n_head if embedding_size is None or num_head is None or num_head == 0: raise ValueError("Check the model config") num_embedding_size_per_head = int(embedding_size / num_head) if hasattr(model_config, "n_layer"): num_layer = model_config.n_layer elif hasattr(model_config, "num_hidden_layers"): num_layer = model_config.num_hidden_layers else: raise ValueError("Number of hidden layers couldn't be determined from the model config") return num_layer, num_head, num_embedding_size_per_head def generate_past_key_values(model, batch_size, seq_len): num_block_layers, num_attention_heads, num_embedding_size_per_head = sparse_model_config(model.config) if model.config.model_type == "bloom": past_key_values = tuple( ( torch.empty(int(num_attention_heads * batch_size), num_embedding_size_per_head, seq_len) .to(model.dtype) .to(model.device), torch.empty(int(num_attention_heads * batch_size), seq_len, num_embedding_size_per_head) .to(model.dtype) .to(model.device), ) for _ in range(num_block_layers) ) else: past_key_values = tuple( ( torch.empty(batch_size, num_attention_heads, seq_len, num_embedding_size_per_head) .to(model.dtype) .to(model.device), torch.empty(batch_size, num_attention_heads, seq_len, num_embedding_size_per_head) .to(model.dtype) .to(model.device), ) for _ in range(num_block_layers) ) return past_key_values def prepare_jit_inputs(inputs, model, tokenizer): batch_size = len(inputs) dummy_input = tokenizer.batch_encode_plus(inputs, return_tensors="pt") dummy_input = dummy_input.to(model.device) if model.config.use_cache: dummy_input["past_key_values"] = generate_past_key_values(model, batch_size, 1) dummy_input["attention_mask"] = torch.cat( [ torch.zeros(dummy_input["attention_mask"].shape[0], 1) .to(dummy_input["attention_mask"].dtype) .to(model.device), dummy_input["attention_mask"], ], -1, ) return dummy_input class _ModelFallbackWrapper(GenerationMixin): __slots__ = ("_optimized", "_default") def __init__(self, optimized, default): self._optimized = optimized self._default = default def __call__(self, *args, **kwargs): if kwargs["past_key_values"] is None and self._default.config.use_cache: kwargs["past_key_values"] = generate_past_key_values(self._default, kwargs["input_ids"].shape[0], 0) kwargs.pop("position_ids", None) for k in list(kwargs.keys()): if kwargs[k] is None or isinstance(kwargs[k], bool): kwargs.pop(k) outputs = self._optimized(**kwargs) lm_logits = outputs[0] past_key_values = outputs[1] fixed_output = CausalLMOutputWithPast( loss=None, logits=lm_logits, past_key_values=past_key_values, hidden_states=None, attentions=None, ) return fixed_output def __getattr__(self, item): return getattr(self._default, item) def prepare_inputs_for_generation( self, input_ids, past_key_values=None, inputs_embeds=None, use_cache=None, **kwargs ): return self._default.prepare_inputs_for_generation( input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, **kwargs ) def _reorder_cache( self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor ) -> Tuple[Tuple[torch.Tensor]]: """ This function is used to re-order the `past_key_values` cache if [`~PretrainedModel.beam_search`] or [`~PretrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct beam_idx at every generation step. """ return self._default._reorder_cache(past_key_values, beam_idx) def remove_tokens_before_copula(text): sentences = text.split(",") result = [sentences[0]] for sentence in sentences[1:]: tokens = nltk.word_tokenize(sentence) target_indices = [i for i, token in enumerate(tokens) if token.lower() in ["is", "are", "am"]] if target_indices: last_target_index = target_indices[-1] result.append(tokens[last_target_index + 1:]) else: result.append(tokens) all_sentences = [" ".join(sen) for sen in result[1:]] all_sentences.insert(0, result[0]) result_text = ",".join(all_sentences) return result_text def generate_prompt( prompt_text, args, zh_en_translator, nlp, model, tokenizer, distributed_state, ): max_seq_length = getattr(model.config, "max_position_embeddings", 0) args["length"] = adjust_length_to_model(args["length"], max_sequence_length=max_seq_length) while(1): prompt_text = zh_en_translator(prompt_text) # only support single input. # Different models need different input formatting and/or extra arguments requires_preprocessing = args["model_type"] in PREPROCESSING_FUNCTIONS.keys() if requires_preprocessing: prepare_input = PREPROCESSING_FUNCTIONS.get(args["model_type"]) preprocessed_prompt_text = prepare_input(args, model, tokenizer, prompt_text) if model.__class__.__name__ in ["TransfoXLLMHeadModel"]: tokenizer_kwargs = {"add_space_before_punct_symbol": True} else: tokenizer_kwargs = {} encoded_prompt = tokenizer.encode( preprocessed_prompt_text, add_special_tokens=False, return_tensors="pt", **tokenizer_kwargs ) else: prefix = args["prefix"] if args["prefix"] else args["padding_text"] encoded_prompt = tokenizer.encode(prefix + prompt_text, add_special_tokens=False, return_tensors="pt") encoded_prompt = encoded_prompt.to(distributed_state.device) if encoded_prompt.size()[-1] == 0: input_ids = None else: input_ids = encoded_prompt if args["jit"]: jit_input_texts = ["enable jit"] jit_inputs = prepare_jit_inputs(jit_input_texts, model, tokenizer) torch._C._jit_set_texpr_fuser_enabled(False) model.config.return_dict = False if hasattr(model, "forward"): sig = inspect.signature(model.forward) else: sig = inspect.signature(model.__call__) jit_inputs = tuple(jit_inputs[key] for key in sig.parameters if jit_inputs.get(key, None) is not None) traced_model = torch.jit.trace(model, jit_inputs, strict=False) traced_model = torch.jit.freeze(traced_model.eval()) traced_model(*jit_inputs) traced_model(*jit_inputs) model = _ModelFallbackWrapper(traced_model, model) generated_sequences = [] for generated_sequence_idx in range(args["num_return_sequences"]): repeat_gen_time = 0 while(1): repeat_gen_time = repeat_gen_time + 1 generated_sequence = model.generate( input_ids=input_ids, length_penalty=args["length_penalty"], max_length=args["length"] + len(encoded_prompt[0]), temperature=args["temperature"], top_k=args["k"], top_p=args["p"], repetition_penalty=args["repetition_penalty"], do_sample=True, num_return_sequences=1, pad_token_id=tokenizer.pad_token_id ) # Remove the n_sequence dimension when returning single sequence if len(generated_sequence.shape) >1: generated_sequence.squeeze_() generated_sequence = generated_sequence.tolist() # Decode text text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True) # Remove all text after the stop token text = text[: text.find(args["stop_token"]) if args["stop_token"] else None] # Add the prompt at the beginning of the sequence. Remove the excess text that was used for pre-processing total_sequence = ( prompt_text + text[len(tokenizer.decode(encoded_prompt[0], clean_up_tokenization_spaces=True)) :] ) break total_sequence = remove_tokens_before_copula(total_sequence) generated_sequences.append(total_sequence) return generated_sequences if __name__ == "__main__": generate_prompt()