import math import re from collections import defaultdict from datetime import datetime from typing import List import torch import torch.optim as optim from rex import accelerator from rex.data.data_manager import DataManager from rex.data.dataset import CachedDataset, StreamReadDataset from rex.tasks.simple_metric_task import SimpleMetricTask from rex.utils.batch import decompose_batch_into_instances from rex.utils.config import ConfigParser from rex.utils.dict import flatten_dict from rex.utils.io import load_jsonlines from rex.utils.registry import register from torch.utils.tensorboard import SummaryWriter from transformers.optimization import ( get_cosine_schedule_with_warmup, get_linear_schedule_with_warmup, ) from .metric import MrcNERMetric, MrcSpanMetric, MultiPartSpanMetric from .model import ( MrcGlobalPointerModel, MrcPointerMatrixModel, SchemaGuidedInstructBertModel, ) from .transform import ( CachedLabelPointerTransform, CachedPointerMRCTransform, CachedPointerTaggingTransform, ) @register("task") class MrcTaggingTask(SimpleMetricTask): def __init__(self, config, **kwargs) -> None: super().__init__(config, **kwargs) def after_initialization(self): now_string = datetime.now().strftime("%Y-%m-%d-%H-%M-%S") self.tb_logger: SummaryWriter = SummaryWriter( log_dir=self.task_path / "tb_summary" / now_string, comment=self.config.comment, ) def after_whole_train(self): self.tb_logger.close() def get_grad_norm(self): # for name, param in self.model.named_parameters(): # if param.grad is not None: # grads = param.grad.detach().data # grad_norm = (grads.norm(p=2) / grads.numel()).item() total_norm = 0.0 for p in self.model.parameters(): if p.grad is not None: param_norm = p.grad.detach().data.norm(2) total_norm += param_norm.item() ** 2 total_norm = total_norm ** (1.0 / 2) return total_norm def log_loss( self, idx: int, loss_item: float, step_or_epoch: str, dataset_name: str ): self.tb_logger.add_scalar( f"loss/{dataset_name}/{step_or_epoch}", loss_item, idx ) # self.tb_logger.add_scalars( # "lr", # { # str(i): self.optimizer.param_groups[i]["lr"] # for i in range(len(self.optimizer.param_groups)) # }, # idx, # ) self.tb_logger.add_scalar("lr", self.optimizer.param_groups[0]["lr"], idx) self.tb_logger.add_scalar("grad_norm_total", self.get_grad_norm(), idx) def log_metrics( self, idx: int, metrics: dict, step_or_epoch: str, dataset_name: str ): metrics = flatten_dict(metrics) self.tb_logger.add_scalars(f"{dataset_name}/{step_or_epoch}", metrics, idx) def init_transform(self): return CachedPointerTaggingTransform( self.config.max_seq_len, self.config.plm_dir, self.config.ent_type2query_filepath, mode=self.config.mode, negative_sample_prob=self.config.negative_sample_prob, ) def init_data_manager(self): return DataManager( self.config.train_filepath, self.config.dev_filepath, self.config.test_filepath, CachedDataset, self.transform, load_jsonlines, self.config.train_batch_size, self.config.eval_batch_size, self.transform.collate_fn, use_stream_transform=False, debug_mode=self.config.debug_mode, dump_cache_dir=self.config.dump_cache_dir, regenerate_cache=self.config.regenerate_cache, ) def init_model(self): # m = MrcPointerMatrixModel( m = MrcGlobalPointerModel( self.config.plm_dir, biaffine_size=self.config.biaffine_size, dropout=self.config.dropout, mode=self.config.mode, ) return m def init_metric(self): return MrcNERMetric() def init_optimizer(self): no_decay = r"(embedding|LayerNorm|\.bias$)" plm_lr = r"^plm\." non_trainable = r"^plm\.(emb|encoder\.layer\.[0-3])" param_groups = [] for name, param in self.model.named_parameters(): lr = self.config.learning_rate weight_decay = self.config.weight_decay if re.search(non_trainable, name): param.requires_grad = False if not re.search(plm_lr, name): lr = self.config.other_learning_rate if re.search(no_decay, name): weight_decay = 0.0 param_groups.append( {"params": param, "lr": lr, "weight_decay": weight_decay} ) return optim.AdamW( param_groups, lr=self.config.learning_rate, betas=(0.9, 0.98), eps=1e-6, ) def init_lr_scheduler(self): num_training_steps = int( len(self.data_manager.train_loader) * self.config.num_epochs * accelerator.num_processes ) num_warmup_steps = math.floor( num_training_steps * self.config.warmup_proportion ) # return get_linear_schedule_with_warmup( return get_cosine_schedule_with_warmup( self.optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, ) def predict_api(self, texts: List[str], **kwargs): raw_dataset = self.transform.predict_transform(texts) text_ids = sorted(list({ins["id"] for ins in raw_dataset})) loader = self.data_manager.prepare_loader(raw_dataset) # to prepare input device loader = accelerator.prepare_data_loader(loader) id2ents = defaultdict(set) for batch in loader: batch_out = self.model(**batch, is_eval=True) for _id, _pred in zip(batch["id"], batch_out["pred"]): id2ents[_id].update(_pred) results = [id2ents[_id] for _id in text_ids] return results @register("task") class MrcQaTask(MrcTaggingTask): def init_transform(self): return CachedPointerMRCTransform( self.config.max_seq_len, self.config.plm_dir, mode=self.config.mode, ) def init_model(self): # m = MrcPointerMatrixModel( m = MrcGlobalPointerModel( self.config.plm_dir, biaffine_size=self.config.biaffine_size, dropout=self.config.dropout, mode=self.config.mode, ) return m def init_metric(self): return MrcSpanMetric() def predict_api(self, data: list[dict], **kwargs): """ Args: data: a list of dict with query, context, and background strings """ raw_dataset = self.transform.predict_transform(data) loader = self.data_manager.prepare_loader(raw_dataset) results = [] for batch in loader: batch_out = self.model(**batch, is_eval=True) batch["pred"] = batch_out["pred"] instances = decompose_batch_into_instances(batch) for ins in instances: preds = ins["pred"] ins_results = [] for index_list in preds: ins_result = [] for i in index_list: ins_result.append(ins["raw_tokens"][i]) ins_results.append(("".join(ins_result), tuple(index_list))) results.append(ins_results) return results class StreamReadDatasetWithLen(StreamReadDataset): def __len__(self): return 631346 @register("task") class SchemaGuidedInstructBertTask(MrcTaggingTask): # def __init__(self, config, **kwargs) -> None: # super().__init__(config, **kwargs) # from watchmen import ClientMode, WatchClient # client = WatchClient( # id=config.task_name, # gpus=[4], # req_gpu_num=1, # mode=ClientMode.SCHEDULE, # server_host="127.0.0.1", # server_port=62333, # ) # client.wait() # def init_lr_scheduler(self): # num_training_steps = int( # 631346 / self.config.train_batch_size # * self.config.num_epochs # * accelerator.num_processes # ) # num_warmup_steps = math.floor( # num_training_steps * self.config.warmup_proportion # ) # # return get_linear_schedule_with_warmup( # return get_cosine_schedule_with_warmup( # self.optimizer, # num_warmup_steps=num_warmup_steps, # num_training_steps=num_training_steps, # ) def init_transform(self): self.transform: CachedLabelPointerTransform return CachedLabelPointerTransform( self.config.max_seq_len, self.config.plm_dir, mode=self.config.mode, label_span=self.config.label_span, include_instructions=self.config.get("include_instructions", True), ) def init_data_manager(self): if self.config.get("stream_mode", False): DatasetClass = StreamReadDatasetWithLen transform = self.transform.transform else: DatasetClass = CachedDataset transform = self.transform return DataManager( self.config.train_filepath, self.config.dev_filepath, self.config.test_filepath, DatasetClass, transform, load_jsonlines, self.config.train_batch_size, self.config.eval_batch_size, self.transform.collate_fn, use_stream_transform=self.config.get("stream_mode", False), debug_mode=self.config.debug_mode, dump_cache_dir=self.config.dump_cache_dir, regenerate_cache=self.config.regenerate_cache, ) def init_model(self): self.model = SchemaGuidedInstructBertModel( self.config.plm_dir, vocab_size=len(self.transform.tokenizer), use_rope=self.config.use_rope, biaffine_size=self.config.biaffine_size, dropout=self.config.dropout, ) if self.config.get("base_model_path"): self.load( self.config.base_model_path, load_config=False, load_model=True, load_optimizer=False, load_history=False, ) return self.model def init_optimizer(self): no_decay = r"(embedding|LayerNorm|\.bias$)" plm_lr = r"^plm\." # non_trainable = r"^plm\.(emb|encoder\.layer\.[0-3])" non_trainable = "no_non_trainable" param_groups = [] for name, param in self.model.named_parameters(): lr = self.config.learning_rate weight_decay = self.config.weight_decay if re.search(non_trainable, name): param.requires_grad = False if not re.search(plm_lr, name): lr = self.config.other_learning_rate if re.search(no_decay, name): weight_decay = 0.0 param_groups.append( {"params": param, "lr": lr, "weight_decay": weight_decay} ) return optim.AdamW( param_groups, lr=self.config.learning_rate, betas=(0.9, 0.98), eps=1e-6, ) def init_metric(self): return MultiPartSpanMetric() def _convert_span_to_string(self, span, token_ids, tokenizer): string = "" if len(span) == 0 or len(span) > 2: pass elif len(span) == 1: string = tokenizer.decode(token_ids[span[0]]) elif len(span) == 2: string = tokenizer.decode(token_ids[span[0] : span[1] + 1]) return (string, self.reset_position(token_ids, span)) def reset_position(self, token_ids: list[int], span: list[int]) -> list[int]: if isinstance(token_ids, torch.Tensor): input_ids = token_ids.cpu().tolist() if len(span) < 1: return span tp_token_id, tl_token_id = self.transform.tokenizer.convert_tokens_to_ids( [self.transform.tp_token, self.transform.tl_token] ) offset = 0 if tp_token_id in input_ids: offset = input_ids.index(tp_token_id) + 1 elif tl_token_id in input_ids: offset = input_ids.index(tl_token_id) + 1 return [i - offset for i in span] def predict_api(self, data: list[dict], **kwargs): """ Args: data: a list of dict in UDI: { "id": str, "instruction": str, "schema": { "ent": list, "rel": list, "event": dict, "cls": list, "discontinuous_ent": list, "hyper_rel": dict }, "text": str, "bg": str, "ans": {}, # empty dict } """ raw_dataset = [self.transform.transform(d) for d in data] loader = self.data_manager.prepare_loader(raw_dataset) results = [] for batch in loader: batch_out = self.model(**batch, is_eval=True) batch["pred"] = batch_out["pred"] instances = decompose_batch_into_instances(batch) for ins in instances: pred_clses = [] pred_ents = [] pred_rels = [] pred_trigger_to_event = defaultdict( lambda: {"event_type": "", "arguments": []} ) pred_events = [] pred_spans = [] pred_discon_ents = [] pred_hyper_rels = [] raw_schema = ins["raw"]["schema"] for multi_part_span in ins["pred"]: span = tuple(multi_part_span) span_to_label = ins["span_to_label"] if span[0] in span_to_label: label = span_to_label[span[0]] if label["task"] == "cls" and len(span) == 1: pred_clses.append(label["string"]) elif label["task"] == "ent" and len(span) == 2: string = self._convert_span_to_string( span[1], ins["input_ids"], self.transform.tokenizer ) pred_ents.append((label["string"], string)) elif label["task"] == "rel" and len(span) == 3: head = self._convert_span_to_string( span[1], ins["input_ids"], self.transform.tokenizer ) tail = self._convert_span_to_string( span[2], ins["input_ids"], self.transform.tokenizer ) pred_rels.append((label["string"], head, tail)) elif label["task"] == "event": if label["type"] == "lm" and len(span) == 2: pred_trigger_to_event[span[1]]["event_type"] = label["string"] # fmt: skip elif label["type"] == "lr" and len(span) == 3: arg = self._convert_span_to_string( span[2], ins["input_ids"], self.transform.tokenizer ) pred_trigger_to_event[span[1]]["arguments"].append( {"argument": arg, "role": label["string"]} ) elif label["task"] == "discontinuous_ent" and len(span) > 1: parts = [ self._convert_span_to_string( part, ins["input_ids"], self.transform.tokenizer ) for part in span[1:] ] string = " ".join([part[0] for part in parts]) position = [] for part in parts: position.append(part[1]) pred_discon_ents.append( (label["string"], string, self.reset_position(position)) ) elif label["task"] == "hyper_rel" and len(span) == 5 and span[3] in span_to_label: # fmt: skip q_label = span_to_label[span[3]] span_1 = self._convert_span_to_string( span[1], ins["input_ids"], self.transform.tokenizer ) span_2 = self._convert_span_to_string( span[2], ins["input_ids"], self.transform.tokenizer ) span_4 = self._convert_span_to_string( span[4], ins["input_ids"], self.transform.tokenizer ) pred_hyper_rels.append((label["string"], span_1, span_2, q_label["string"], span_4)) # fmt: skip else: # span task has no labels pred_token_ids = [] for part in span: _pred_token_ids = [ins["input_ids"][i] for i in part] pred_token_ids.extend(_pred_token_ids) span_string = self.transform.tokenizer.decode(pred_token_ids) pred_spans.append( ( span_string, tuple( [ tuple( self.reset_position( ins["input_ids"].cpu().tolist(), part ) ) for part in span ] ), ) ) for trigger, item in pred_trigger_to_event.items(): trigger = self._convert_span_to_string( trigger, ins["input_ids"], self.transform.tokenizer ) if item["event_type"] not in raw_schema["event"]: continue legal_roles = raw_schema["event"][item["event_type"]] pred_events.append( { "trigger": trigger, "event_type": item["event_type"], "arguments": [ arg for arg in filter( lambda arg: arg["role"] in legal_roles, item["arguments"], ) ], } ) results.append( { "id": ins["raw"]["id"], "results": { "cls": pred_clses, "ent": pred_ents, "rel": pred_rels, "event": pred_events, "span": pred_spans, "discon_ent": pred_discon_ents, "hyper_rel": pred_hyper_rels, }, } ) return results if __name__ == "__main__": pass # further_finetune() # from rex.utils.config import ConfigParser # config = ConfigParser.parse_cmd(cmd_args=["-dc", "conf/ner.yaml"]) # config = ConfigParser.parse_cmd(cmd_args=["-dc", "conf/mirror-ace05en.yaml"]) # task = MrcTaggingTask( # config, # initialize=True, # makedirs=True, # dump_configfile=True, # ) # task = SchemaGuidedInstructBertTask.from_taskdir( # "outputs/InstructBert_TagSpan_DebertaV3Base_ACE05EN_Rel", # initialize=True, # load_config=True, # dump_configfile=False, # ) # task = SchemaGuidedInstructBertTask( # config, # initialize=True, # makedirs=True, # dump_configfile=False, # ) # task.load( # "outputs/InstructBert_TagSpan_DebertaV3Base_ACE05EN_NerRelEvent/ckpt/SchemaGuidedInstructBertModel.epoch.0.pth", # load_config=False, # ) # task.eval("test", verbose=True, dump=True, dump_middle=True, postfix="re_eval") # task.load( # # "outputs/Mirror_RobertaBaseWwm_Cons_MsraMrc/ckpt/MrcGlobalPointerModel.best.pth", # # "outputs/Mirror_RobertaBaseWwm_W2_MsraMrc_HyperParamExp1/ckpt/MrcGlobalPointerModel.best.pth", # config.base_model_path, # load_config=False, # load_model=True, # load_optimizer=False, # load_history=False, # ) # task.train() # task = MrcTaggingTask.from_taskdir( # "outputs/Mirror_W2_MSRAv2_NER", # initialize=True, # dump_configfile=False, # load_config=True, # ) # for name, _ in task.model.named_parameters(): # print(name) # task.eval("test", verbose=True, dump=True, dump_middle=True, postfix="re_eval.0.1") # task = MrcQaTask( # config, # initialize=True, # makedirs=True, # dump_configfile=True, # ) # task.train() # task.eval("dev", verbose=True, dump=True, dump_middle=True, postfix="re_eval")