from iSparrow.sparrow_model_base import ModelBase try: import tflite_runtime.interpreter as tflite except ImportError: import tensorflow.lite as tflite from iSparrow import utils from iSparrow import ModelBase import numpy as np from pathlib import Path from scipy.special import softmax class Model(ModelBase): """ Model Implementation of a iSparrow model that uses the google perch tflite model. Args: ModelBase (iSparrow.ModelBase): Model base class that provides the interface through which to interact with iSparrow. """ def __init__(self, model_path: str, num_threads: int = 1, **kwargs): """ __init__ Create a new model instance that uses the google perch tflite converted model. Args: model_path (str): path to where the google perch tflite model is stored num_threads (int, optional): number of threads to use. Defaults to 1. """ labels_path = str(Path(model_path) / "labels.txt") model_path = str(Path(model_path) / "model.tflite") # base class loads the model and labels super().__init__( "google_perch_lite", model_path, labels_path, num_threads=num_threads, **kwargs ) # store input and output index to not have to retrieve them each time an inference is made input_details = self.model.get_input_details() output_details = self.model.get_output_details() self.input_layer_index = input_details[0]["index"] self.output_layer_index = output_details[1]["index"] def predict(self, sample: np.array) -> np.array: """ predict Make inference about the bird species for the preprocessed data passed to this function as arguments. Args: data (np.array): list of preprocessed data chunks Returns: numpy array: array of probabilities per class """ data = np.array([sample], dtype="float32") self.model.resize_tensor_input( self.input_layer_index, [len(data), *data[0].shape] ) self.model.allocate_tensors() # Make a prediction self.model.set_tensor(self.input_layer_index, data) self.model.invoke() logits = self.model.get_tensor(self.output_layer_index) confidence = softmax(logits) return confidence @classmethod def from_cfg(cls, sparrow_folder: str, cfg: dict): """ from_cfg Create a new instance from a dictionary containing keyword arguments. Usually loaded from a config file. Args: sparrow_dir (str): Installation directory of the Sparrow package cfg (dict): Dictionary containing the keyword arguments Returns: Model: New model instance created with the supplied kwargs. """ cfg["model_name"] = str( Path(sparrow_folder) / Path("models") / cfg["model_name"] ) return cls(**cfg)