import argparse import nltk import torch import numpy as np import gradio as gr from nltk import sent_tokenize from transformers import ( RobertaTokenizer, RobertaForMaskedLM, LogitsProcessorList, TopKLogitsWarper, TemperatureLogitsWarper, TypicalLogitsWarper, ) nltk.download('punkt') device = "cuda" if torch.cuda.is_available() else "cpu" pretrained = "roberta-large" if device == "cuda" else "roberta-base" tokenizer = RobertaTokenizer.from_pretrained(pretrained) model = RobertaForMaskedLM.from_pretrained(pretrained) model = model.to(device) max_len = 20 top_k = 100 temperature = 1 typical_p = 0 burnin = 250 max_iter = 500 # adapted from https://github.com/nyu-dl/bert-gen def generate_step(out: object, gen_idx: int, top_k: int = top_k, temperature: float = temperature, typical_p: float = typical_p, sample: bool = False) -> list: """ Generate a word from from out[gen_idx] args: - out (torch.Tensor): tensor of logits of size batch_size x seq_len x vocab_size - gen_idx (int): location for which to generate - top_k (int): if >0, only sample from the top k most probable words - temperature (float): sampling temperature - typical_p (float): if >0 use typical sampling - sample (bool): if True, sample from full distribution. returns: - list: batch_size tokens """ logits = out.logits[:, gen_idx] warpers = LogitsProcessorList() if temperature: warpers.append(TemperatureLogitsWarper(temperature)) if top_k > 0: warpers.append(TopKLogitsWarper(top_k)) if typical_p > 0: if typical_p >= 1: typical_p = 0.999 warpers.append(TypicalLogitsWarper(typical_p)) logits = warpers(None, logits) if sample: probs = torch.nn.functional.softmax(logits, dim=-1) next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) else: next_tokens = torch.argmax(logits, dim=-1) return next_tokens.tolist() # adapted from https://github.com/nyu-dl/bert-gen def parallel_sequential_generation(seed_text: str, seed_end_text: str, max_len: int = max_len, top_k: int = top_k, temperature: float = temperature, typical_p: float = typical_p, max_iter: int = max_iter, burnin: int = burnin) -> str: """ Generate text consistent with preceding and following text Args: - seed_text (str): preceding text - seed_end_text (str): following text - top_k (int): if >0, only sample from the top k most probable words - temperature (float): sampling temperature - typical_p (float): if >0 use typical sampling - max_iter (int): number of iterations in MCMC - burnin: during burn-in period, sample from full distribution; afterwards take argmax Returns: - string: generated text to insert between seed_text and seed_end_text """ inp = tokenizer(seed_text + tokenizer.mask_token * max_len + seed_end_text, return_tensors='pt') masked_tokens = np.where( inp['input_ids'][0].numpy() == tokenizer.mask_token_id)[0] seed_len = masked_tokens[0] inp = inp.to(device) for ii in range(max_iter): kk = np.random.randint(0, max_len) idxs = generate_step(model(**inp), gen_idx=seed_len + kk, top_k=top_k if (ii >= burnin) else 0, temperature=temperature, typical_p=typical_p, sample=(ii < burnin)) inp['input_ids'][0][seed_len + kk] = idxs[0] tokens = inp['input_ids'].cpu().numpy()[0][masked_tokens] tokens = tokens[(np.where((tokens != tokenizer.eos_token_id) & (tokens != tokenizer.bos_token_id)))] return tokenizer.decode(tokens) def inbertolate(doc: str, max_len: int = max_len, top_k: int = top_k, temperature: float = temperature, typical_p: float = typical_p, max_iter: int = max_iter, burnin: int = burnin) -> str: """ Pad out document generating every other sentence Args: - doc (str): document text - max_len (int): number of tokens to insert between sentences - top_k (int): if >0, only sample from the top k most probable words - temperature (float): sampling temperature - typical_p (float): if >0 use typical sampling - max_iter (int): number of iterations in MCMC - burnin: during burn-in period, sample from full distribution; afterwards take argmax Returns: - string: generated text to insert between seed_text and seed_end_text """ new_doc = '' paras = doc.split('\n') for para in paras: para = sent_tokenize(para) if para == '': new_doc += '\n' continue para += [''] for sentence in range(len(para) - 1): new_doc += para[sentence] + ' ' new_doc += parallel_sequential_generation( para[sentence], para[sentence + 1], max_len=max_len, top_k=top_k, temperature=float(temperature), typical_p=typical_p, burnin=burnin, max_iter=max_iter) + ' ' new_doc += '\n' return new_doc demo = gr.Interface( fn=inbertolate, title="inBERTolate", description=f"Hit your word count by using BERT ({pretrained}) to pad out your essays!", inputs=[ gr.Textbox(label="Text", lines=10), gr.Slider(label="Maximum length to insert between sentences", minimum=1, maximum=40, step=1, value=max_len), gr.Slider(label="Top k", minimum=0, maximum=200, value=top_k), gr.Slider(label="Temperature", minimum=0, maximum=2, value=temperature), gr.Slider(label="Typical p", minimum=0, maximum=1, value=typical_p), gr.Slider(label="Maximum iterations", minimum=0, maximum=1000, value=max_iter), gr.Slider(label="Burn-in", minimum=0, maximum=500, value=burnin), ], outputs=gr.Textbox(label="Expanded text", lines=30)) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--port', type=int) parser.add_argument('--server', type=int) args = parser.parse_args() demo.launch(server_name=args.server or '0.0.0.0', server_port=args.port)