NeMo / examples /nlp /language_modeling /megatron_retro_eval.py
camenduru's picture
thanks to NVIDIA ❤
7934b29
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from examples.nlp.language_modeling.megatron_gpt_eval import RequestDataSet
from omegaconf.omegaconf import OmegaConf, open_dict
from pytorch_lightning import Trainer
from torch.utils.data import DataLoader
from nemo.collections.nlp.models.language_modeling.megatron_retrieval_model import MegatronRetrievalModel
from nemo.collections.nlp.modules.common.transformer.text_generation import LengthParam, SamplingParam
from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy, NLPSaveRestoreConnector
from nemo.core.config import hydra_runner
try:
from apex.transformer import parallel_state
HAVE_APEX = True
except (ImportError, ModuleNotFoundError):
HAVE_APEX = False
"""
This is the script to run RETRO Model text generation.
Usage:
Assume the model has TP=1, PP=1
run greedy inference from a nemo file:
python megatron_retro_eval.py \
trainer.devices=1 \
trainer.num_nodes=1 \
trainer.accelerator=gpu \
trainer.precision=16 \
inference.tokens_to_generate=128 \
inference.greedy=True \
retro_model_file=path_to_retro_nemo_file \
tensor_model_parallel_size=1 \
pipeline_model_parallel_size=1 \
retrieval_service.faiss_devices='0' \
retrieval_service.faiss_index=path_to_faiss_index \
retrieval_service.retrieval_index=path_to_retrieval_dataset \
retrieval_service.neighbors=20
"""
@hydra_runner(config_path="conf", config_name="megatron_retro_inference")
def main(cfg) -> None:
trainer = Trainer(strategy=NLPDDPStrategy(), **cfg.trainer)
model_path = cfg.retro_model_file
save_restore_connector = NLPSaveRestoreConnector()
if os.path.isdir(model_path):
save_restore_connector.model_extracted_dir = model_path
model_cfg = MegatronRetrievalModel.restore_from(
model_path, trainer=trainer, return_config=True, save_restore_connector=save_restore_connector,
)
with open_dict(model_cfg):
model_cfg.precision = trainer.precision
model_cfg.sequence_parallel = False
model_cfg.activations_checkpoint_granularity = None
model_cfg.activations_checkpoint_method = None
model = MegatronRetrievalModel.restore_from(
model_path, trainer=trainer, save_restore_connector=save_restore_connector, override_config_path=model_cfg,
)
length_params: LengthParam = {
"max_length": cfg.inference.tokens_to_generate,
"min_length": cfg.inference.min_tokens_to_generate,
}
sampling_params: SamplingParam = {
"use_greedy": cfg.inference.greedy,
"temperature": cfg.inference.temperature,
"top_k": cfg.inference.top_k,
"top_p": cfg.inference.top_p,
"repetition_penalty": cfg.inference.repetition_penalty,
"add_BOS": cfg.inference.add_BOS,
"all_probs": cfg.inference.all_probs,
"compute_logprob": cfg.inference.compute_logprob,
}
# check whether the DDP is initialized
if parallel_state.is_unitialized():
def dummy():
return
if model.trainer.strategy.launcher is not None:
model.trainer.strategy.launcher.launch(dummy, trainer=model.trainer)
model.trainer.strategy.setup_environment()
config = OmegaConf.to_container(cfg.inference)
retrieval_service = OmegaConf.to_container(cfg.retrieval_service)
model.set_inference_config(config, retrieval_service)
if not cfg.use_predict_method:
# First method of running text generation, call model.generate method
response = model.generate(
inputs=OmegaConf.to_container(cfg.prompts),
length_params=length_params,
sampling_params=sampling_params,
strategy=model.inference_strategy,
)
else:
# Second method of running text generation, call trainer.predict
ds = RequestDataSet(OmegaConf.to_container(cfg.prompts))
request_dl = DataLoader(dataset=ds, batch_size=cfg.inference_batch_size)
response = trainer.predict(model, request_dl)
print("***************************")
print(response)
print("***************************")
if __name__ == '__main__':
main()