BiomedNLP-KRISSBERT-PubMed-UMLS-EL / usage /generate_prototypes.py
shengz's picture
Add the example usage.
e3ef0b9
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Command line tool that produces embeddings for a large set of entity mentions
based on the pretrained mention encoder.
"""
import logging
import os
import pathlib
import pickle
import hydra
from omegaconf import DictConfig, OmegaConf
from transformers import AutoConfig, AutoTokenizer, AutoModel
from utils import generate_vectors
# Setup logger
logger = logging.getLogger()
logger.setLevel(logging.INFO)
log_formatter = logging.Formatter(
"[%(thread)s] %(asctime)s [%(levelname)s] %(name)s: %(message)s"
)
console = logging.StreamHandler()
console.setFormatter(log_formatter)
logger.addHandler(console)
@hydra.main(config_path="conf", config_name="generate_prototypes", version_base=None)
def main(cfg: DictConfig):
logger.info("Configuration:")
logger.info("%s", OmegaConf.to_yaml(cfg))
config = AutoConfig.from_pretrained(cfg.model_name_or_path)
tokenizer = AutoTokenizer.from_pretrained(
cfg.model_name_or_path,
use_fast=True,
)
encoder = AutoModel.from_pretrained(
cfg.model_name_or_path,
config=config
)
encoder.cuda()
encoder.eval()
ds = hydra.utils.instantiate(cfg.train_data)
data = generate_vectors(encoder, tokenizer, ds, cfg.batch_size, cfg.max_length, is_prototype=True)
pathlib.Path(os.path.dirname(cfg.output_prototypes)).mkdir(parents=True, exist_ok=True)
logger.info("Writing results to %s" % cfg.output_prototypes)
with open(cfg.output_prototypes, mode="wb") as f:
pickle.dump(data, f)
with open(cfg.output_name_cuis, 'w') as f:
for name, cuis in ds.name_to_cuis.items():
f.write('|'.join(cuis) + '||' + name + '\n')
logger.info("Total data processed %d. Written to %s", len(data), cfg.output_prototypes)
if __name__ == "__main__":
main()