metric / inference.py
Elron's picture
Upload inference.py with huggingface_hub
fe8dcb5 verified
raw
history blame
No virus
775 Bytes
import abc
from .artifact import Artifact
class InferenceEngine(abc.ABC, Artifact):
"""Abstract base class for inference."""
@abc.abstractmethod
def infer(self, dataset):
"""Perform inference on the input dataset."""
pass
class HFPipelineBasedInferenceEngine(Artifact):
"""Abstract base class for inference."""
model_name: str
max_new_tokens: int
def prepare(self):
from transformers import pipeline
self.model = pipeline(model=self.model_name)
def infer(self, dataset):
return [
output["generated_text"]
for output in self.model(
[instance["source"] for instance in dataset],
max_new_tokens=self.max_new_tokens,
)
]