Mirror / src /task.py
Spico's picture
fix assignment before ref bug, add examples in index
959183d
raw
history blame contribute delete
No virus
22.5 kB
import math
import re
from collections import defaultdict
from datetime import datetime
from typing import List
import torch
import torch.optim as optim
from rex import accelerator
from rex.data.data_manager import DataManager
from rex.data.dataset import CachedDataset, StreamReadDataset
from rex.tasks.simple_metric_task import SimpleMetricTask
from rex.utils.batch import decompose_batch_into_instances
from rex.utils.config import ConfigParser
from rex.utils.dict import flatten_dict
from rex.utils.io import load_jsonlines
from rex.utils.registry import register
from torch.utils.tensorboard import SummaryWriter
from transformers.optimization import (
get_cosine_schedule_with_warmup,
get_linear_schedule_with_warmup,
)
from .metric import MrcNERMetric, MrcSpanMetric, MultiPartSpanMetric
from .model import (
MrcGlobalPointerModel,
MrcPointerMatrixModel,
SchemaGuidedInstructBertModel,
)
from .transform import (
CachedLabelPointerTransform,
CachedPointerMRCTransform,
CachedPointerTaggingTransform,
)
@register("task")
class MrcTaggingTask(SimpleMetricTask):
def __init__(self, config, **kwargs) -> None:
super().__init__(config, **kwargs)
def after_initialization(self):
now_string = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
self.tb_logger: SummaryWriter = SummaryWriter(
log_dir=self.task_path / "tb_summary" / now_string,
comment=self.config.comment,
)
def after_whole_train(self):
self.tb_logger.close()
def get_grad_norm(self):
# for name, param in self.model.named_parameters():
# if param.grad is not None:
# grads = param.grad.detach().data
# grad_norm = (grads.norm(p=2) / grads.numel()).item()
total_norm = 0.0
for p in self.model.parameters():
if p.grad is not None:
param_norm = p.grad.detach().data.norm(2)
total_norm += param_norm.item() ** 2
total_norm = total_norm ** (1.0 / 2)
return total_norm
def log_loss(
self, idx: int, loss_item: float, step_or_epoch: str, dataset_name: str
):
self.tb_logger.add_scalar(
f"loss/{dataset_name}/{step_or_epoch}", loss_item, idx
)
# self.tb_logger.add_scalars(
# "lr",
# {
# str(i): self.optimizer.param_groups[i]["lr"]
# for i in range(len(self.optimizer.param_groups))
# },
# idx,
# )
self.tb_logger.add_scalar("lr", self.optimizer.param_groups[0]["lr"], idx)
self.tb_logger.add_scalar("grad_norm_total", self.get_grad_norm(), idx)
def log_metrics(
self, idx: int, metrics: dict, step_or_epoch: str, dataset_name: str
):
metrics = flatten_dict(metrics)
self.tb_logger.add_scalars(f"{dataset_name}/{step_or_epoch}", metrics, idx)
def init_transform(self):
return CachedPointerTaggingTransform(
self.config.max_seq_len,
self.config.plm_dir,
self.config.ent_type2query_filepath,
mode=self.config.mode,
negative_sample_prob=self.config.negative_sample_prob,
)
def init_data_manager(self):
return DataManager(
self.config.train_filepath,
self.config.dev_filepath,
self.config.test_filepath,
CachedDataset,
self.transform,
load_jsonlines,
self.config.train_batch_size,
self.config.eval_batch_size,
self.transform.collate_fn,
use_stream_transform=False,
debug_mode=self.config.debug_mode,
dump_cache_dir=self.config.dump_cache_dir,
regenerate_cache=self.config.regenerate_cache,
)
def init_model(self):
# m = MrcPointerMatrixModel(
m = MrcGlobalPointerModel(
self.config.plm_dir,
biaffine_size=self.config.biaffine_size,
dropout=self.config.dropout,
mode=self.config.mode,
)
return m
def init_metric(self):
return MrcNERMetric()
def init_optimizer(self):
no_decay = r"(embedding|LayerNorm|\.bias$)"
plm_lr = r"^plm\."
non_trainable = r"^plm\.(emb|encoder\.layer\.[0-3])"
param_groups = []
for name, param in self.model.named_parameters():
lr = self.config.learning_rate
weight_decay = self.config.weight_decay
if re.search(non_trainable, name):
param.requires_grad = False
if not re.search(plm_lr, name):
lr = self.config.other_learning_rate
if re.search(no_decay, name):
weight_decay = 0.0
param_groups.append(
{"params": param, "lr": lr, "weight_decay": weight_decay}
)
return optim.AdamW(
param_groups,
lr=self.config.learning_rate,
betas=(0.9, 0.98),
eps=1e-6,
)
def init_lr_scheduler(self):
num_training_steps = int(
len(self.data_manager.train_loader)
* self.config.num_epochs
* accelerator.num_processes
)
num_warmup_steps = math.floor(
num_training_steps * self.config.warmup_proportion
)
# return get_linear_schedule_with_warmup(
return get_cosine_schedule_with_warmup(
self.optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps,
)
def predict_api(self, texts: List[str], **kwargs):
raw_dataset = self.transform.predict_transform(texts)
text_ids = sorted(list({ins["id"] for ins in raw_dataset}))
loader = self.data_manager.prepare_loader(raw_dataset)
# to prepare input device
loader = accelerator.prepare_data_loader(loader)
id2ents = defaultdict(set)
for batch in loader:
batch_out = self.model(**batch, is_eval=True)
for _id, _pred in zip(batch["id"], batch_out["pred"]):
id2ents[_id].update(_pred)
results = [id2ents[_id] for _id in text_ids]
return results
@register("task")
class MrcQaTask(MrcTaggingTask):
def init_transform(self):
return CachedPointerMRCTransform(
self.config.max_seq_len,
self.config.plm_dir,
mode=self.config.mode,
)
def init_model(self):
# m = MrcPointerMatrixModel(
m = MrcGlobalPointerModel(
self.config.plm_dir,
biaffine_size=self.config.biaffine_size,
dropout=self.config.dropout,
mode=self.config.mode,
)
return m
def init_metric(self):
return MrcSpanMetric()
def predict_api(self, data: list[dict], **kwargs):
"""
Args:
data: a list of dict with query, context, and background strings
"""
raw_dataset = self.transform.predict_transform(data)
loader = self.data_manager.prepare_loader(raw_dataset)
results = []
for batch in loader:
batch_out = self.model(**batch, is_eval=True)
batch["pred"] = batch_out["pred"]
instances = decompose_batch_into_instances(batch)
for ins in instances:
preds = ins["pred"]
ins_results = []
for index_list in preds:
ins_result = []
for i in index_list:
ins_result.append(ins["raw_tokens"][i])
ins_results.append(("".join(ins_result), tuple(index_list)))
results.append(ins_results)
return results
class StreamReadDatasetWithLen(StreamReadDataset):
def __len__(self):
return 631346
@register("task")
class SchemaGuidedInstructBertTask(MrcTaggingTask):
# def __init__(self, config, **kwargs) -> None:
# super().__init__(config, **kwargs)
# from watchmen import ClientMode, WatchClient
# client = WatchClient(
# id=config.task_name,
# gpus=[4],
# req_gpu_num=1,
# mode=ClientMode.SCHEDULE,
# server_host="127.0.0.1",
# server_port=62333,
# )
# client.wait()
# def init_lr_scheduler(self):
# num_training_steps = int(
# 631346 / self.config.train_batch_size
# * self.config.num_epochs
# * accelerator.num_processes
# )
# num_warmup_steps = math.floor(
# num_training_steps * self.config.warmup_proportion
# )
# # return get_linear_schedule_with_warmup(
# return get_cosine_schedule_with_warmup(
# self.optimizer,
# num_warmup_steps=num_warmup_steps,
# num_training_steps=num_training_steps,
# )
def init_transform(self):
self.transform: CachedLabelPointerTransform
return CachedLabelPointerTransform(
self.config.max_seq_len,
self.config.plm_dir,
mode=self.config.mode,
label_span=self.config.label_span,
include_instructions=self.config.get("include_instructions", True),
)
def init_data_manager(self):
if self.config.get("stream_mode", False):
DatasetClass = StreamReadDatasetWithLen
transform = self.transform.transform
else:
DatasetClass = CachedDataset
transform = self.transform
return DataManager(
self.config.train_filepath,
self.config.dev_filepath,
self.config.test_filepath,
DatasetClass,
transform,
load_jsonlines,
self.config.train_batch_size,
self.config.eval_batch_size,
self.transform.collate_fn,
use_stream_transform=self.config.get("stream_mode", False),
debug_mode=self.config.debug_mode,
dump_cache_dir=self.config.dump_cache_dir,
regenerate_cache=self.config.regenerate_cache,
)
def init_model(self):
self.model = SchemaGuidedInstructBertModel(
self.config.plm_dir,
vocab_size=len(self.transform.tokenizer),
use_rope=self.config.use_rope,
biaffine_size=self.config.biaffine_size,
dropout=self.config.dropout,
)
if self.config.get("base_model_path"):
self.load(
self.config.base_model_path,
load_config=False,
load_model=True,
load_optimizer=False,
load_history=False,
)
return self.model
def init_optimizer(self):
no_decay = r"(embedding|LayerNorm|\.bias$)"
plm_lr = r"^plm\."
# non_trainable = r"^plm\.(emb|encoder\.layer\.[0-3])"
non_trainable = "no_non_trainable"
param_groups = []
for name, param in self.model.named_parameters():
lr = self.config.learning_rate
weight_decay = self.config.weight_decay
if re.search(non_trainable, name):
param.requires_grad = False
if not re.search(plm_lr, name):
lr = self.config.other_learning_rate
if re.search(no_decay, name):
weight_decay = 0.0
param_groups.append(
{"params": param, "lr": lr, "weight_decay": weight_decay}
)
return optim.AdamW(
param_groups,
lr=self.config.learning_rate,
betas=(0.9, 0.98),
eps=1e-6,
)
def init_metric(self):
return MultiPartSpanMetric()
def _convert_span_to_string(self, span, token_ids, tokenizer):
string = ""
if len(span) == 0 or len(span) > 2:
pass
elif len(span) == 1:
string = tokenizer.decode(token_ids[span[0]])
elif len(span) == 2:
string = tokenizer.decode(token_ids[span[0] : span[1] + 1])
return (string, self.reset_position(token_ids, span))
def reset_position(self, input_ids: list[int], span: list[int]) -> list[int]:
if isinstance(input_ids, torch.Tensor):
input_ids = input_ids.cpu().tolist()
if len(span) < 1:
return span
tp_token_id, tl_token_id = self.transform.tokenizer.convert_tokens_to_ids(
[self.transform.tp_token, self.transform.tl_token]
)
offset = 0
if tp_token_id in input_ids:
offset = input_ids.index(tp_token_id) + 1
elif tl_token_id in input_ids:
offset = input_ids.index(tl_token_id) + 1
return [i - offset for i in span]
def predict_api(self, data: list[dict], **kwargs):
"""
Args:
data: a list of dict in UDI:
{
"id": str,
"instruction": str,
"schema": {
"ent": list,
"rel": list,
"event": dict,
"cls": list,
"discontinuous_ent": list,
"hyper_rel": dict
},
"text": str,
"bg": str,
"ans": {}, # empty dict
}
"""
raw_dataset = [self.transform.transform(d) for d in data]
loader = self.data_manager.prepare_loader(raw_dataset)
results = []
for batch in loader:
batch_out = self.model(**batch, is_eval=True)
batch["pred"] = batch_out["pred"]
instances = decompose_batch_into_instances(batch)
for ins in instances:
pred_clses = []
pred_ents = []
pred_rels = []
pred_trigger_to_event = defaultdict(
lambda: {"event_type": "", "arguments": []}
)
pred_events = []
pred_spans = []
pred_discon_ents = []
pred_hyper_rels = []
raw_schema = ins["raw"]["schema"]
for multi_part_span in ins["pred"]:
span = tuple(multi_part_span)
span_to_label = ins["span_to_label"]
if span[0] in span_to_label:
label = span_to_label[span[0]]
if label["task"] == "cls" and len(span) == 1:
pred_clses.append(label["string"])
elif label["task"] == "ent" and len(span) == 2:
string = self._convert_span_to_string(
span[1], ins["input_ids"], self.transform.tokenizer
)
pred_ents.append((label["string"], string))
elif label["task"] == "rel" and len(span) == 3:
head = self._convert_span_to_string(
span[1], ins["input_ids"], self.transform.tokenizer
)
tail = self._convert_span_to_string(
span[2], ins["input_ids"], self.transform.tokenizer
)
pred_rels.append((label["string"], head, tail))
elif label["task"] == "event":
if label["type"] == "lm" and len(span) == 2:
pred_trigger_to_event[span[1]]["event_type"] = label["string"] # fmt: skip
elif label["type"] == "lr" and len(span) == 3:
arg = self._convert_span_to_string(
span[2], ins["input_ids"], self.transform.tokenizer
)
pred_trigger_to_event[span[1]]["arguments"].append(
{"argument": arg, "role": label["string"]}
)
elif label["task"] == "discontinuous_ent" and len(span) > 1:
parts = [
self._convert_span_to_string(
part, ins["input_ids"], self.transform.tokenizer
)
for part in span[1:]
]
string = " ".join([part[0] for part in parts])
position = []
for part in parts:
position.append(part[1])
pred_discon_ents.append(
(label["string"], string, self.reset_position(position))
)
elif label["task"] == "hyper_rel" and len(span) == 5 and span[3] in span_to_label: # fmt: skip
q_label = span_to_label[span[3]]
span_1 = self._convert_span_to_string(
span[1], ins["input_ids"], self.transform.tokenizer
)
span_2 = self._convert_span_to_string(
span[2], ins["input_ids"], self.transform.tokenizer
)
span_4 = self._convert_span_to_string(
span[4], ins["input_ids"], self.transform.tokenizer
)
pred_hyper_rels.append((label["string"], span_1, span_2, q_label["string"], span_4)) # fmt: skip
else:
# span task has no labels
pred_token_ids = []
for part in span:
_pred_token_ids = [ins["input_ids"][i] for i in part]
pred_token_ids.extend(_pred_token_ids)
span_string = self.transform.tokenizer.decode(pred_token_ids)
pred_spans.append(
(
span_string,
tuple(
[
tuple(
self.reset_position(
ins["input_ids"].cpu().tolist(), part
)
)
for part in span
]
),
)
)
for trigger, item in pred_trigger_to_event.items():
trigger = self._convert_span_to_string(
trigger, ins["input_ids"], self.transform.tokenizer
)
if item["event_type"] not in raw_schema["event"]:
continue
legal_roles = raw_schema["event"][item["event_type"]]
pred_events.append(
{
"trigger": trigger,
"event_type": item["event_type"],
"arguments": [
arg
for arg in filter(
lambda arg: arg["role"] in legal_roles,
item["arguments"],
)
],
}
)
results.append(
{
"id": ins["raw"]["id"],
"results": {
"cls": pred_clses,
"ent": pred_ents,
"rel": pred_rels,
"event": pred_events,
"span": pred_spans,
"discon_ent": pred_discon_ents,
"hyper_rel": pred_hyper_rels,
},
}
)
return results
if __name__ == "__main__":
pass
# further_finetune()
# from rex.utils.config import ConfigParser
# config = ConfigParser.parse_cmd(cmd_args=["-dc", "conf/ner.yaml"])
# config = ConfigParser.parse_cmd(cmd_args=["-dc", "conf/mirror-ace05en.yaml"])
# task = MrcTaggingTask(
# config,
# initialize=True,
# makedirs=True,
# dump_configfile=True,
# )
# task = SchemaGuidedInstructBertTask.from_taskdir(
# "outputs/InstructBert_TagSpan_DebertaV3Base_ACE05EN_Rel",
# initialize=True,
# load_config=True,
# dump_configfile=False,
# )
# task = SchemaGuidedInstructBertTask(
# config,
# initialize=True,
# makedirs=True,
# dump_configfile=False,
# )
# task.load(
# "outputs/InstructBert_TagSpan_DebertaV3Base_ACE05EN_NerRelEvent/ckpt/SchemaGuidedInstructBertModel.epoch.0.pth",
# load_config=False,
# )
# task.eval("test", verbose=True, dump=True, dump_middle=True, postfix="re_eval")
# task.load(
# # "outputs/Mirror_RobertaBaseWwm_Cons_MsraMrc/ckpt/MrcGlobalPointerModel.best.pth",
# # "outputs/Mirror_RobertaBaseWwm_W2_MsraMrc_HyperParamExp1/ckpt/MrcGlobalPointerModel.best.pth",
# config.base_model_path,
# load_config=False,
# load_model=True,
# load_optimizer=False,
# load_history=False,
# )
# task.train()
# task = MrcTaggingTask.from_taskdir(
# "outputs/Mirror_W2_MSRAv2_NER",
# initialize=True,
# dump_configfile=False,
# load_config=True,
# )
# for name, _ in task.model.named_parameters():
# print(name)
# task.eval("test", verbose=True, dump=True, dump_middle=True, postfix="re_eval.0.1")
# task = MrcQaTask(
# config,
# initialize=True,
# makedirs=True,
# dump_configfile=True,
# )
# task.train()
# task.eval("dev", verbose=True, dump=True, dump_middle=True, postfix="re_eval")