|
import os |
|
|
|
import argparse |
|
import logging |
|
|
|
import numpy as np |
|
import torch |
|
import datetime |
|
import gradio as gr |
|
|
|
from transformers import ( |
|
CTRLLMHeadModel, |
|
CTRLTokenizer, |
|
GPT2LMHeadModel, |
|
GPT2Tokenizer, |
|
OpenAIGPTLMHeadModel, |
|
OpenAIGPTTokenizer, |
|
TransfoXLLMHeadModel, |
|
TransfoXLTokenizer, |
|
XLMTokenizer, |
|
XLMWithLMHeadModel, |
|
XLNetLMHeadModel, |
|
XLNetTokenizer, |
|
) |
|
|
|
|
|
logging.basicConfig( |
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
MAX_LENGTH = int(10000) |
|
|
|
MODEL_CLASSES = { |
|
"gpt2": (GPT2LMHeadModel, GPT2Tokenizer), |
|
"ctrl": (CTRLLMHeadModel, CTRLTokenizer), |
|
"openai-gpt": (OpenAIGPTLMHeadModel, OpenAIGPTTokenizer), |
|
"xlnet": (XLNetLMHeadModel, XLNetTokenizer), |
|
"transfo-xl": (TransfoXLLMHeadModel, TransfoXLTokenizer), |
|
"xlm": (XLMWithLMHeadModel, XLMTokenizer), |
|
} |
|
|
|
def set_seed(args): |
|
rd = np.random.randint(100000) |
|
print('seed =', rd) |
|
np.random.seed(rd) |
|
torch.manual_seed(rd) |
|
if args.n_gpu > 0: |
|
torch.cuda.manual_seed_all(rd) |
|
|
|
|
|
|
|
|
|
|
|
|
|
def prepare_ctrl_input(args, _, tokenizer, prompt_text): |
|
if args.temperature > 0.7: |
|
logger.info("CTRL typically works better with lower temperatures (and lower top_k).") |
|
|
|
encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False) |
|
if not any(encoded_prompt[0] == x for x in tokenizer.control_codes.values()): |
|
logger.info("WARNING! You are not starting your generation from a control code so you won't get good results") |
|
return prompt_text |
|
|
|
|
|
def prepare_xlm_input(args, model, tokenizer, prompt_text): |
|
|
|
|
|
|
|
use_lang_emb = hasattr(model.config, "use_lang_emb") and model.config.use_lang_emb |
|
if hasattr(model.config, "lang2id") and use_lang_emb: |
|
available_languages = model.config.lang2id.keys() |
|
if args.xlm_language in available_languages: |
|
language = args.xlm_language |
|
else: |
|
language = None |
|
while language not in available_languages: |
|
language = input("Using XLM. Select language in " + str(list(available_languages)) + " >>> ") |
|
|
|
model.config.lang_id = model.config.lang2id[language] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return prompt_text |
|
|
|
|
|
def prepare_xlnet_input(args, _, tokenizer, prompt_text): |
|
prompt_text = (args.padding_text if args.padding_text else PADDING_TEXT) + prompt_text |
|
return prompt_text |
|
|
|
|
|
def prepare_transfoxl_input(args, _, tokenizer, prompt_text): |
|
prompt_text = (args.padding_text if args.padding_text else PADDING_TEXT) + prompt_text |
|
return prompt_text |
|
|
|
|
|
PREPROCESSING_FUNCTIONS = { |
|
"ctrl": prepare_ctrl_input, |
|
"xlm": prepare_xlm_input, |
|
"xlnet": prepare_xlnet_input, |
|
"transfo-xl": prepare_transfoxl_input, |
|
} |
|
|
|
|
|
def adjust_length_to_model(length, max_sequence_length): |
|
if length < 0 and max_sequence_length > 0: |
|
length = max_sequence_length |
|
elif 0 < max_sequence_length < length: |
|
length = max_sequence_length |
|
elif length < 0: |
|
length = MAX_LENGTH |
|
return length |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--model_type", |
|
default=None, |
|
type=str, |
|
required=True, |
|
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()), |
|
) |
|
parser.add_argument( |
|
"--model_name_or_path", |
|
default=None, |
|
type=str, |
|
required=True, |
|
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(MODEL_CLASSES.keys()), |
|
) |
|
|
|
parser.add_argument("--prompt", type=str, default="") |
|
parser.add_argument("--length", type=int, default=20) |
|
parser.add_argument("--stop_token", type=str, default="</s>", help="Token at which lyrics generation is stopped") |
|
|
|
parser.add_argument( |
|
"--temperature", |
|
type=float, |
|
default=1.0, |
|
help="temperature of 1.0 has no effect, lower tend toward greedy sampling", |
|
) |
|
parser.add_argument( |
|
"--repetition_penalty", type=float, default=1.0, help="primarily useful for CTRL model; in that case, use 1.2" |
|
) |
|
parser.add_argument("--k", type=int, default=0) |
|
parser.add_argument("--p", type=float, default=0.9) |
|
|
|
parser.add_argument("--padding_text", type=str, default="", help="Padding lyrics for Transfo-XL and XLNet.") |
|
parser.add_argument("--xlm_language", type=str, default="", help="Optional language when used with the XLM model.") |
|
|
|
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") |
|
parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available") |
|
parser.add_argument("--num_return_sequences", type=int, default=1, help="The number of samples to generate.") |
|
args = parser.parse_args() |
|
|
|
args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") |
|
args.n_gpu = 0 if args.no_cuda else torch.cuda.device_count() |
|
|
|
|
|
try: |
|
args.model_type = args.model_type.lower() |
|
model_class, tokenizer_class = MODEL_CLASSES[args.model_type] |
|
except KeyError: |
|
raise KeyError("the model {} you specified is not supported. You are welcome to add it and open a PR :)") |
|
|
|
tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path) |
|
model = model_class.from_pretrained(args.model_name_or_path) |
|
model.to(args.device) |
|
|
|
args.length = adjust_length_to_model(args.length, max_sequence_length=model.config.max_position_embeddings) |
|
logger.info(args) |
|
generated_sequences = [] |
|
prompt_text = "" |
|
while prompt_text != "stop": |
|
set_seed(args) |
|
while not len(prompt_text): |
|
prompt_text = args.prompt if args.prompt else input("Context >>> ") |
|
|
|
|
|
requires_preprocessing = args.model_type in PREPROCESSING_FUNCTIONS.keys() |
|
if requires_preprocessing: |
|
prepare_input = PREPROCESSING_FUNCTIONS.get(args.model_type) |
|
preprocessed_prompt_text = prepare_input(args, model, tokenizer, prompt_text) |
|
encoded_prompt = tokenizer.encode( |
|
preprocessed_prompt_text, add_special_tokens=False, return_tensors="pt", add_space_before_punct_symbol=True |
|
) |
|
else: |
|
encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors="pt") |
|
encoded_prompt = encoded_prompt.to(args.device) |
|
|
|
output_sequences = model.generate( |
|
input_ids=encoded_prompt, |
|
max_length=args.length + len(encoded_prompt[0]), |
|
temperature=args.temperature, |
|
top_k=args.k, |
|
top_p=args.p, |
|
repetition_penalty=args.repetition_penalty, |
|
do_sample=True, |
|
num_return_sequences=args.num_return_sequences, |
|
) |
|
|
|
|
|
if len(output_sequences.shape) > 2: |
|
output_sequences.squeeze_() |
|
|
|
now = datetime.datetime.now() |
|
date_time = now.strftime('%Y%m%d_%H%M%S%f') |
|
|
|
for generated_sequence_idx, generated_sequence in enumerate(output_sequences): |
|
print("ruGPT:".format(generated_sequence_idx + 1)) |
|
generated_sequence = generated_sequence.tolist() |
|
|
|
|
|
text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True) |
|
|
|
|
|
text = text[: text.find(args.stop_token) if args.stop_token else None] |
|
|
|
|
|
total_sequence = ( |
|
prompt_text + text[len(tokenizer.decode(encoded_prompt[0], clean_up_tokenization_spaces=True)) :] |
|
) |
|
|
|
generated_sequences.append(total_sequence) |
|
|
|
print(total_sequence) |
|
|
|
prompt_text = "" |
|
if args.prompt: |
|
break |
|
|
|
return generated_sequences |
|
|
|
title = "ruGPT3 Song Writer" |
|
description = "Generate russian songs via fine-tuned ruGPT3" |
|
|
|
gr.Interface( |
|
process, |
|
gr.inputs.Textbox(lines=1, label="Input text", examples="Как дела? Как дела? Это новый кадиллак"), |
|
gr.outputs.Textbox(lines=20, label="Output text"), |
|
title=title, |
|
description=description, |
|
).launch(enable_queue=True,cache_examples=True) |
|
|