from transformers import Tool from remyxai.utils import labeler class RemyxClassifier(Tool): name = "custom_classifier" description = ( """ This is a tool that fetches or trains a custom classifier and returns a json file containing the labels of the images in a directory. It takes an image directory and a comma separated string of labels as input, and returns the json file containing the predicted labels for each image as output """ ) inputs = ["text", "text"] outputs = ["text"] def __call__(self, image_directory: str, labels: str): labels_list = [l.strip() for l in labels.split(",")] model_name = "labeler_{}".format("_".join(labels_list)) result = labeler(labels=labels_list, image_dir=image_directory, model_name=model_name) return result