import torch import torch.nn.functional as F import os import argparse from tqdm import trange from transformers import GPT2LMHeadModel def is_word(word): for item in list(word): if item not in 'qwertyuiopasdfghjklzxcvbnm': return False return True def _is_chinese_char(char): """Checks whether CP is the codepoint of a CJK character.""" # This defines a "chinese character" as anything in the CJK Unicode block: # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) # # Note that the CJK Unicode block is NOT all Japanese and Korean characters, # despite its name. The modern Korean Hangul alphabet is a different block, # as is Japanese Hiragana and Katakana. Those alphabets are used to write # space-separated words, so they are not treated specially and handled # like the all of the other languages. cp = ord(char) if ((cp >= 0x4E00 and cp <= 0x9FFF) or # (cp >= 0x3400 and cp <= 0x4DBF) or # (cp >= 0x20000 and cp <= 0x2A6DF) or # (cp >= 0x2A700 and cp <= 0x2B73F) or # (cp >= 0x2B740 and cp <= 0x2B81F) or # (cp >= 0x2B820 and cp <= 0x2CEAF) or (cp >= 0xF900 and cp <= 0xFAFF) or # (cp >= 0x2F800 and cp <= 0x2FA1F)): # return True return False def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering Args: logits: logits distribution shape (vocabulary size) top_k > 0: keep only top k tokens with highest probability (top-k filtering). top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 """ assert logits.dim() == 1 # batch size 1 for now - could be updated for more but the code would be less clear top_k = min(top_k, logits.size(-1)) # Safety check if top_k > 0: # Remove all tokens with a probability less than the last token of the top-k indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] logits[indices_to_remove] = filter_value if top_p > 0.0: sorted_logits, sorted_indices = torch.sort(logits, descending=True) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) # Remove tokens with cumulative probability above the threshold sorted_indices_to_remove = cumulative_probs > top_p # Shift the indices to the right to keep also the first token above the threshold sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 indices_to_remove = sorted_indices[sorted_indices_to_remove] logits[indices_to_remove] = filter_value return logits def sample_sequence(model, context, length, n_ctx, tokenizer, temperature=1.0, top_k=30, top_p=0.0, repitition_penalty=1.0, device='cpu'): context = torch.tensor(context, dtype=torch.long, device=device) context = context.unsqueeze(0) generated = context with torch.no_grad(): for _ in trange(length): inputs = {'input_ids': generated[0][-(n_ctx - 1):].unsqueeze(0)} outputs = model( **inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states) next_token_logits = outputs[0][0, -1, :] for id in set(generated): next_token_logits[id] /= repitition_penalty next_token_logits = next_token_logits / temperature next_token_logits[tokenizer.convert_tokens_to_ids('[UNK]')] = -float('Inf') filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p) next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1) generated = torch.cat((generated, next_token.unsqueeze(0)), dim=1) return generated.tolist()[0] def fast_sample_sequence(model, context, length, temperature=1.0, top_k=30, top_p=0.0, device='cpu'): inputs = torch.LongTensor(context).view(1, -1).to(device) if len(context) > 1: _, past = model(inputs[:, :-1], None)[:2] prev = inputs[:, -1].view(1, -1) else: past = None prev = inputs generate = [] + context with torch.no_grad(): for i in trange(length): output = model(prev, past=past) output, past = output[:2] output = output[-1].squeeze(0) / temperature filtered_logits = top_k_top_p_filtering(output, top_k=top_k, top_p=top_p) next_token = torch.multinomial(torch.softmax(filtered_logits, dim=-1), num_samples=1) generate.append(next_token.item()) prev = next_token.view(1, 1) return generate # 通过命令行参数--fast_pattern,指定模式 def generate(n_ctx, model, context, length, tokenizer, temperature=1, top_k=0, top_p=0.0, repitition_penalty=1.0, device='cpu', is_fast_pattern=False): if is_fast_pattern: return fast_sample_sequence(model, context, length, temperature=temperature, top_k=top_k, top_p=top_p, device=device) else: return sample_sequence(model, context, length, n_ctx, tokenizer=tokenizer, temperature=temperature, top_k=top_k, top_p=top_p, repitition_penalty=repitition_penalty, device=device) def smp_generate(pre_str): from tokenizations import tokenization_bert os.environ["CUDA_VISIBLE_DEVICES"] = '0,1,2,3' # 此处设置程序使用哪些显卡 length = 500 batch_size = 1 nsamples = 1 temperature = 1 topk = 8 topp = 0 repetition_penalty = 1.0 model_path = 'pretrained' tokenizer_path = 'cache/vocab.txt' save_samples = False save_samples_path = '.' fast_pattern = True prefix = pre_str device = "cuda" if torch.cuda.is_available() else "cpu" tokenizer = tokenization_bert.BertTokenizer(vocab_file=tokenizer_path) model = GPT2LMHeadModel.from_pretrained(model_path) model.to(device) model.eval() n_ctx = model.config.n_ctx if length == -1: length = model.config.n_ctx while True: raw_text = prefix context_tokens = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(raw_text)) generated = 0 for _ in range(nsamples // batch_size): out = generate( n_ctx=n_ctx, model=model, context=context_tokens, length=length, is_fast_pattern=fast_pattern, tokenizer=tokenizer, temperature=temperature, top_k=topk, top_p=topp, repitition_penalty=repetition_penalty, device=device ) for i in range(batch_size): generated += 1 text = tokenizer.convert_ids_to_tokens(out) for i, item in enumerate(text[:-1]): # 确保英文前后有空格 if is_word(item) and is_word(text[i + 1]): text[i] = item + ' ' for i, item in enumerate(text): if item == '[MASK]': text[i] = '' elif item == '[CLS]': text[i] = '\n\n' elif item == '[SEP]': text[i] = '\n' info = "=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40 + "\n" text = ''.join(text).replace('##', '').strip() return text if __name__ == '__main__': print(smp_generate('曹贼休走'))