Spaces:
Runtime error
Runtime error
File size: 1,772 Bytes
1cdf8e3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
class FondantInferenceModel:
"""FondantInferenceModel class that abstracts the model loading and inference.
User needs to implement an inference, pre/postprocess step and pass the class to the FondantInferenceComponent.
The FondantInferenceComponent will then load the model and prepare it for inference.
The examples folder can then show examples for a pytorch / huggingface / tensorflow / ... model.
"""
def __init__(self, device: str = "cpu"):
self.device = device
# load model
self.model = self.load_model()
# set model to eval mode
self.eval()
def load_model(self):
# load model
...
def eval(self):
# prepare for inference
self.model = self.model.eval()
self.model = self.model.to(self.device)
def preprocess(self, input):
# preprocess input
...
def postprocess(self, output):
# postprocess output
...
def __call__(self, *args, **kwargs):
processed_inputs = self.preprocess(*args, **kwargs)
outputs = self.model(*processed_inputs)
processed_outputs = self.postprocess(outputs)
return processed_outputs
class FondantInferenceComponent(FondantTransformComponent, FondantInferenceModel):
# loads the model and prepares it for inference
def transform(
self, args: argparse.Namespace, dataframe: dd.DataFrame
) -> dd.DataFrame:
# by using the InferenceComponent, the model is automatically loaded and prepared for inference
# you just need to call the infer method
# the self.infer method calls the model.__call__ method of the FondantInferenceModel
output = self.infer(args.image)
|