ndhieunguyen's picture
Add application file
7dd9869
raw
history blame
13.2 kB
import os, sys, json
import glob
from functools import partial
sys.path.insert(0, 'e2e-metrics')
import numpy as np
from pycocotools.coco import COCO
from pycocoevalcap.eval import COCOEvalCap
from metrics.pymteval import BLEUScore, NISTScore
from nltk.translate.meteor_score import meteor_score
from parse import *
import json
import sys, os, torch
from spacy.lang.en import English
import ast
from transformers import BertForMaskedLM, BertTokenizer
MODE = sys.argv[1] # ar or diff
SPLIT = sys.argv[2] # val or test
OUT_PATH = sys.argv[3] # output path.
INPUT_PATH = sys.argv[4] # input path. e.g. diffusion_lm/improved-diffusion/anlg_results/diff_roc_pad_rand128_transformer_lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd108_xstart_e2e_long.ema_0.9999_800000.pt.infill_infill
def load_results_simple(path):
with open(path, 'r') as f:
full_result_dict = json.load(f)
return full_result_dict
def post_process(filename, fileout, tokenizer_spacy):
# filename = 'diffusion_lm/improved-diffusion/anlg_results/diff_roc_mbr.json2'
bert_model = 'bert-base-cased'
tokenizer = BertTokenizer.from_pretrained(bert_model)
model = BertForMaskedLM.from_pretrained(bert_model).cuda()
fileout_handle = open(fileout, 'w')
full_lst = []
with open(filename, 'r') as f:
for line in f:
line = json.loads(line)
full_lst.append(line)
for example in full_lst:
sent = example['sample']
obs1 = example['obs1']
obs2 = example['obs2']
if 'UNK' in sent:
sent = obs1 + sent.replace('UNK', tokenizer.mask_token) + obs2
print(sent)
model_inputs = tokenizer(sent, return_tensors="pt")
model_inputs = {k: v.to(model.device) for k, v in model_inputs.items()}
model_out = model(**model_inputs)
mask_words = model_inputs['input_ids'] == tokenizer.mask_token_id
masked_logits = model_out.logits[mask_words].view(-1, model_out.logits.size(-1))
# take argmax from this.
max_cands = torch.max(masked_logits, dim=-1)
indices = max_cands.indices
model_inputs['input_ids'][mask_words] = indices
out = tokenizer.batch_decode(model_inputs['input_ids'].tolist(),
skip_special_tokens=True)[0]
print(out)
word_lstout = [x.text for x in tokenizer_spacy(out)]
word_lst1 = [x.text for x in tokenizer_spacy(example['obs1'])]
word_lst2 = [x.text for x in tokenizer_spacy(example['obs2'])]
example['sample'] = " ".join(word_lstout[len(word_lst1):-len(word_lst2)])
print(example['sample'])
print()
else:
print('NO NEED THIS FIX. ')
print(json.dumps(example), file=fileout_handle)
fileout_handle.close()
def load_results(sent_lst, tokenizer):
# target_file = f"{INPUT_PATH}_*.json"
# target_file = glob.glob(target_file)
# print([x for x in target_file if 'val' not in x and 'test' not in x])
# 10/0
full_result_dict = {}
failed_instances = []
found_idx = []
sent_lst_lst = list(sent_lst.items())
for idx, (key, val) in enumerate(sent_lst_lst):
# if idx < 2500: continue
if idx in full_result_dict.keys(): continue
word_lst1 = [x.text for x in tokenizer(val['obs1'])]
word_lst2 = [x.text for x in tokenizer(val['obs2'])]
# target_file = f"diffusion_lm/improved-diffusion/anlg_results/diff_roc_pad_rand128_" \
# f"transformer_lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd108_xstart_e2e_long.ema" \
# f"_0.9999_800000.pt.infill_infill_*_{SPLIT}_{idx}.json"
target_file = f"{INPUT_PATH}_*_{SPLIT}_{idx}.json"
file_lst = glob.glob(target_file)
# print(file_lst, target_file)
try:
assert len(file_lst) == 1
except:
print('the file must have existed in a batched version')
# if SPLIT == 'val': assert False
# if idx % 100 == 1: idx = idx-1
target_file = f"{INPUT_PATH}_*_{idx}.json"
file_lst = glob.glob(target_file)
print(file_lst, target_file)
print(file_lst)
target_file = file_lst[0]
if "x128" in target_file:
infill_lst = []
with open(target_file, 'r') as f:
for line in f:
example = json.loads(line)[0]
infill_ = example.split()[len(word_lst1):-len(word_lst2)]
# print(len(infill_))
# print(infill_, example)
# assert len(infill_) == 10
infill_=' '.join(infill_)
# print(infill_)
infill_lst.append(infill_)
result_dict = {
"pred_samples": infill_lst,
"sample": None,
"obs1": val['obs1'],
"obs2": val['obs2']
}
full_result_dict[idx] = result_dict
else:
with open(target_file, 'r') as f:
for line in f:
example = ast.literal_eval(line.strip())
index, template = list(example.keys())[0]
print(index, idx)
if int(index) < int(idx):
continue
assert int(index) == int(idx)
found_idx.append(idx)
example = list(example.values())[0]
kk, val = sent_lst_lst[idx]
word_lst1 = [x.text for x in tokenizer(val['obs1'])]
word_lst2 = [x.text for x in tokenizer(val['obs2'])]
infill_lst = [" ".join(xx.split()[len(word_lst1):-len(word_lst2)]) for xx in example]
result_dict = {
"pred_samples": infill_lst,
"sample": None,
"obs1": val['obs1'],
"obs2": val['obs2']
}
full_result_dict[idx] = result_dict
idx += 1
with open('full_diff_test_outputs_aug.json', 'w') as f:
json.dump(full_result_dict, f)
return full_result_dict
# read files.
def mbr(result_lst, total_len, sample_size, utility):
result = []
for i in range(total_len):
example_set = result_lst[i * sample_size:(i + 1) * sample_size]
# print(example_set)
score_dict = {}
for idx in range(len(example_set)):
y = example_set[idx]
utility_lst = []
for idx_x in range(len(example_set)):
if idx_x != idx:
utility_lst.append(utility(example_set[idx_x], y))
score_dict[idx] = np.array(utility_lst).mean()
# print(score_dict)
best_y = sorted(score_dict.items(), key=lambda item: item[1])[-1]
result.append(example_set[best_y[0]])
# print(best_y)
return result
def bleu_score(scorer, sent_sys, sents_ref):
scorer.reset()
scorer.append(sent_sys, [sents_ref])
return scorer.score()
def meteor_score2(pred, ref):
meteor = meteor_score([ref.split()], pred.split())
return meteor
def apply_mbr_func(full_result_dict, outpath, sent_lst):
assert len(sent_lst) == len(full_result_dict)
out_handle = open(outpath, 'w')
count = 0
for idx, val in full_result_dict.items():
infill_lst = val['pred_samples']
print(count, idx )
assert count == int(idx)
count += 1
sample_size = len(infill_lst)
total_len = 1
mteval_scorers = [BLEUScore(), BLEUScore(smoothing=1.0), NISTScore()]
result_lst = mbr(infill_lst, total_len, sample_size, partial(bleu_score, mteval_scorers[1]))
print(infill_lst)
print(result_lst)
result_str = result_lst[0]
result_dict = {
"pred_samples": infill_lst,
"sample": result_str,
"obs1": val['obs1'],
"obs2": val['obs2']
}
print(json.dumps(result_dict), file=out_handle)
out_handle.close()
print(f'written to {outpath}')
return
if SPLIT == 'val':
source_file = 'diffusion_lm/ROCstory/anlg/anlg/dev_cleanup.json'
elif SPLIT == 'test':
source_file = 'diffusion_lm/ROCstory/anlg/anlg/test_cleanup_no_label.json'
else:
assert False, "invalid split"
with open(source_file, 'r') as f:
sent_lst = json.load(f)
if MODE == 'diff':
nlp = English()
tokenizer = nlp.tokenizer
# load_results(sent_lst, tokenizer)
# 10/0
decoded_dict = load_results_simple(INPUT_PATH)
############3
# small_decoded_dict = {}
# for i in range(10):
# small_decoded_dict[i] = decoded_dict[str(i)]
# decoded_dict = small_decoded_dict
# small_sent_lst = {}
# for k, v in sent_lst.items():
# if len(small_sent_lst) > 9: break
# small_sent_lst[k] = v
# sent_lst = small_sent_lst
############3
outpath = OUT_PATH
apply_mbr_func(decoded_dict, outpath, sent_lst)
post_process(outpath, outpath+'.clean.json', tokenizer)
#
# # load_results(sent_lst, tokenizer)
# # 10/0
# print(len(sent_lst))
# for idx, (key, val) in enumerate(sent_lst.items()):
# # if idx < 518: continue
# if idx > 900:
# break
# # change the matching method.
# word_lst1 = [x.text for x in tokenizer(val['obs1'])]
# word_lst2 = [x.text for x in tokenizer(val['obs2'])]
# # partial_seq = f"{val['obs1']} " + "PAD " + f"{val['obs2']}"
# # word_lst = [x.text for x in tokenizer(partial_seq)]
# # partial_seq = " ".join(word_lst)
# # partial_seq = partial_seq.replace('PAD', '{}')
# # print(partial_seq, idx)
#
# # target_file = f"diffusion_lm/improved-diffusion/anlg_results/diff_roc_pad_rand128_" \
# # f"transformer_lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd108_xstart_e2e_long.ema" \
# # f"_0.9999_800000.pt.infill_infill_*_{SPLIT}_{idx}.json"
# target_file = f"{INPUT_PATH}_*_{SPLIT}_{idx}.json"
#
# file_lst = glob.glob(target_file)
# print(file_lst, target_file)
# assert len(file_lst) == 1
# target_file = file_lst[0]
# # print(target_file)
# infill_lst = []
# with open(target_file, 'r') as f:
# for line in f:
# example = json.loads(line)[0]
# # print(example, partial_seq)
# # infill_ = parse(partial_seq, example)
# # print(example)
# infill_ = example.split()[len(word_lst1):-len(word_lst2)]
# # print(len(infill_))
# # print(infill_, example)
# # assert len(infill_) == 10
# infill_=' '.join(infill_)
# # print(infill_)
# infill_lst.append(infill_)
# infill_lst = infill_lst
# sample_size = len(infill_lst)
# total_len = 1
# mteval_scorers = [BLEUScore(), BLEUScore(smoothing=1.0), NISTScore()]
# result_lst = mbr(infill_lst, total_len, sample_size, partial(bleu_score, mteval_scorers[1]))
# print(infill_lst)
# print(result_lst)
# result_str = result_lst[0]
# result_dict = {
# "pred_samples": infill_lst,
# "sample":result_str,
# "obs1": val['obs1'],
# "obs2": val['obs2']
# }
# print(json.dumps(result_dict), file=out_handle)
#
# out_handle.close()
# print(f'written to {outpath}')
elif MODE == 'ar':
outpath = OUT_PATH #'diffusion_lm/improved-diffusion/anlg_results/ar_full_mbr.json'
out_handle = open(outpath, 'w')
sample_file = INPUT_PATH #'diffusion_lm/improved-diffusion/anlg_results/ar_sample_500_v2.json'
nlp = English()
tokenizer = nlp.tokenizer
print(len(sent_lst))
sample_lst = []
with open(sample_file, 'r') as f:
for line in f:
sample_dict = json.loads(line)
sample_lst.append(sample_dict)
for idx, (key, val) in enumerate(sent_lst.items()):
# if idx < 109: continue
# if idx > 499:
# break
infill_lst = sample_lst[idx]['samples']
sample_size = len(infill_lst)
total_len = 1
mteval_scorers = [BLEUScore(), BLEUScore(smoothing=1.0), NISTScore()]
result_lst = mbr(infill_lst, total_len, sample_size, partial(bleu_score, mteval_scorers[1]))
print(infill_lst)
print(result_lst)
result_str = result_lst[0]
result_dict = {
"pred_samples": infill_lst,
"sample": result_str,
"obs1": val['obs1'],
"obs2": val['obs2']
}
print(json.dumps(result_dict), file=out_handle)
out_handle.close()
print(f'written to {outpath}')
post_process(outpath, outpath + '.clean.json', tokenizer)
# print(file+'.clean')
# with open(file+'.clean', 'w') as f:
# for line in result_lst:
# print(line, file=f)