|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import torch.multiprocessing as mp |
|
from apex.transformer import parallel_state |
|
from omegaconf import OmegaConf |
|
from omegaconf.omegaconf import open_dict |
|
from pytorch_lightning.trainer.trainer import Trainer |
|
|
|
from nemo.collections.nlp.models.language_modeling.megatron_gpt_prompt_learning_model import ( |
|
MegatronGPTPromptLearningModel, |
|
) |
|
from nemo.collections.nlp.modules.common.transformer.text_generation import LengthParam, SamplingParam |
|
from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy |
|
from nemo.core.config import hydra_runner |
|
|
|
mp.set_start_method("spawn", force=True) |
|
|
|
""" |
|
This is the script to run GPT text generation. |
|
a. run greedy inference from a p-tuned/prompt-tuned model's nemo file: |
|
python megatron_gpt_prompt_learning_eval.py \ |
|
virtual_prompt_model_file=PATH_TO_NEMO_PROMPT_LEARNING_MODEL_FILE \ |
|
gpt_model_file=PATH_TO_FROZEN_GPT_MODEL_FILE \ |
|
inference.greedy=True \ |
|
inference.add_BOS=False \ |
|
trainer.devices=1 \ |
|
trainer.num_nodes=1 \ |
|
tensor_model_parallel_size=1 \ |
|
pipeline_model_parallel_size=1 \ |
|
pred_file_path=PATH_WHERE_PRED_TEXT_FILE_WILL_BE_SAVED \ |
|
data_paths=[path/to/dataset1.jsonl, path/to/dataset2.jsonl] |
|
|
|
virtual_prompt_model_file should be a path to a .nemo file saved after p-tuning/prompt tuning and model file |
|
is still the path to the gpt model's .nemo file. |
|
|
|
data_paths should be a list of .json or .jsonl files containing json objects similar to the ones |
|
used during prompt learning. They should have keys that match the fields specified in the prompt template. |
|
Fields can be dropped from the prompt dict and their corresponding section of the prompt template will |
|
be automatically removed. |
|
|
|
For example, say the prompt template during p-tuning/prompt-tuning looked like: |
|
|
|
'<|VIRTUAL_PROMPT_0|> Context: {context} Question: {question} Answer: {answer}' |
|
|
|
but you don't want to include the answer field during inference. Just don't |
|
include the answer field in the prompt dict like below: |
|
|
|
{"taskname": "squad", "context": "some paragraph", "question": "question related to paragraph"} |
|
{"taskname": "squad", "context": "another paragraph", "question": "a different question related to paragraph"} |
|
|
|
And the dataset class will automatically format your input to have the form: |
|
|
|
[ |
|
'<|VIRTUAL_PROMPT_0|> Context: some paragraph Question: question related to paragraph Answer:', |
|
'<|VIRTUAL_PROMPT_0|> Context: another paragraph Question: a different question related to paragraph Answer:' |
|
] |
|
|
|
Similarly for other senarios, just add virtual_prompt_model_file=PATH_TO_NEMO_PROMPT_LEARNING_MODEL_FILE if you're using a |
|
p-tuned/prompt-tuned model. |
|
""" |
|
|
|
|
|
@hydra_runner(config_path="conf", config_name="megatron_gpt_prompt_learning_inference") |
|
def main(cfg) -> None: |
|
if not torch.cuda.is_available(): |
|
raise EnvironmentError("GPU is needed for the inference") |
|
|
|
|
|
trainer = Trainer(strategy=NLPDDPStrategy(), **cfg.trainer) |
|
assert ( |
|
cfg.trainer.devices * cfg.trainer.num_nodes |
|
== cfg.tensor_model_parallel_size * cfg.pipeline_model_parallel_size |
|
), "devices * num_nodes should equal tensor_model_parallel_size * pipeline_model_parallel_size" |
|
|
|
|
|
prompt_learning_cfg = MegatronGPTPromptLearningModel.restore_from( |
|
cfg.virtual_prompt_model_file, trainer=trainer, return_config=True, |
|
) |
|
if cfg.get("gpt_model_file"): |
|
with open_dict(prompt_learning_cfg): |
|
prompt_learning_cfg.language_model_path = cfg.gpt_model_file |
|
prompt_learning_cfg.sequence_parallel = False |
|
prompt_learning_cfg.activations_checkpoint_method = None |
|
prompt_learning_cfg.activations_checkpoint_granularity = None |
|
prompt_learning_cfg.activations_checkpoint_num_layers = None |
|
|
|
|
|
|
|
model = MegatronGPTPromptLearningModel.restore_from( |
|
restore_path=cfg.virtual_prompt_model_file, trainer=trainer, override_config_path=prompt_learning_cfg, |
|
) |
|
model.freeze() |
|
|
|
|
|
try: |
|
model.frozen_model.model.language_model.encoder.activations_checkpoint_method = None |
|
except AttributeError: |
|
pass |
|
|
|
|
|
if parallel_state.is_unitialized(): |
|
|
|
def placeholder(): |
|
return |
|
|
|
if model.trainer.strategy.launcher is not None: |
|
model.trainer.strategy.launcher.launch(placeholder, trainer=model.trainer) |
|
model.trainer.strategy.setup_environment() |
|
|
|
length_params: LengthParam = { |
|
"max_length": cfg.inference.tokens_to_generate, |
|
"min_length": cfg.inference.min_tokens_to_generate, |
|
} |
|
|
|
sampling_params: SamplingParam = { |
|
"use_greedy": cfg.inference.greedy, |
|
"temperature": cfg.inference.temperature, |
|
"top_k": cfg.inference.top_k, |
|
"top_p": cfg.inference.top_p, |
|
"repetition_penalty": cfg.inference.repetition_penalty, |
|
"add_BOS": cfg.inference.add_BOS, |
|
"all_probs": cfg.inference.all_probs, |
|
"compute_logprob": cfg.inference.compute_logprob, |
|
} |
|
|
|
max_input_length = model.frozen_model.cfg.encoder_seq_length - length_params["max_length"] |
|
|
|
_, dataloader = model.build_virtual_prompt_dataset( |
|
data=cfg.data_paths, |
|
batch_size=cfg.inference.get('batch_size', 1), |
|
max_seq_length=max_input_length, |
|
min_seq_length=model.cfg.data.get('min_seq_length', 1), |
|
add_bos=sampling_params["add_BOS"], |
|
add_eos=False, |
|
for_train=False, |
|
tokens_to_generate=length_params["max_length"], |
|
drop_last=False, |
|
shuffle=False, |
|
num_workers=cfg.get("num_workers", 1), |
|
) |
|
|
|
config = OmegaConf.to_container(cfg.inference) |
|
model.set_inference_config(config) |
|
response = trainer.predict(model, dataloader) |
|
|
|
print("***************************") |
|
with open(cfg.pred_file_path, "w", encoding="utf-8") as pred_file: |
|
for i in range(len(response)): |
|
for sent in response[i]["sentences"]: |
|
sent = sent.strip() |
|
sent = sent.replace("\n", " ") |
|
pred_file.write(sent + "\n") |
|
print(f"Inference Complete, prediction file saved at {cfg.pred_file_path}") |
|
print("***************************") |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|