Spaces:
Running
Running
import os, sys, json | |
import glob | |
from functools import partial | |
sys.path.insert(0, 'e2e-metrics') | |
import numpy as np | |
from pycocotools.coco import COCO | |
from pycocoevalcap.eval import COCOEvalCap | |
from metrics.pymteval import BLEUScore, NISTScore | |
from nltk.translate.meteor_score import meteor_score | |
from parse import * | |
import json | |
import sys, os, torch | |
from spacy.lang.en import English | |
import ast | |
from transformers import BertForMaskedLM, BertTokenizer | |
MODE = sys.argv[1] # ar or diff | |
SPLIT = sys.argv[2] # val or test | |
OUT_PATH = sys.argv[3] # output path. | |
INPUT_PATH = sys.argv[4] # input path. e.g. diffusion_lm/improved-diffusion/anlg_results/diff_roc_pad_rand128_transformer_lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd108_xstart_e2e_long.ema_0.9999_800000.pt.infill_infill | |
def load_results_simple(path): | |
with open(path, 'r') as f: | |
full_result_dict = json.load(f) | |
return full_result_dict | |
def post_process(filename, fileout, tokenizer_spacy): | |
# filename = 'diffusion_lm/improved-diffusion/anlg_results/diff_roc_mbr.json2' | |
bert_model = 'bert-base-cased' | |
tokenizer = BertTokenizer.from_pretrained(bert_model) | |
model = BertForMaskedLM.from_pretrained(bert_model).cuda() | |
fileout_handle = open(fileout, 'w') | |
full_lst = [] | |
with open(filename, 'r') as f: | |
for line in f: | |
line = json.loads(line) | |
full_lst.append(line) | |
for example in full_lst: | |
sent = example['sample'] | |
obs1 = example['obs1'] | |
obs2 = example['obs2'] | |
if 'UNK' in sent: | |
sent = obs1 + sent.replace('UNK', tokenizer.mask_token) + obs2 | |
print(sent) | |
model_inputs = tokenizer(sent, return_tensors="pt") | |
model_inputs = {k: v.to(model.device) for k, v in model_inputs.items()} | |
model_out = model(**model_inputs) | |
mask_words = model_inputs['input_ids'] == tokenizer.mask_token_id | |
masked_logits = model_out.logits[mask_words].view(-1, model_out.logits.size(-1)) | |
# take argmax from this. | |
max_cands = torch.max(masked_logits, dim=-1) | |
indices = max_cands.indices | |
model_inputs['input_ids'][mask_words] = indices | |
out = tokenizer.batch_decode(model_inputs['input_ids'].tolist(), | |
skip_special_tokens=True)[0] | |
print(out) | |
word_lstout = [x.text for x in tokenizer_spacy(out)] | |
word_lst1 = [x.text for x in tokenizer_spacy(example['obs1'])] | |
word_lst2 = [x.text for x in tokenizer_spacy(example['obs2'])] | |
example['sample'] = " ".join(word_lstout[len(word_lst1):-len(word_lst2)]) | |
print(example['sample']) | |
print() | |
else: | |
print('NO NEED THIS FIX. ') | |
print(json.dumps(example), file=fileout_handle) | |
fileout_handle.close() | |
def load_results(sent_lst, tokenizer): | |
# target_file = f"{INPUT_PATH}_*.json" | |
# target_file = glob.glob(target_file) | |
# print([x for x in target_file if 'val' not in x and 'test' not in x]) | |
# 10/0 | |
full_result_dict = {} | |
failed_instances = [] | |
found_idx = [] | |
sent_lst_lst = list(sent_lst.items()) | |
for idx, (key, val) in enumerate(sent_lst_lst): | |
# if idx < 2500: continue | |
if idx in full_result_dict.keys(): continue | |
word_lst1 = [x.text for x in tokenizer(val['obs1'])] | |
word_lst2 = [x.text for x in tokenizer(val['obs2'])] | |
# target_file = f"diffusion_lm/improved-diffusion/anlg_results/diff_roc_pad_rand128_" \ | |
# f"transformer_lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd108_xstart_e2e_long.ema" \ | |
# f"_0.9999_800000.pt.infill_infill_*_{SPLIT}_{idx}.json" | |
target_file = f"{INPUT_PATH}_*_{SPLIT}_{idx}.json" | |
file_lst = glob.glob(target_file) | |
# print(file_lst, target_file) | |
try: | |
assert len(file_lst) == 1 | |
except: | |
print('the file must have existed in a batched version') | |
# if SPLIT == 'val': assert False | |
# if idx % 100 == 1: idx = idx-1 | |
target_file = f"{INPUT_PATH}_*_{idx}.json" | |
file_lst = glob.glob(target_file) | |
print(file_lst, target_file) | |
print(file_lst) | |
target_file = file_lst[0] | |
if "x128" in target_file: | |
infill_lst = [] | |
with open(target_file, 'r') as f: | |
for line in f: | |
example = json.loads(line)[0] | |
infill_ = example.split()[len(word_lst1):-len(word_lst2)] | |
# print(len(infill_)) | |
# print(infill_, example) | |
# assert len(infill_) == 10 | |
infill_=' '.join(infill_) | |
# print(infill_) | |
infill_lst.append(infill_) | |
result_dict = { | |
"pred_samples": infill_lst, | |
"sample": None, | |
"obs1": val['obs1'], | |
"obs2": val['obs2'] | |
} | |
full_result_dict[idx] = result_dict | |
else: | |
with open(target_file, 'r') as f: | |
for line in f: | |
example = ast.literal_eval(line.strip()) | |
index, template = list(example.keys())[0] | |
print(index, idx) | |
if int(index) < int(idx): | |
continue | |
assert int(index) == int(idx) | |
found_idx.append(idx) | |
example = list(example.values())[0] | |
kk, val = sent_lst_lst[idx] | |
word_lst1 = [x.text for x in tokenizer(val['obs1'])] | |
word_lst2 = [x.text for x in tokenizer(val['obs2'])] | |
infill_lst = [" ".join(xx.split()[len(word_lst1):-len(word_lst2)]) for xx in example] | |
result_dict = { | |
"pred_samples": infill_lst, | |
"sample": None, | |
"obs1": val['obs1'], | |
"obs2": val['obs2'] | |
} | |
full_result_dict[idx] = result_dict | |
idx += 1 | |
with open('full_diff_test_outputs_aug.json', 'w') as f: | |
json.dump(full_result_dict, f) | |
return full_result_dict | |
# read files. | |
def mbr(result_lst, total_len, sample_size, utility): | |
result = [] | |
for i in range(total_len): | |
example_set = result_lst[i * sample_size:(i + 1) * sample_size] | |
# print(example_set) | |
score_dict = {} | |
for idx in range(len(example_set)): | |
y = example_set[idx] | |
utility_lst = [] | |
for idx_x in range(len(example_set)): | |
if idx_x != idx: | |
utility_lst.append(utility(example_set[idx_x], y)) | |
score_dict[idx] = np.array(utility_lst).mean() | |
# print(score_dict) | |
best_y = sorted(score_dict.items(), key=lambda item: item[1])[-1] | |
result.append(example_set[best_y[0]]) | |
# print(best_y) | |
return result | |
def bleu_score(scorer, sent_sys, sents_ref): | |
scorer.reset() | |
scorer.append(sent_sys, [sents_ref]) | |
return scorer.score() | |
def meteor_score2(pred, ref): | |
meteor = meteor_score([ref.split()], pred.split()) | |
return meteor | |
def apply_mbr_func(full_result_dict, outpath, sent_lst): | |
assert len(sent_lst) == len(full_result_dict) | |
out_handle = open(outpath, 'w') | |
count = 0 | |
for idx, val in full_result_dict.items(): | |
infill_lst = val['pred_samples'] | |
print(count, idx ) | |
assert count == int(idx) | |
count += 1 | |
sample_size = len(infill_lst) | |
total_len = 1 | |
mteval_scorers = [BLEUScore(), BLEUScore(smoothing=1.0), NISTScore()] | |
result_lst = mbr(infill_lst, total_len, sample_size, partial(bleu_score, mteval_scorers[1])) | |
print(infill_lst) | |
print(result_lst) | |
result_str = result_lst[0] | |
result_dict = { | |
"pred_samples": infill_lst, | |
"sample": result_str, | |
"obs1": val['obs1'], | |
"obs2": val['obs2'] | |
} | |
print(json.dumps(result_dict), file=out_handle) | |
out_handle.close() | |
print(f'written to {outpath}') | |
return | |
if SPLIT == 'val': | |
source_file = 'diffusion_lm/ROCstory/anlg/anlg/dev_cleanup.json' | |
elif SPLIT == 'test': | |
source_file = 'diffusion_lm/ROCstory/anlg/anlg/test_cleanup_no_label.json' | |
else: | |
assert False, "invalid split" | |
with open(source_file, 'r') as f: | |
sent_lst = json.load(f) | |
if MODE == 'diff': | |
nlp = English() | |
tokenizer = nlp.tokenizer | |
# load_results(sent_lst, tokenizer) | |
# 10/0 | |
decoded_dict = load_results_simple(INPUT_PATH) | |
############3 | |
# small_decoded_dict = {} | |
# for i in range(10): | |
# small_decoded_dict[i] = decoded_dict[str(i)] | |
# decoded_dict = small_decoded_dict | |
# small_sent_lst = {} | |
# for k, v in sent_lst.items(): | |
# if len(small_sent_lst) > 9: break | |
# small_sent_lst[k] = v | |
# sent_lst = small_sent_lst | |
############3 | |
outpath = OUT_PATH | |
apply_mbr_func(decoded_dict, outpath, sent_lst) | |
post_process(outpath, outpath+'.clean.json', tokenizer) | |
# | |
# # load_results(sent_lst, tokenizer) | |
# # 10/0 | |
# print(len(sent_lst)) | |
# for idx, (key, val) in enumerate(sent_lst.items()): | |
# # if idx < 518: continue | |
# if idx > 900: | |
# break | |
# # change the matching method. | |
# word_lst1 = [x.text for x in tokenizer(val['obs1'])] | |
# word_lst2 = [x.text for x in tokenizer(val['obs2'])] | |
# # partial_seq = f"{val['obs1']} " + "PAD " + f"{val['obs2']}" | |
# # word_lst = [x.text for x in tokenizer(partial_seq)] | |
# # partial_seq = " ".join(word_lst) | |
# # partial_seq = partial_seq.replace('PAD', '{}') | |
# # print(partial_seq, idx) | |
# | |
# # target_file = f"diffusion_lm/improved-diffusion/anlg_results/diff_roc_pad_rand128_" \ | |
# # f"transformer_lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd108_xstart_e2e_long.ema" \ | |
# # f"_0.9999_800000.pt.infill_infill_*_{SPLIT}_{idx}.json" | |
# target_file = f"{INPUT_PATH}_*_{SPLIT}_{idx}.json" | |
# | |
# file_lst = glob.glob(target_file) | |
# print(file_lst, target_file) | |
# assert len(file_lst) == 1 | |
# target_file = file_lst[0] | |
# # print(target_file) | |
# infill_lst = [] | |
# with open(target_file, 'r') as f: | |
# for line in f: | |
# example = json.loads(line)[0] | |
# # print(example, partial_seq) | |
# # infill_ = parse(partial_seq, example) | |
# # print(example) | |
# infill_ = example.split()[len(word_lst1):-len(word_lst2)] | |
# # print(len(infill_)) | |
# # print(infill_, example) | |
# # assert len(infill_) == 10 | |
# infill_=' '.join(infill_) | |
# # print(infill_) | |
# infill_lst.append(infill_) | |
# infill_lst = infill_lst | |
# sample_size = len(infill_lst) | |
# total_len = 1 | |
# mteval_scorers = [BLEUScore(), BLEUScore(smoothing=1.0), NISTScore()] | |
# result_lst = mbr(infill_lst, total_len, sample_size, partial(bleu_score, mteval_scorers[1])) | |
# print(infill_lst) | |
# print(result_lst) | |
# result_str = result_lst[0] | |
# result_dict = { | |
# "pred_samples": infill_lst, | |
# "sample":result_str, | |
# "obs1": val['obs1'], | |
# "obs2": val['obs2'] | |
# } | |
# print(json.dumps(result_dict), file=out_handle) | |
# | |
# out_handle.close() | |
# print(f'written to {outpath}') | |
elif MODE == 'ar': | |
outpath = OUT_PATH #'diffusion_lm/improved-diffusion/anlg_results/ar_full_mbr.json' | |
out_handle = open(outpath, 'w') | |
sample_file = INPUT_PATH #'diffusion_lm/improved-diffusion/anlg_results/ar_sample_500_v2.json' | |
nlp = English() | |
tokenizer = nlp.tokenizer | |
print(len(sent_lst)) | |
sample_lst = [] | |
with open(sample_file, 'r') as f: | |
for line in f: | |
sample_dict = json.loads(line) | |
sample_lst.append(sample_dict) | |
for idx, (key, val) in enumerate(sent_lst.items()): | |
# if idx < 109: continue | |
# if idx > 499: | |
# break | |
infill_lst = sample_lst[idx]['samples'] | |
sample_size = len(infill_lst) | |
total_len = 1 | |
mteval_scorers = [BLEUScore(), BLEUScore(smoothing=1.0), NISTScore()] | |
result_lst = mbr(infill_lst, total_len, sample_size, partial(bleu_score, mteval_scorers[1])) | |
print(infill_lst) | |
print(result_lst) | |
result_str = result_lst[0] | |
result_dict = { | |
"pred_samples": infill_lst, | |
"sample": result_str, | |
"obs1": val['obs1'], | |
"obs2": val['obs2'] | |
} | |
print(json.dumps(result_dict), file=out_handle) | |
out_handle.close() | |
print(f'written to {outpath}') | |
post_process(outpath, outpath + '.clean.json', tokenizer) | |
# print(file+'.clean') | |
# with open(file+'.clean', 'w') as f: | |
# for line in result_lst: | |
# print(line, file=f) | |