|
import torch |
|
import torch.nn.functional as F |
|
import os |
|
import argparse |
|
from tqdm import trange |
|
from transformers import GPT2LMHeadModel |
|
import gradio as gr |
|
|
|
|
|
def is_word(word): |
|
for item in list(word): |
|
if item not in 'qwertyuiopasdfghjklzxcvbnm': |
|
return False |
|
return True |
|
|
|
|
|
def _is_chinese_char(char): |
|
"""Checks whether CP is the codepoint of a CJK character.""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cp = ord(char) |
|
if ((cp >= 0x4E00 and cp <= 0x9FFF) or |
|
(cp >= 0x3400 and cp <= 0x4DBF) or |
|
(cp >= 0x20000 and cp <= 0x2A6DF) or |
|
(cp >= 0x2A700 and cp <= 0x2B73F) or |
|
(cp >= 0x2B740 and cp <= 0x2B81F) or |
|
(cp >= 0x2B820 and cp <= 0x2CEAF) or |
|
(cp >= 0xF900 and cp <= 0xFAFF) or |
|
(cp >= 0x2F800 and cp <= 0x2FA1F)): |
|
return True |
|
|
|
return False |
|
|
|
|
|
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): |
|
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering |
|
Args: |
|
logits: logits distribution shape (vocabulary size) |
|
top_k > 0: keep only top k tokens with highest probability (top-k filtering). |
|
top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). |
|
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) |
|
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 |
|
""" |
|
assert logits.dim() == 1 |
|
top_k = min(top_k, logits.size(-1)) |
|
if top_k > 0: |
|
|
|
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] |
|
logits[indices_to_remove] = filter_value |
|
|
|
if top_p > 0.0: |
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True) |
|
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
|
|
|
|
|
sorted_indices_to_remove = cumulative_probs > top_p |
|
|
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
|
sorted_indices_to_remove[..., 0] = 0 |
|
|
|
indices_to_remove = sorted_indices[sorted_indices_to_remove] |
|
logits[indices_to_remove] = filter_value |
|
return logits |
|
|
|
|
|
def sample_sequence(model, context, length, n_ctx, tokenizer, temperature=1.0, top_k=30, top_p=0.0, repitition_penalty=1.0, |
|
device='cpu'): |
|
context = torch.tensor(context, dtype=torch.long, device=device) |
|
context = context.unsqueeze(0) |
|
generated = context |
|
with torch.no_grad(): |
|
for _ in trange(length): |
|
inputs = {'input_ids': generated[0][-(n_ctx - 1):].unsqueeze(0)} |
|
outputs = model( |
|
**inputs) |
|
next_token_logits = outputs[0][0, -1, :] |
|
for id in set(generated): |
|
next_token_logits[id] /= repitition_penalty |
|
next_token_logits = next_token_logits / temperature |
|
next_token_logits[tokenizer.convert_tokens_to_ids('[UNK]')] = -float('Inf') |
|
filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p) |
|
next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1) |
|
generated = torch.cat((generated, next_token.unsqueeze(0)), dim=1) |
|
return generated.tolist()[0] |
|
|
|
|
|
def fast_sample_sequence(model, context, length, temperature=1.0, top_k=30, top_p=0.0, device='cpu'): |
|
inputs = torch.LongTensor(context).view(1, -1).to(device) |
|
if len(context) > 1: |
|
_, past = model(inputs[:, :-1], None)[:2] |
|
prev = inputs[:, -1].view(1, -1) |
|
else: |
|
past = None |
|
prev = inputs |
|
generate = [] + context |
|
with torch.no_grad(): |
|
for i in trange(length): |
|
output = model(prev, past=past) |
|
output, past = output[:2] |
|
output = output[-1].squeeze(0) / temperature |
|
filtered_logits = top_k_top_p_filtering(output, top_k=top_k, top_p=top_p) |
|
next_token = torch.multinomial(torch.softmax(filtered_logits, dim=-1), num_samples=1) |
|
generate.append(next_token.item()) |
|
prev = next_token.view(1, 1) |
|
return generate |
|
|
|
|
|
|
|
def generate(n_ctx, model, context, length, tokenizer, temperature=1, top_k=0, top_p=0.0, repitition_penalty=1.0, device='cpu', |
|
is_fast_pattern=False): |
|
if is_fast_pattern: |
|
return fast_sample_sequence(model, context, length, temperature=temperature, top_k=top_k, top_p=top_p, |
|
device=device) |
|
else: |
|
return sample_sequence(model, context, length, n_ctx, tokenizer=tokenizer, temperature=temperature, top_k=top_k, top_p=top_p, |
|
repitition_penalty=repitition_penalty, device=device) |
|
|
|
def smp_generate(pre_str): |
|
|
|
from tokenizations import tokenization_bert |
|
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = '0,1,2,3' |
|
length = 500 |
|
batch_size = 1 |
|
nsamples = 1 |
|
temperature = 1 |
|
topk = 8 |
|
topp = 0 |
|
repetition_penalty = 1.0 |
|
model_path = 'pretrained' |
|
tokenizer_path = 'cache/vocab.txt' |
|
save_samples = False |
|
save_samples_path = '.' |
|
fast_pattern = True |
|
prefix = pre_str |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
tokenizer = tokenization_bert.BertTokenizer(vocab_file=tokenizer_path) |
|
model = GPT2LMHeadModel.from_pretrained(model_path) |
|
model.to(device) |
|
model.eval() |
|
|
|
n_ctx = model.config.n_ctx |
|
|
|
if length == -1: |
|
length = model.config.n_ctx |
|
|
|
while True: |
|
raw_text = prefix |
|
context_tokens = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(raw_text)) |
|
generated = 0 |
|
for _ in range(nsamples // batch_size): |
|
out = generate( |
|
n_ctx=n_ctx, |
|
model=model, |
|
context=context_tokens, |
|
length=length, |
|
is_fast_pattern=fast_pattern, tokenizer=tokenizer, |
|
temperature=temperature, top_k=topk, top_p=topp, repitition_penalty=repetition_penalty, device=device |
|
) |
|
for i in range(batch_size): |
|
generated += 1 |
|
text = tokenizer.convert_ids_to_tokens(out) |
|
for i, item in enumerate(text[:-1]): |
|
if is_word(item) and is_word(text[i + 1]): |
|
text[i] = item + ' ' |
|
for i, item in enumerate(text): |
|
if item == '[MASK]': |
|
text[i] = '' |
|
elif item == '[CLS]': |
|
text[i] = '\n\n' |
|
elif item == '[SEP]': |
|
text[i] = '\n' |
|
info = "=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40 + "\n" |
|
text = ''.join(text).replace('##', '').strip() |
|
return text |
|
|
|
|
|
def format_text(text): |
|
return "<p>" + text.replace("\n", "<br>") + "</p>" |
|
|
|
input_textbox = gr.inputs.Textbox(label="输入前缀") |
|
output_textbox = gr.outputs.Textbox(label="生成文言文") |
|
|
|
|
|
html_content = """ |
|
<div style="display: flex; flex-direction: column-reverse;"> |
|
<div style="flex-grow: 1; overflow-y: auto;"> |
|
{output} |
|
</div> |
|
<div style="margin-top: 10px;"> |
|
{input} |
|
</div> |
|
</div> |
|
""" |
|
|
|
iface = gr.Interface(fn=smp_generate, inputs=input_textbox, outputs=output_textbox, |
|
title="文言文生成器", layout="vertical", layout_mode="size", |
|
layout_alignments=["center", "top"], template="gradio/custom.html", |
|
html=html_content) |
|
|
|
iface.launch() |
|
|
|
|
|
|
|
|