# This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from typing import List, Dict import os import time import logging import json import gzip from dataclasses import dataclass, field import torch from torch import Tensor as T from transformers import PreTrainedTokenizer logger = logging.getLogger() @dataclass class Mention: cui: str start: int end: int text: str types: str @dataclass class ContextualMention: mention: str cuis: List[str] ctx_l: str ctx_r: str def to_tensor(self, tokenizer: PreTrainedTokenizer, max_length: int) -> T: ctx_l_ids = tokenizer.encode( text=self.ctx_l, add_special_tokens=False, max_length=max_length, truncation=True, ) ctx_r_ids = tokenizer.encode( text=self.ctx_r, add_special_tokens=False, max_length=max_length, truncation=True, ) mention_ids = tokenizer.encode( text=self.mention, add_special_tokens=False, max_length=max_length, truncation=True, ) # Concatenate context and mention to the max length. token_ids = tokenizer.convert_tokens_to_ids(['']) + mention_ids \ + tokenizer.convert_tokens_to_ids(['']) max_ctx_len = max_length - len(token_ids) - 2 # Exclude [CLS] and [SEP] max_ctx_l_len = max_ctx_len // 2 max_ctx_r_len = max_ctx_len - max_ctx_l_len if len(ctx_l_ids) < max_ctx_l_len and len(ctx_r_ids) < max_ctx_r_len: token_ids = ctx_l_ids + token_ids + ctx_r_ids elif len(ctx_l_ids) >= max_ctx_l_len and len(ctx_r_ids) >= max_ctx_r_len: token_ids = ctx_l_ids[-max_ctx_l_len:] + token_ids \ + ctx_r_ids[:max_ctx_r_len] elif len(ctx_l_ids) >= max_ctx_l_len: ctx_l_len = max_ctx_len - len(ctx_r_ids) token_ids = ctx_l_ids[-ctx_l_len:] + token_ids + ctx_r_ids else: ctx_r_len = max_ctx_len - len(ctx_l_ids) token_ids = ctx_l_ids + token_ids + ctx_r_ids[:ctx_r_len] token_ids = [tokenizer.cls_token_id] + token_ids # The above snippet doesn't guarantee the max length limit. token_ids = token_ids[:max_length - 1] + [tokenizer.sep_token_id] if len(token_ids) < max_length: token_ids = token_ids + [tokenizer.pad_token_id] * (max_length - len(token_ids)) return torch.tensor(token_ids) @dataclass class Document: id: str = None title: str = None abstract: str = None mentions: List[Mention] = field(default_factory=list) def concatenate_text(self) -> str: return ' '.join([self.title, self.abstract]) @classmethod def from_PubTator(cls, path: str, split_path_prefix: str) -> Dict[str, List]: docs = [] with gzip.open(path, 'rb') as f: for b in f.read().decode().strip().split('\n\n'): d = cls() s = '' for i, ln in enumerate(b.split('\n')): if i == 0: id, type, text = ln.strip().split('|', 2) assert type == 't' d.id, d.title = id, text elif i == 1: id, type, text = ln.strip().split('|', 2) assert type == 'a' assert d.id == id d.abstract = text s = d.concatenate_text() else: items = ln.strip().split('\t') assert d.id == items[0] cui = items[5].split('UMLS:')[-1] assert len(cui) == 8, breakpoint() m = Mention( cui=cui, start=int(items[1]), end=int(items[2]), text=items[3], types=items[4].split(',') ) assert m.text == s[m.start: m.end] d.mentions.append(m) docs.append(d) dataset = split_dataset(docs, split_path_prefix) print_dataset_stats(dataset) return dataset def to_contextual_mentions(self, max_length: int = 64) -> List[ContextualMention]: text = self.concatenate_text() mentions = [] for m in self.mentions: assert m.text == text[m.start:m.end] # Context ctx_l, ctx_r = text[:m.start].strip().split(), text[m.end:].strip().split() ctx_l, ctx_r = ' '.join(ctx_l[-max_length:]), ' '.join(ctx_r[:max_length]) cm = ContextualMention( mention=m.text, cuis=[m.cui], ctx_l=ctx_l, ctx_r=ctx_r, ) mentions.append(cm) return mentions def split_dataset(docs: List, split_path_prefix: str) -> Dict[str, List]: split_kv = {'train': 'trng', 'dev': 'dev', 'test': 'test'} id_to_split = {} dataset = {} for k, v in split_kv.items(): dataset[k] = [] path = split_path_prefix + v + '.txt' for i in open(path, encoding='utf-8').read().strip().split('\n'): assert i not in id_to_split, breakpoint() id_to_split[i] = k for doc in docs: split = id_to_split[doc.id] dataset[split].append(doc) return dataset def print_dataset_stats(dataset: Dict[str, List[Document]]) -> None: all_docs = [] for v in dataset.values(): all_docs.extend(v) for split, docs in {'all': all_docs, **dataset}.items(): logger.info(f"***** {split} *****") logger.info(f"Documents: {len(docs)}") logger.info(f"Mentions: {sum(len(d.mentions) for d in docs)}") cuis = set() for d in docs: for m in d.mentions: cuis.add(m.cui) logger.info(f"Mentioned concepts: {len(cuis)}") class MedMentionsDataset(torch.utils.data.Dataset): def __init__(self, dataset_path: str, split: str) -> None: super().__init__() self.dataset_path = dataset_path self.docs = Document.from_PubTator( path=os.path.join(self.dataset_path, 'corpus_pubtator.txt.gz'), split_path_prefix=os.path.join(self.dataset_path, 'corpus_pubtator_pmids_') )[split] self.mentions = [] self.name_to_cuis = {} self._post_init() def _post_init(self): for d in self.docs: self.mentions.extend(d.to_contextual_mentions()) for m in self.mentions: if m.mention not in self.name_to_cuis: self.name_to_cuis[m.mention] = set() self.name_to_cuis[m.mention].update(m.cuis) def __getitem__(self, index: int) -> ContextualMention: return self.mentions[index] def __len__(self) -> int: return len(self.mentions) class PreprocessedDataset(torch.utils.data.Dataset): def __init__(self, dataset_path: str) -> None: super().__init__() self.file = dataset_path self.data = [] self.load_data() def load_data(self) -> None: with open(self.file, encoding='utf-8') as f: logger.info("Reading file %s" % self.file) for ln in f: if ln.strip(): self.data.append(json.loads(ln)) logger.info("Loaded data size: {}".format(len(self.data))) def __getitem__(self, index: int) -> ContextualMention: d = self.data[index] return ContextualMention( ctx_l=d['context_left'], ctx_r=d['context_right'], mention=d['mention'], cuis=d['cuis'], ) def __len__(self) -> int: return len(self.data) def generate_vectors( encoder: torch.nn.Module, tokenizer: PreTrainedTokenizer, dataset: torch.utils.data.Dataset, batch_size: int, max_length: int, is_prototype: bool = False, ): n = len(dataset) total = 0 results = [] start_time = time.time() logger.info("Start encoding...") for i, batch_start in enumerate(range(0, n, batch_size)): batch = [dataset[i] for i in range(batch_start, min(n, batch_start + batch_size))] batch_token_tensors = [m.to_tensor(tokenizer, max_length) for m in batch] ids_batch = torch.stack(batch_token_tensors, dim=0).cuda() seg_batch = torch.zeros_like(ids_batch) attn_mask = (ids_batch != tokenizer.pad_token_id) with torch.inference_mode(): out = encoder( input_ids=ids_batch, token_type_ids=seg_batch, attention_mask=attn_mask ) out = out[0][:, 0, :] out = out.cpu() num_mentions = out.size(0) total += num_mentions if is_prototype: meta_batch = [{'cuis': m.cuis} for m in batch] assert len(meta_batch) == num_mentions results.extend([(meta_batch[i], out[i].view(-1).numpy()) for i in range(num_mentions)]) else: results.extend(out.cpu().split(1, dim=0)) if (i + 1) % 10 == 0: eta = (n - total) * (time.time() - start_time) / 60 / total logger.info(f"Batch={i + 1}, Encoded mentions={total}, ETA={eta:.1f}m") assert len(results) == n logger.info(f"Total encoded mentions={n}") if not is_prototype: results = torch.cat(results, dim=0) return results