Spaces:
Runtime error
Runtime error
class FondantInferenceModel: | |
"""FondantInferenceModel class that abstracts the model loading and inference. | |
User needs to implement an inference, pre/postprocess step and pass the class to the FondantInferenceComponent. | |
The FondantInferenceComponent will then load the model and prepare it for inference. | |
The examples folder can then show examples for a pytorch / huggingface / tensorflow / ... model. | |
""" | |
def __init__(self, device: str = "cpu"): | |
self.device = device | |
# load model | |
self.model = self.load_model() | |
# set model to eval mode | |
self.eval() | |
def load_model(self): | |
# load model | |
... | |
def eval(self): | |
# prepare for inference | |
self.model = self.model.eval() | |
self.model = self.model.to(self.device) | |
def preprocess(self, input): | |
# preprocess input | |
... | |
def postprocess(self, output): | |
# postprocess output | |
... | |
def __call__(self, *args, **kwargs): | |
processed_inputs = self.preprocess(*args, **kwargs) | |
outputs = self.model(*processed_inputs) | |
processed_outputs = self.postprocess(outputs) | |
return processed_outputs | |
class FondantInferenceComponent(FondantTransformComponent, FondantInferenceModel): | |
# loads the model and prepares it for inference | |
def transform( | |
self, args: argparse.Namespace, dataframe: dd.DataFrame | |
) -> dd.DataFrame: | |
# by using the InferenceComponent, the model is automatically loaded and prepared for inference | |
# you just need to call the infer method | |
# the self.infer method calls the model.__call__ method of the FondantInferenceModel | |
output = self.infer(args.image) | |