yerang's picture
Upload 1110 files
e3af00f verified
raw
history blame
4.89 kB
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import typing
import numpy as np
import torch # pytype: disable=import-error
from nemo.collections.nlp.modules.common.transformer.text_generation import ( # pytype: disable=import-error
LengthParam,
OutputType,
SamplingParam,
)
from pytriton.decorators import ConstantPadder, batch, first_value, group_by_values
from pytriton.exceptions import PyTritonInvalidOperationError, PyTritonUnrecoverableError
from pytriton.model_config import Tensor
from helpers import cast_output, typedict2tensor # pytype: disable=import-error # isort:skip
_INPUT_PARAMETERS_NAMES = list(typing.get_type_hints(LengthParam)) + list(typing.get_type_hints(SamplingParam))
class NemoGptCallable:
def __init__(self, *, model_name: str, model):
self.model_name = model_name
self._model = model.cuda()
self._is_prompt_learning_model = hasattr(model, "virtual_prompt_style")
self._text_generate_fn = (
self._model.frozen_model.generate if self._is_prompt_learning_model else self._model.generate
)
self._task_generate_fn = self._model.generate if self._is_prompt_learning_model else None
self.inputs = (
(
Tensor(name="tasks", shape=(1,), dtype=bytes),
Tensor(name="prompts", shape=(1,), dtype=bytes),
)
+ typedict2tensor(LengthParam, overwrite_kwargs={"optional": True}, defaults=None)
+ typedict2tensor(SamplingParam, overwrite_kwargs={"optional": True}, defaults=None)
)
self.outputs = typedict2tensor(OutputType)
self._outputs_dict = {output.name: output for output in self.outputs}
def _format_prompts(
self, tasks: typing.List[str], prompts: typing.List[str]
) -> typing.List[typing.Union[str, typing.Dict[str, str]]]:
formatted_prompts = []
for task_name, prompt in zip(tasks, prompts):
task_template = self._model.task_templates[task_name]
formatted_prompts.append(
{
**{"taskname": task_name},
**dict(zip(task_template["prompt_template_fields"], [prompt])),
}
)
return formatted_prompts
@batch
@group_by_values("tasks", *_INPUT_PARAMETERS_NAMES, pad_fn=ConstantPadder(0))
@first_value(*_INPUT_PARAMETERS_NAMES, strict=False)
def infer(self, **inputs: np.ndarray) -> typing.Dict[str, np.ndarray]:
# Tell other ranks we're doing generate
generate_num = 0
choice = torch.cuda.LongTensor([generate_num])
torch.distributed.broadcast(choice, 0)
def _str_ndarray2list(str_ndarray: np.ndarray) -> typing.List[str]:
str_ndarray = str_ndarray.astype("bytes")
str_ndarray = np.char.decode(str_ndarray, encoding="utf-8")
str_ndarray = str_ndarray.squeeze(axis=-1)
return str_ndarray.tolist()
tasks = _str_ndarray2list(inputs.pop("tasks"))
prompts = _str_ndarray2list(inputs.pop("prompts"))
length_params = LengthParam(**{k: v for k, v in inputs.items() if k in typing.get_type_hints(LengthParam)})
sampling_params = SamplingParam(
**{k: v for k, v in inputs.items() if k in typing.get_type_hints(SamplingParam)}
)
if tasks[0] == "text_generation":
generate_fn = self._text_generate_fn
else:
generate_fn = self._task_generate_fn
if generate_fn is None:
raise PyTritonInvalidOperationError(
f"Model {self.model_name} does not support task {tasks[0]}. "
"Only text_generation task is supported."
)
prompts = self._format_prompts(tasks, prompts)
try:
output: OutputType = generate_fn(
inputs=prompts,
length_params=length_params,
sampling_params=sampling_params,
)
except RuntimeError as e:
raise PyTritonUnrecoverableError("Fatal error occurred - no further inferences possible.") from e
output = {
output_name: cast_output(data, self._outputs_dict[output_name].dtype)
for output_name, data in output.items()
}
return output