|
from transformers import Tool |
|
from remyxai.utils import labeler |
|
|
|
class RemyxClassifier(Tool): |
|
name = "custom_classifier" |
|
description = ( |
|
""" |
|
This is a tool that fetches or trains a custom classifier |
|
and returns a json file containing the labels of the images in a directory. |
|
It takes an image directory and a comma separated string of labels |
|
as input, and returns the json file containing the predicted labels |
|
for each image as output |
|
""" |
|
) |
|
|
|
inputs = ["text", "text"] |
|
outputs = ["text"] |
|
|
|
def __call__(self, image_directory: str, labels: str): |
|
labels_list = [l.strip() for l in labels.split(",")] |
|
model_name = "labeler_{}".format("_".join(labels_list)) |
|
result = labeler(labels=labels_list, |
|
image_dir=image_directory, |
|
model_name=model_name) |
|
return result |
|
|