remyxai-classifier-labeler / classifier_labeler.py
smellslikeml
first commit
7f9e237
raw
history blame contribute delete
895 Bytes
from transformers import Tool
from remyxai.utils import labeler
class RemyxClassifier(Tool):
name = "custom_classifier"
description = (
"""
This is a tool that fetches or trains a custom classifier
and returns a json file containing the labels of the images in a directory.
It takes an image directory and a comma separated string of labels
as input, and returns the json file containing the predicted labels
for each image as output
"""
)
inputs = ["text", "text"]
outputs = ["text"]
def __call__(self, image_directory: str, labels: str):
labels_list = [l.strip() for l in labels.split(",")]
model_name = "labeler_{}".format("_".join(labels_list))
result = labeler(labels=labels_list,
image_dir=image_directory,
model_name=model_name)
return result