|
import json |
|
import logging |
|
import os |
|
import sys |
|
import time |
|
import torch |
|
import warnings |
|
|
|
import pandas as pd |
|
|
|
from collections import defaultdict |
|
from tqdm import tqdm |
|
from torch.utils.data import DataLoader |
|
from transformers import ( |
|
HfArgumentParser, |
|
set_seed, |
|
AutoTokenizer, |
|
AutoConfig, |
|
) |
|
|
|
from tevatron.arguments import ( |
|
GLENP2ModelArguments as ModelArguments, |
|
GLENP2DataArguments as DataArguments, |
|
GLENP2TrainingArguments as TrainingArguments, |
|
) |
|
from tevatron.datasets import GLENP2EncodeDataset |
|
from tevatron.metrics import compute_recall, compute_mrr, evaluate_beir |
|
from tevatron.modeling import GLENP2Model |
|
from tevatron.tree import TreeBuilder |
|
|
|
logger = logging.getLogger(__name__) |
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
warnings.filterwarnings(action="ignore") |
|
|
|
|
|
def main(): |
|
parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) |
|
|
|
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): |
|
model_args, data_args, training_args = parser.parse_json_file( |
|
json_file=os.path.abspath(sys.argv[1]) |
|
) |
|
else: |
|
model_args, data_args, training_args = parser.parse_args_into_dataclasses() |
|
model_args: ModelArguments |
|
data_args: DataArguments |
|
training_args: TrainingArguments |
|
|
|
if training_args.local_rank > 0 or training_args.n_gpu > 1: |
|
raise NotImplementedError("Multi-GPU is not supported.") |
|
|
|
if os.path.exists(os.path.join(model_args.infer_dir, "model_args.json")): |
|
print( |
|
f"> Load model arguments from {os.path.join(model_args.infer_dir, 'model_args.json')}" |
|
) |
|
|
|
|
|
cli_infer_dir = model_args.infer_dir |
|
cli_infer_ckpt = model_args.infer_ckpt |
|
cli_model_name_or_path = model_args.model_name_or_path |
|
cli_logs_dir = model_args.logs_dir |
|
cli_docid_file_name = model_args.docid_file_name |
|
|
|
with open(os.path.join(model_args.infer_dir, "model_args.json"), "r") as f: |
|
model_args_dict = json.load(f) |
|
|
|
|
|
import inspect |
|
model_args_signature = inspect.signature(ModelArguments.__init__) |
|
valid_args = set(model_args_signature.parameters.keys()) - {'self'} |
|
filtered_args = {k: v for k, v in model_args_dict.items() if k in valid_args} |
|
|
|
model_args = ModelArguments(**filtered_args) |
|
|
|
|
|
model_args.infer_dir = cli_infer_dir |
|
model_args.infer_ckpt = cli_infer_ckpt |
|
model_args.model_name_or_path = cli_model_name_or_path |
|
model_args.logs_dir = cli_logs_dir |
|
if cli_docid_file_name: |
|
model_args.docid_file_name = cli_docid_file_name |
|
else: |
|
print(f"> Not found model arguments from {os.path.join(model_args.infer_dir)}") |
|
|
|
|
|
logging.basicConfig( |
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", |
|
datefmt="%m/%d/%Y %H:%M:%S", |
|
level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN, |
|
) |
|
set_seed(training_args.seed) |
|
|
|
if model_args.model_name_or_path == "t5-large": |
|
model_args.num_layers = 24 |
|
model_args.num_decoder_layers = 24 |
|
model_args.d_ff = 4096 |
|
model_args.d_model = 1024 |
|
model_args.num_heads = 16 |
|
model_args.d_kv = 64 |
|
|
|
|
|
if not hasattr(model_args, 'max_output_length'): |
|
model_args.max_output_length = model_args.num_multi_vectors + 1 |
|
|
|
data_args.max_output_length = model_args.max_output_length |
|
|
|
|
|
base_model_name = model_args.model_name_or_path |
|
if os.path.isdir(model_args.model_name_or_path): |
|
|
|
base_model_name = "t5-base" |
|
print(f"> Using base model '{base_model_name}' for model loading") |
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
model_args.tokenizer_name |
|
if model_args.tokenizer_name |
|
else base_model_name, |
|
cache_dir=model_args.cache_dir, |
|
use_fast=True, |
|
) |
|
decode_vocab_size = 32128 if len(tokenizer) == 32100 else len(tokenizer) |
|
|
|
|
|
if model_args.config_name: |
|
config_path = model_args.config_name |
|
else: |
|
|
|
config_path = base_model_name |
|
print(f"> Using config from base model: {config_path}") |
|
|
|
config = AutoConfig.from_pretrained( |
|
config_path, |
|
num_layers=model_args.num_layers, |
|
num_decoder_layers=model_args.num_decoder_layers, |
|
d_ff=model_args.d_ff, |
|
d_model=model_args.d_model, |
|
num_heads=model_args.num_heads, |
|
decoder_start_token_id=0, |
|
output_past=True, |
|
d_kv=model_args.d_kv, |
|
dropout_rate=model_args.dropout_rate, |
|
decode_vocab_size=decode_vocab_size, |
|
tie_word_embeddings=model_args.tie_word_embeddings, |
|
tie_decode_embeddings=model_args.tie_decode_embeddings, |
|
num_labels=1, |
|
cache_dir=model_args.cache_dir, |
|
) |
|
|
|
original_model_path = model_args.model_name_or_path |
|
model_args.model_name_or_path = base_model_name |
|
|
|
model = GLENP2Model.load( |
|
model_args=model_args, |
|
tokenizer=tokenizer, |
|
config=config, |
|
cache_dir=model_args.cache_dir, |
|
) |
|
|
|
|
|
model_args.model_name_or_path = original_model_path |
|
|
|
|
|
if not os.path.exists(model_args.logs_dir): |
|
os.makedirs(model_args.logs_dir) |
|
|
|
infer_ckpt_info = "_".join(model_args.infer_dir.split("/")) |
|
training_args.res1_save_path = os.path.join( |
|
model_args.logs_dir, |
|
f'{time.strftime("%Y%m%d-%H%M%S")}_res1_recall{model_args.num_return_sequences}_{data_args.dataset_name}_{str(model.__class__.__name__)}_{infer_ckpt_info}.tsv', |
|
) |
|
|
|
|
|
if model_args.infer_ckpt: |
|
ckpt_path = model_args.infer_ckpt |
|
else: |
|
|
|
root_model_bin = os.path.join(model_args.infer_dir, "pytorch_model.bin") |
|
root_model_safetensors = os.path.join(model_args.infer_dir, "model.safetensors") |
|
|
|
if os.path.exists(root_model_bin): |
|
ckpt_path = root_model_bin |
|
elif os.path.exists(root_model_safetensors): |
|
ckpt_path = root_model_safetensors |
|
else: |
|
|
|
checkpoint_dirs = [d for d in os.listdir(model_args.infer_dir) |
|
if d.startswith("checkpoint-") and os.path.isdir(os.path.join(model_args.infer_dir, d))] |
|
if checkpoint_dirs: |
|
|
|
checkpoint_dirs.sort(key=lambda x: int(x.split("-")[1])) |
|
latest_checkpoint = checkpoint_dirs[-1] |
|
|
|
|
|
safetensors_path = os.path.join(model_args.infer_dir, latest_checkpoint, "model.safetensors") |
|
bin_path = os.path.join(model_args.infer_dir, latest_checkpoint, "pytorch_model.bin") |
|
|
|
if os.path.exists(safetensors_path): |
|
ckpt_path = safetensors_path |
|
elif os.path.exists(bin_path): |
|
ckpt_path = bin_path |
|
else: |
|
raise FileNotFoundError(f"No model checkpoint found in {model_args.infer_dir}") |
|
|
|
print(f"> Using latest checkpoint: {latest_checkpoint}") |
|
else: |
|
raise FileNotFoundError(f"No model checkpoint found in {model_args.infer_dir}") |
|
|
|
|
|
if ckpt_path.endswith('.safetensors'): |
|
from safetensors.torch import load_file |
|
state_dict = load_file(ckpt_path, device="cpu") |
|
else: |
|
state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=False) |
|
if "state_dict" in state_dict: |
|
state_dict = state_dict["state_dict"] |
|
|
|
if model_args.untie_encoder: |
|
model.lm_q.load_state_dict(state_dict, strict=False) |
|
model.lm_p.load_state_dict(state_dict, strict=False) |
|
else: |
|
model.lm_q.load_state_dict(state_dict, strict=False) |
|
model.lm_p = model.lm_q |
|
|
|
print(f"> Restored parameters from checkpoint {ckpt_path}") |
|
|
|
|
|
if "lm_head.weight" in model.lm_p.state_dict() and model_args.untie_encoder: |
|
state_dict = model.lm_p.state_dict() |
|
model.lm_p.shared.weight.data.copy_(state_dict["shared.weight"]) |
|
model.lm_p.lm_head.weight.data.copy_(model.lm_p.shared.weight.data) |
|
state_dict = model.lm_q.state_dict() |
|
model.lm_q.shared.weight.data.copy_(state_dict["shared.weight"]) |
|
model.lm_q.lm_head.weight.data.copy_(model.lm_q.shared.weight.data) |
|
elif "lm_head.weight" in model.lm_p.state_dict() and not model_args.untie_encoder: |
|
state_dict = model.lm_p.state_dict() |
|
model.lm_p.shared.weight.data.copy_(state_dict["shared.weight"]) |
|
model.lm_p.lm_head.weight.data.copy_(model.lm_p.shared.weight.data) |
|
model.lm_q = model.lm_p |
|
|
|
del state_dict |
|
|
|
|
|
if data_args.dataset_name in ["nq320k", "marco_passage", "nfcorpus", "arguana", "the_vault"]: |
|
encode_dataset = GLENP2EncodeDataset( |
|
data_args=data_args, |
|
tokenizer=tokenizer, |
|
max_len=data_args.max_input_length, |
|
task="infer_qry", |
|
) |
|
else: |
|
raise NotImplementedError(f"{data_args.dataset_name} is not supported") |
|
|
|
encode_loader = DataLoader( |
|
encode_dataset, |
|
batch_size=training_args.per_device_eval_batch_size, |
|
shuffle=False, |
|
drop_last=False, |
|
) |
|
model = model.to(training_args.device) |
|
model.eval() |
|
|
|
model.tokenizer = tokenizer |
|
if model_args.mask_special_tokens_for_decoding: |
|
special_token_ids = tokenizer.all_special_ids |
|
model_args.special_token_ids = [ |
|
x |
|
for x in special_token_ids |
|
if x |
|
not in [ |
|
tokenizer.bos_token_id, |
|
tokenizer.eos_token_id, |
|
tokenizer.pad_token_id, |
|
] |
|
] |
|
max_output_length = data_args.max_output_length |
|
|
|
|
|
docid_file_name = ( |
|
"/".join(model_args.infer_dir.split("/")[:-1]) |
|
+ "/" |
|
+ model_args.docid_file_name |
|
+ ".tsv" |
|
) |
|
docid_df = pd.read_csv( |
|
docid_file_name, |
|
sep="\t", |
|
names=["oldid", "docid", "docid_logit", "text"], |
|
dtype={"oldid": str, "docid": str, "docid_logit": str, "text": str}, |
|
).loc[:, ["oldid", "docid", "docid_logit"]] |
|
|
|
num_uniques = len(set(docid_df.docid)) |
|
unique_ratio = num_uniques / len(docid_df) |
|
num_unique_tokens = [0] * 10 |
|
docid2num_docs = dict(docid_df.docid.value_counts()) |
|
for docid in docid_df.docid.unique(): |
|
tokens = docid.split("<->") |
|
num_unique_token = len(set(tokens)) |
|
num_unique_tokens[num_unique_token - 1] += 1 |
|
print(f"num_uniques: {num_uniques}/{len(docid_df)} ({unique_ratio*100:.2f}%)") |
|
print("[Frequent Collision]", docid_df.docid.value_counts()[:5], sep="\n") |
|
print(f"distribution of number of unique tokens: {num_unique_tokens}") |
|
|
|
docid2oldids = defaultdict(list) |
|
oldid2docid_logit = dict() |
|
for docid, oldid, docid_logit in docid_df[["docid", "oldid", "docid_logit"]].values: |
|
docid2oldids[docid].append(oldid) |
|
oldid2docid_logit[oldid] = torch.tensor( |
|
[float(x) for x in docid_logit.split("<->")] |
|
) |
|
|
|
oldid2docid = dict(zip(docid_df.oldid, docid_df.docid)) |
|
if model_args.tree == 1: |
|
builder = TreeBuilder() |
|
all_id = [] |
|
for docid in list(oldid2docid.values()): |
|
toks = docid.split("<->") |
|
toks = tokenizer.convert_tokens_to_ids(toks) |
|
if len(toks) != max_output_length - 1: |
|
print(toks, docid, "is not equal to max_output_length") |
|
toks = toks[: max_output_length - 1] |
|
all_id.append(toks) |
|
builder.add(toks) |
|
model.root = builder.build() |
|
|
|
|
|
model.docid2num_docs = docid2num_docs |
|
model.docid2oldids = docid2oldids |
|
model.oldid2docid_logit = oldid2docid_logit |
|
model.oldid2docid = oldid2docid |
|
|
|
|
|
encode_loader = DataLoader( |
|
encode_dataset, |
|
batch_size=training_args.per_device_eval_batch_size, |
|
shuffle=False, |
|
drop_last=False, |
|
) |
|
model = model.to(training_args.device) |
|
model.eval() |
|
|
|
texts, preds, labels, ranks = [], [], [], [] |
|
for batch in tqdm(encode_loader, dynamic_ncols=True, desc="query processing"): |
|
text, pred, label, rank = model.evaluation_step(batch, predefined_id=False) |
|
pred = [",".join(p) for p in pred] |
|
|
|
texts += text |
|
preds += pred |
|
labels += label |
|
ranks += rank |
|
|
|
|
|
res = pd.DataFrame( |
|
list(zip(texts, preds, labels, ranks)), columns=["query", "pred", "gt", "rank"] |
|
) |
|
res["rank"] = res["rank"].astype(int) |
|
res.sort_values(by=["query", "rank"], ascending=True, inplace=True) |
|
res1 = res.loc[res["rank"] == 1] |
|
res1.to_csv( |
|
training_args.res1_save_path, mode="w", sep="\t", header=None, index=False |
|
) |
|
|
|
|
|
training_args.recall_num = [1, 10, 100] |
|
training_args.ndcg_num = [1, 10, 100] |
|
training_args.mrr_num = [10, 100] |
|
|
|
|
|
training_args.recall_num = [ |
|
x for x in training_args.recall_num if x <= model_args.num_return_sequences |
|
] |
|
training_args.ndcg_num = [ |
|
x for x in training_args.ndcg_num if x <= model_args.num_return_sequences |
|
] |
|
training_args.mrr_num = [ |
|
x for x in training_args.mrr_num if x <= model_args.num_return_sequences |
|
] |
|
|
|
|
|
training_args.dataset_name = data_args.dataset_name |
|
training_args.unseen_query_set, training_args.seen_query_set = None, None |
|
|
|
if data_args.dataset_name == "nq320k": |
|
seen_query_df = pd.read_csv( |
|
"data/nq320k/GTQ_NQ_dev_seen.tsv", sep="\t", dtype=str |
|
) |
|
unseen_query_df = pd.read_csv( |
|
"data/nq320k/GTQ_NQ_dev_unseen.tsv", sep="\t", dtype=str |
|
) |
|
training_args.unseen_query_set = set(unseen_query_df["query"]) |
|
training_args.seen_query_set = set(seen_query_df["query"]) |
|
print( |
|
f"> Loading unseen query (#:{len(training_args.unseen_query_set)}) and seen query (#:{len(training_args.seen_query_set)})" |
|
) |
|
|
|
compute_recall(training_args, cutoff=training_args.recall_num) |
|
compute_mrr(training_args, cutoff=training_args.mrr_num) |
|
elif data_args.dataset_name in ["marco_passage", "the_vault"]: |
|
compute_recall(training_args, cutoff=training_args.recall_num) |
|
compute_mrr(training_args, cutoff=training_args.mrr_num) |
|
else: |
|
evaluate_beir(training_args, tokenizer, encode_dataset) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|