Spaces:
Running
Running
import json | |
import sys, os, torch | |
from spacy.lang.en import English | |
from improved_diffusion.rounding import rounding_func, load_models, load_tokenizer | |
from transformers import AutoModelForCausalLM | |
# read files. | |
# with open('diffusion_lm/ROCstory/anlg/anlg/dev_cleanup.json', 'r') as f: | |
SPLIT = 'test' | |
if SPLIT == 'val': | |
source_file = 'diffusion_lm/ROCstory/anlg/anlg/dev_cleanup.json' | |
elif SPLIT == 'test': | |
source_file = 'diffusion_lm/ROCstory/anlg/anlg/test_cleanup_no_label.json' | |
else: | |
assert False, "invalid split" | |
with open(source_file, 'r') as f: | |
sent_lst = json.load(f) | |
nlp = English() | |
tokenizer = nlp.tokenizer | |
MODE = 'ar' | |
''' | |
"00b9adb2-b3b6-4737-902a-50f308bac4b5-1": { | |
"gold_labels": [ | |
"I put my baby in the car and drove around.", | |
"I realized he needed his blanket, which I had forgotten at a faraway hotel.", | |
"I took a drive to get my baby to sleep.", | |
"I took my baby for a drive and she fell asleep in the car." | |
], | |
"obs1": "My baby would not go to sleep last night.", | |
"obs2": "I wound up driving for hours." | |
}, | |
''' | |
print(len(sent_lst)) | |
if MODE == 'ar': | |
model_name = 'predictability/diff_models/roc_e=20_b=32_m=gpt2_wikitext-103-raw-v1_101_wp_pad_infill' | |
model_name = 'predictability/diff_models/roc_e=6_b=10_m=gpt2_wikitext-103-raw-v1_101_wp_pad_infill_v2' | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, # path to the AR model trained for LMing this task. | |
).cuda() | |
tokenizer2 = load_tokenizer('roc', 'random', | |
'predictability/diffusion_models_v7/diff_roc_pad_rand16_transformer_lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd108_xstart') | |
vocab = {v: k for k, v in tokenizer2.items()} | |
print(len(tokenizer2), len(vocab), 'loaded vocabs') | |
outfile='ar_sample_full_test_v2.json' | |
filehandle = open(outfile, 'w') | |
for idx, (key, val) in enumerate(sent_lst.items()): | |
# if idx <= 499: | |
# continue | |
# if idx >= 500: | |
# continue | |
# if idx != 684: | |
# continue | |
if MODE == 'diff': | |
partial_seq = f"{val['obs1']} " + "PAD "*10 + f"{val['obs2']}" | |
word_lst = [x.text for x in tokenizer(partial_seq)] | |
partial_seq = " ".join(word_lst) | |
print(partial_seq, idx) | |
# partial_seq = "Brenna and I used to be best friends . PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD We never talked again ." | |
COMMAND = "python ../scripts/infill.py " \ | |
"--model_path predictability/diffusion_models_v7/diff_roc_pad_rand128_transformer_lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd108_xstart_e2e_long/ema_0.9999_800000.pt " \ | |
" --batch_size 50 " \ | |
f"--partial_seq \'{partial_seq}\' " \ | |
f"--eval_task_ infill --notes {SPLIT}_{idx} " \ | |
f"--out_dir ../anlg_results" | |
os.system(COMMAND) | |
torch.cuda.empty_cache() | |
elif MODE == 'ar': | |
partial_seq = f"{val['obs1']} " + f"{val['obs2']}" | |
print(partial_seq) | |
word_idx_lst = [vocab['START']] + [vocab.get(x.text, vocab['UNK']) for x in tokenizer(partial_seq)] | |
init_prompt = torch.LongTensor(word_idx_lst).cuda().unsqueeze(0) | |
print(init_prompt.shape) | |
# sample_out = model.generate(init_prompt, do_sample=True, max_length=64, top_k=len(vocab)) | |
if 'sample' in outfile: | |
print('sampling 50 examples.') | |
init_prompt = init_prompt.expand(50, -1) | |
sample_out = model.generate(init_prompt, do_sample=True, max_length=64, top_k=len(vocab)) | |
else: | |
sample_out = model.generate(init_prompt, do_sample=False, num_beam=4, max_length=64, top_k=len(vocab)) | |
print(sample_out.shape) | |
sample_out = sample_out[:, init_prompt.size(1):] | |
# decode | |
if 'sample' in outfile: | |
sample_lst = [] | |
for examp in sample_out: | |
sample = examp.tolist() | |
words_sample = [tokenizer2[s] for s in sample] | |
tempsent = [x for x in words_sample if x != 'PAD'] | |
if tempsent[0] == 'START': | |
tempsent = tempsent[1:] | |
if tempsent[-1] == 'END': | |
tempsent = tempsent[:-1] | |
result_sent = " ".join(tempsent) | |
sample_lst.append(result_sent) | |
out_dict = {'idx': idx, | |
'obs1': val['obs1'], | |
'obs2': val['obs2'], | |
'samples': sample_lst} | |
print(json.dumps(out_dict), file=filehandle) | |
else: | |
sample = sample_out[0].tolist() | |
words_sample = [tokenizer2[s] for s in sample] | |
tempsent = [x for x in words_sample if x != 'PAD'] | |
if tempsent[0] == 'START': | |
tempsent = tempsent[1:] | |
if tempsent[-1] == 'END': | |
tempsent = tempsent[:-1] | |
result_sent = " ".join(tempsent) | |
out_dict = {'idx':idx, | |
'obs1':val['obs1'], | |
'obs2':val['obs2'], | |
'sample':result_sent} | |
print(json.dumps(out_dict), file=filehandle) | |
filehandle.close() | |
print(f'written to {outfile}') | |