import os import random import time import pickle import math from argparse import ArgumentParser from collections import defaultdict import string import csv from tqdm import tqdm import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from transformers import AutoTokenizer, AutoModelWithLMHead, pipeline, set_seed, GPT2Tokenizer, GPT2Model from data import Dataset from model import Model from util import save_checkpoint, ProgressMeter, AverageMeter, num_params, pad_mask from predict_topic import predict from constants import * def main(args): with open(args.dataset_info, 'rb') as rf: dataset_info = pickle.load(rf) gpt_tokenizer = AutoTokenizer.from_pretrained(args.model_string) gpt_tokenizer.add_special_tokens({'pad_token': PAD_TOKEN}) gpt_pad_id = gpt_tokenizer.encode(PAD_TOKEN)[0] gpt_model = AutoModelWithLMHead.from_pretrained(args.model_string).to(args.device) gpt_model.eval() checkpoint = torch.load(args.ckpt, map_location=args.device) model_args = checkpoint['args'] conditioning_model = Model(model_args, gpt_pad_id, len(dataset_info.index2word)) # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway conditioning_model.load_state_dict(checkpoint['state_dict']) conditioning_model = conditioning_model.to(args.device) conditioning_model.eval() if args.verbose: print("=> loaded checkpoint '{}' (epoch {})" .format(args.ckpt, checkpoint['epoch'])) print('num params', num_params(conditioning_model)) input_texts, conditions, categories = [], [], [] if args.condition_file is not None: with open(args.condition_file, 'r') as rf: for line in rf: input_texts.append(line.strip().split('\t')[0]) conditions.append(line.strip().split('\t')[1]) categories.append(None) for cw in conditions[-1].split(): assert cw in dataset_info.word2index else: prefixes = [] with open(args.prefix_file, 'r') as rf: for line in rf: prefixes.append(line.strip()) condition_wordlists = [] for root, _, files in os.walk(args.wordlist_dir): for fname in files: words = [] with open(os.path.join(root, fname), 'r') as rf: for line in rf: word = line.strip() if word in dataset_info.word2index: words.append(word) else: if args.verbose: print('word not found:', word) condition_wordlists.append((' '.join(words), fname.split('.')[0])) for p in prefixes: for c, category in condition_wordlists: input_texts.append(p) conditions.append(c) categories.append(category) all_cr = [] pair_num = 0 for input_text, condition_words, category in tqdm(zip(input_texts, conditions, categories), total=len(conditions)): predict_function = predict condition_results = [] for i in range(0, args.sample_size, args.max_sample_batch): num_samples = min(args.max_sample_batch, args.sample_size - i) condition_results += predict_function(gpt_model, gpt_tokenizer, conditioning_model, [input_text for _ in range(num_samples)], condition_words, dataset_info, args.precondition_topk, args.topk, args.length_cutoff, condition_lambda=args.condition_lambda, device=args.device) all_cr.append((input_text, category, condition_results)) pair_num += 1 if args.max_pairs > 0 and pair_num >= args.max_pairs: break with open(args.log_file, 'w') as wf: writer = csv.DictWriter(wf, fieldnames=['category', 'input_text', 'generation']) writer.writeheader() for cr_group in all_cr: for cr in cr_group[2]: writer.writerow({'category': cr_group[1], 'input_text': cr_group[0], 'generation': cr}) if __name__=='__main__': parser = ArgumentParser() # DATA parser.add_argument('--ckpt', type=str, required=True) parser.add_argument('--log_file', type=str, required=True, help='file to write outputs to (csv format)') parser.add_argument('--dataset_info', type=str, required=True, help='saved dataset info') parser.add_argument('--model_string', type=str, default='gpt2-medium') parser.add_argument('--condition_file', type=str, default=None, help='file of inputs and conditions') parser.add_argument('--prefix_file', type=str, default=None, help='prefix set') parser.add_argument('--wordlist_dir', type=str, default=None, help='dir of bow wordlists for categories') parser.add_argument('--sample_size', type=int, default=3, help='samples per input text-condition pair') parser.add_argument('--max_sample_batch', type=int, default=3, help='max samples at a time') parser.add_argument('--max_pairs', type=int, default=-1, help='max input-condition pairs, for debugging quickly') parser.add_argument('--precondition_topk', type=int, default=200, help='consider top k outputs from gpt at each step before conditioning and re-pruning') parser.add_argument('--topk', type=int, default=10, help='consider top k outputs from gpt at each step') parser.add_argument('--condition_lambda', type=float, default=1.0, help='lambda weight on conditioning model') parser.add_argument('--length_cutoff', type=int, default=80, help='max length') parser.add_argument('--seed', type=int, default=1, help='random seed') parser.add_argument('--device', type=str, default='cuda', choices=['cpu', 'cuda']) parser.add_argument('--debug', action='store_true', default=False) parser.add_argument('--verbose', action='store_true', default=False) args = parser.parse_args() assert (args.condition_file is not None) != (args.prefix_file is not None and args.wordlist_dir is not None) # one of two interfaces for specifying random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) main(args)