Spaces:
Build error
Build error
from model.base_model import SummModel | |
import argparse | |
import os | |
import torch | |
import gzip | |
import json | |
from model.third_party.HMNet.Models.Trainers.HMNetTrainer import HMNetTrainer | |
from model.third_party.HMNet.Utils.Arguments import Arguments | |
import spacy | |
nlp = spacy.load("en_core_web_sm", disable=["parser"]) | |
# tagger = nlp.get_pipe('tagger') | |
# ner = nlp.get_pipe('ner') | |
# POS = {w: i for i, w in enumerate([''] + list(tagger.labels))} | |
# ENT = {w: i for i, w in enumerate([''] + list(ner.move_names))} | |
# These two dicts are adapted from SpaCy 2.3.1, since HMNet's embedding for POS and ENT is fixed | |
POS = { | |
"": 0, | |
"$": 1, | |
"''": 2, | |
",": 3, | |
"-LRB-": 4, | |
"-RRB-": 5, | |
".": 6, | |
":": 7, | |
"ADD": 8, | |
"AFX": 9, | |
"CC": 10, | |
"CD": 11, | |
"DT": 12, | |
"EX": 13, | |
"FW": 14, | |
"HYPH": 15, | |
"IN": 16, | |
"JJ": 17, | |
"JJR": 18, | |
"JJS": 19, | |
"LS": 20, | |
"MD": 21, | |
"NFP": 22, | |
"NN": 23, | |
"NNP": 24, | |
"NNPS": 25, | |
"NNS": 26, | |
"PDT": 27, | |
"POS": 28, | |
"PRP": 29, | |
"PRP$": 30, | |
"RB": 31, | |
"RBR": 32, | |
"RBS": 33, | |
"RP": 34, | |
"SYM": 35, | |
"TO": 36, | |
"UH": 37, | |
"VB": 38, | |
"VBD": 39, | |
"VBG": 40, | |
"VBN": 41, | |
"VBP": 42, | |
"VBZ": 43, | |
"WDT": 44, | |
"WP": 45, | |
"WP$": 46, | |
"WRB": 47, | |
"XX": 48, | |
"_SP": 49, | |
"``": 50, | |
} | |
ENT = { | |
"": 0, | |
"B-ORG": 1, | |
"B-DATE": 2, | |
"B-PERSON": 3, | |
"B-GPE": 4, | |
"B-MONEY": 5, | |
"B-CARDINAL": 6, | |
"B-NORP": 7, | |
"B-PERCENT": 8, | |
"B-WORK_OF_ART": 9, | |
"B-LOC": 10, | |
"B-TIME": 11, | |
"B-QUANTITY": 12, | |
"B-FAC": 13, | |
"B-EVENT": 14, | |
"B-ORDINAL": 15, | |
"B-PRODUCT": 16, | |
"B-LAW": 17, | |
"B-LANGUAGE": 18, | |
"I-ORG": 19, | |
"I-DATE": 20, | |
"I-PERSON": 21, | |
"I-GPE": 22, | |
"I-MONEY": 23, | |
"I-CARDINAL": 24, | |
"I-NORP": 25, | |
"I-PERCENT": 26, | |
"I-WORK_OF_ART": 27, | |
"I-LOC": 28, | |
"I-TIME": 29, | |
"I-QUANTITY": 30, | |
"I-FAC": 31, | |
"I-EVENT": 32, | |
"I-ORDINAL": 33, | |
"I-PRODUCT": 34, | |
"I-LAW": 35, | |
"I-LANGUAGE": 36, | |
"L-ORG": 37, | |
"L-DATE": 38, | |
"L-PERSON": 39, | |
"L-GPE": 40, | |
"L-MONEY": 41, | |
"L-CARDINAL": 42, | |
"L-NORP": 43, | |
"L-PERCENT": 44, | |
"L-WORK_OF_ART": 45, | |
"L-LOC": 46, | |
"L-TIME": 47, | |
"L-QUANTITY": 48, | |
"L-FAC": 49, | |
"L-EVENT": 50, | |
"L-ORDINAL": 51, | |
"L-PRODUCT": 52, | |
"L-LAW": 53, | |
"L-LANGUAGE": 54, | |
"U-ORG": 55, | |
"U-DATE": 56, | |
"U-PERSON": 57, | |
"U-GPE": 58, | |
"U-MONEY": 59, | |
"U-CARDINAL": 60, | |
"U-NORP": 61, | |
"U-PERCENT": 62, | |
"U-WORK_OF_ART": 63, | |
"U-LOC": 64, | |
"U-TIME": 65, | |
"U-QUANTITY": 66, | |
"U-FAC": 67, | |
"U-EVENT": 68, | |
"U-ORDINAL": 69, | |
"U-PRODUCT": 70, | |
"U-LAW": 71, | |
"U-LANGUAGE": 72, | |
"O": 73, | |
} | |
class HMNetModel(SummModel): | |
# static variables | |
model_name = "HMNET" | |
is_extractive = False | |
is_neural = True | |
is_dialogue_based = True | |
def __init__( | |
self, | |
min_gen_length: int = 10, | |
max_gen_length: int = 300, | |
beam_width: int = 6, | |
**kwargs, | |
): | |
""" | |
Create a summarization model with HMNet backbone. In the default setting, the inference speed will be | |
10s/sample (on one GPU), however, if one can tune these three parameters properly, e.g. min_gen_length=10, | |
max_gen_length=100, and beam_width=2, the inference speed will increase to 2s/sample (on one GPU). | |
Args: | |
min_gen_length (int): minimum generation length of the decoder | |
max_gen_length (int): maximum generation length of the decoder | |
beam_width (int): width of the beam when doing beam search in the decoding process | |
kwargs: the other valid parameters. The valid parameters can be found in | |
model/dialogue/hmnet/config/dialogue.conf . You can use either lower case or upper case for parameter | |
name. The valid parameter name is one of the following args, however, we do not encourage you to modify | |
them, since some unexpected, untested errors might be triggered: | |
['MODEL', 'TASK', 'CRITERION', 'SEED', 'MAX_NUM_EPOCHS', 'EVAL_PER_UPDATE_NUM' | |
, 'UPDATES_PER_EPOCH', 'OPTIMIZER', 'START_LEARNING_RATE', 'LR_SCHEDULER', 'WARMUP_STEPS', | |
'WARMUP_INIT_LR', 'WARMUP_END_LR', 'GRADIENT_ACCUMULATE_STEP', 'GRAD_CLIPPING', 'USE_REL_DATA_PATH', | |
'TRAIN_FILE', 'DEV_FILE', 'TEST_FILE', 'ROLE_DICT_FILE', 'MINI_BATCH', 'MAX_PADDING_RATIO', | |
'BATCH_READ_AHEAD', 'DOC_SHUFFLE_BUF_SIZE', 'SAMPLE_SHUFFLE_BUFFER_SIZE', 'BATCH_SHUFFLE_BUFFER_SIZE', | |
'MAX_TRANSCRIPT_WORD', 'MAX_SENT_LEN', 'MAX_SENT_NUM', 'DROPOUT', 'VOCAB_DIM', 'ROLE_SIZE', 'ROLE_DIM', | |
'POS_DIM', 'ENT_DIM', 'USE_ROLE', 'USE_POSENT', 'USE_BOS_TOKEN', 'USE_EOS_TOKEN', | |
'TRANSFORMER_EMBED_DROPOUT', 'TRANSFORMER_RESIDUAL_DROPOUT', 'TRANSFORMER_ATTENTION_DROPOUT', | |
'TRANSFORMER_LAYER', 'TRANSFORMER_HEAD', 'TRANSFORMER_POS_DISCOUNT', 'PRE_TOKENIZER', | |
'PRE_TOKENIZER_PATH', 'PYLEARN_MODEL', 'EXTRA_IDS', 'BEAM_WIDTH', 'EVAL_TOKENIZED', 'EVAL_LOWERCASE', | |
'MAX_GEN_LENGTH', 'MIN_GEN_LENGTH', 'NO_REPEAT_NGRAM_SIZE'] | |
Return an instance of HMNet model for dialogue summarization. | |
""" | |
super(HMNetModel, self).__init__() | |
self.root_path = self._get_root() | |
# we leave the most influential params with prompt and the others as hidden kwargs | |
kwargs["MIN_GEN_LENGTH"] = min_gen_length | |
kwargs["MAX_GEN_LENGTH"] = max_gen_length | |
kwargs["BEAM_WIDTH"] = beam_width | |
self.opt = self._parse_args(kwargs) | |
self.model = HMNetTrainer(self.opt) | |
def _get_root(self): | |
root_path = os.getcwd() | |
while "model" not in os.listdir(root_path): | |
root_path = os.path.dirname(root_path) | |
root_path = os.path.join(root_path, "model/dialogue") | |
return root_path | |
def _parse_args(self, kwargs): | |
parser = argparse.ArgumentParser( | |
description="HMNet: Pretrain or fine-tune models for HMNet model." | |
) | |
parser.add_argument( | |
"--command", default="evaluate", help="Command: train/evaluate" | |
) | |
parser.add_argument( | |
"--conf_file", | |
default=os.path.join(self.root_path, "hmnet/config/dialogue.conf"), | |
help="Path to the BigLearn conf file.", | |
) | |
parser.add_argument( | |
"--PYLEARN_MODEL", help="Overrides this option from the conf file." | |
) | |
parser.add_argument( | |
"--master_port", help="Overrides this option default", default=None | |
) | |
parser.add_argument("--cluster", help="local, philly or aml", default="local") | |
parser.add_argument( | |
"--dist_init_path", help="Distributed init path for AML", default="./tmp" | |
) | |
parser.add_argument( | |
"--fp16", | |
action="store_true", | |
help="Whether to use 16-bit float precision instead of 32-bit", | |
) | |
parser.add_argument( | |
"--fp16_opt_level", | |
type=str, | |
default="O1", | |
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." | |
"See details at https://nvidia.github.io/apex/amp.html", | |
) | |
parser.add_argument("--no_cuda", action="store_true", help="Disable cuda.") | |
parser.add_argument( | |
"--config_overrides", | |
help="Override parameters on config, VAR=val;VAR=val;...", | |
) | |
cmdline_args = parser.parse_args() | |
command = cmdline_args.command | |
conf_file = cmdline_args.conf_file | |
conf_args = Arguments(conf_file) | |
opt = conf_args.readArguments() | |
if cmdline_args.config_overrides: | |
for config_override in cmdline_args.config_overrides.split(";"): | |
config_override = config_override.strip() | |
if config_override: | |
var_val = config_override.split("=") | |
assert ( | |
len(var_val) == 2 | |
), f"Config override '{var_val}' does not have the form 'VAR=val'" | |
conf_args.add_opt(opt, var_val[0], var_val[1], force_override=True) | |
opt["cuda"] = torch.cuda.is_available() and not cmdline_args.no_cuda | |
opt["confFile"] = conf_file | |
if "datadir" not in opt: | |
opt["datadir"] = os.path.dirname( | |
conf_file | |
) # conf_file specifies where the data folder is | |
opt["basename"] = os.path.basename( | |
conf_file | |
) # conf_file specifies where the name of save folder is | |
opt["command"] = command | |
# combine cmdline_args into opt dictionary | |
for key, val in cmdline_args.__dict__.items(): | |
# if val is not None and key not in ['command', 'conf_file']: | |
if val is not None: | |
opt[key] = val | |
# combine kwargs into opt dictionary (we allow lower case) | |
for key, val in kwargs.items(): | |
valid_keys = [x for x in opt.keys() if x.upper() == x] | |
if key.upper() not in valid_keys: | |
print("WARNING: {} is not a valid key in HMNet.".format(key)) | |
print("The valid keys are:", valid_keys) | |
continue | |
if val is not None: | |
opt[key.upper()] = val | |
return opt | |
def summarize(self, corpus, queries=None): | |
print(f"HMNet model: processing document of {corpus.__len__()} samples") | |
# transform the original dataset to "dialogue" input | |
# we only use test set path for evaluation | |
data_folder = os.path.join( | |
os.path.dirname(self.opt["datadir"]), | |
"ExampleRawData/meeting_summarization/AMI_proprec/test", | |
) | |
self._create_datafolder(data_folder) | |
self._preprocess(corpus, data_folder) | |
# return self.model.eval() | |
results = self._evaluate() | |
return results | |
def _evaluate(self): | |
if self.opt["rank"] == 0: | |
self.model.log("-----------------------------------------------") | |
self.model.log("Evaluating model ... ") | |
self.model.set_up_model() | |
eval_dataset = "test" | |
batch_generator_eval = self.model.get_batch_generator(eval_dataset) | |
predictions = self._eval_batches( | |
self.model.module, batch_generator_eval, self.model.saveFolder, eval_dataset | |
) | |
return predictions | |
def _eval_batches(self, module, dev_batches, save_folder, label=""): | |
max_sent_len = int(self.opt["MAX_GEN_LENGTH"]) | |
print("Decoding current model ... \nSaving folder is {}".format(save_folder)) | |
print("Each sample will cost about 10 second.") | |
import time | |
start_time = time.time() | |
predictions = [] # prediction of tokens from model | |
if not isinstance(module.tokenizer, list): | |
decoder_tokenizer = module.tokenizer | |
elif len(module.tokenizer) == 1: | |
decoder_tokenizer = module.tokenizer[0] | |
elif len(module.tokenizer) == 2: | |
decoder_tokenizer = module.tokenizer[1] | |
else: | |
assert False, "len(module.tokenizer) > 2" | |
with torch.no_grad(): | |
for j, dev_batch in enumerate(dev_batches): | |
for b in dev_batch: | |
if torch.is_tensor(dev_batch[b]): | |
dev_batch[b] = dev_batch[b].to(self.opt["device"]) | |
beam_search_res = module( | |
dev_batch, beam_search=True, max_sent_len=max_sent_len | |
) | |
pred = [ | |
[t[0] for t in x] if len(x) > 0 else [[]] for x in beam_search_res | |
] | |
predictions.extend( | |
[ | |
[ | |
self._convert_tokens_to_string(decoder_tokenizer, tt) | |
for tt in t | |
] | |
for t in pred | |
] | |
) | |
if ( | |
"DEBUG" in self.opt and j >= 10 | |
) or j >= self.model.task.evaluator.eval_batches_num: | |
# in debug mode (decode first 10 batches) ortherwise decode first self.eval_batches_num bathes | |
break | |
top1_predictions = [x[0] for x in predictions] | |
print("Total time for inference:", time.time() - start_time) | |
return top1_predictions | |
def _convert_tokens_to_string(self, tokenizer, tokens): | |
if "EVAL_TOKENIZED" in self.opt: | |
tokens = [t for t in tokens if t not in tokenizer.all_special_tokens] | |
if "EVAL_LOWERCASE" in self.opt: | |
tokens = [t.lower() for t in tokens] | |
if "EVAL_TOKENIZED" in self.opt: | |
return " ".join(tokens) | |
else: | |
return tokenizer.decode( | |
tokenizer.convert_tokens_to_ids(tokens), skip_special_tokens=True | |
) | |
def _preprocess(self, corpus, test_path): | |
samples = [] | |
for i, sample in enumerate(corpus): | |
new_sample = {"id": i, "meeting": [], "summary": []} | |
if isinstance(sample, str): | |
raise RuntimeError( | |
"Error: the input of HMNet should be dialogues, rather than documents." | |
) | |
# add all the turns one by one | |
for turn in sample: | |
turn = [x.strip() for x in turn.split(":")] | |
if len(turn) < 2: | |
continue | |
tokenized_turn = nlp(turn[1]) | |
# In case we can't find proper entity in move_names | |
ent_id = [] | |
pos_id = [] | |
for token in tokenized_turn: | |
ent = ( | |
token.ent_iob_ + "-" + token.ent_type_ | |
if token.ent_iob_ != "O" | |
else "O" | |
) | |
ent_id.append(ENT[ent] if ent in ENT else ENT[""]) | |
pos = token.tag_ | |
pos_id.append(POS[pos] if pos in POS else POS[""]) | |
new_sample["meeting"].append( | |
{ | |
"speaker": turn[0], | |
"role": "", | |
"utt": { | |
"word": [str(token) for token in tokenized_turn], | |
"pos_id": pos_id, | |
"ent_id": ent_id, | |
}, | |
} | |
) | |
new_sample["summary"].append( | |
"This is a dummy summary. HMNet will filter out the sample w/o summary!" | |
) | |
samples.append(new_sample) | |
# save to the gzip | |
file_path = os.path.join(test_path, "split_{}.jsonl.gz".format(i)) | |
with gzip.open(file_path, "wt", encoding="utf-8") as file: | |
file.write(json.dumps(new_sample)) | |
def _clean_datafolder(self, data_folder): | |
for name in os.listdir(data_folder): | |
name = os.path.join(data_folder, name) | |
if ".gz" in name: | |
os.remove(name) | |
def _create_datafolder(self, data_folder): | |
if os.path.exists(data_folder): | |
self._clean_datafolder(data_folder) | |
else: | |
os.makedirs(data_folder) | |
with open( | |
os.path.join(os.path.dirname(data_folder), "test_ami.json"), | |
"w", | |
encoding="utf-8", | |
) as file: | |
json.dump( | |
[ | |
{ | |
"source": { | |
"dataset": "../ExampleRawData/meeting_summarization/AMI_proprec/test/" | |
}, | |
"task": "meeting", | |
"name": "ami", | |
} | |
], | |
file, | |
) | |
with open( | |
os.path.join( | |
os.path.dirname(os.path.dirname(data_folder)), "role_dict_ext.json" | |
), | |
"w", | |
) as file: | |
json.dump({}, file) | |
def show_capability(cls) -> None: | |
basic_description = cls.generate_basic_description() | |
more_details = ( | |
"A HMNet model finetuned on CNN-DM dataset for summarization.\n\n" | |
"Strengths:\n - High performance on dialogue summarization task.\n\n" | |
"Weaknesses:\n - Not suitable for datasets other than dialogues.\n\n" | |
"Initialization arguments:\n " | |
" - `corpus`: Unlabelled corpus of documents.\n" | |
) | |
print(f"{basic_description} \n {'#' * 20} \n {more_details}") | |