img2art-search / main.py
brunorosilva
deploy: hf
d612317
from img2art_search.models.predict import predict
from img2art_search.models.train import fine_tune_vit
from img2art_search.models.compute_embeddings import create_gallery_embeddings
import gradio as gr
import argparse
def make_interface():
interface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs=gr.Gallery(label="Most similar images", height=256 * 3),
live=True,
)
interface.launch(share=True)
def train(epochs, batch_size):
fine_tune_vit(epochs, batch_size)
def create_gallery(gallery_path):
create_gallery_embeddings(gallery_path)
def main():
parser = argparse.ArgumentParser(description="Train or infer the ViT model for image-to-art search.")
subparsers = parser.add_subparsers(dest="command")
# Subparser for training
train_parser = subparsers.add_parser("train", help="Fine-tune the ViT model")
train_parser.add_argument("--epochs", type=int, default=50, help="Number of training epochs")
train_parser.add_argument("--batch_size", type=int, default=32, help="Batch size for training")
# Subparser for inference
_ = subparsers.add_parser("interface", help="Perform image-to-art search using the fine-tuned model")
create_gallery_parser = subparsers.add_parser("gallery", help="Create new gallery from a path")
create_gallery_parser.add_argument("--gallery_path", type=str, default="data/wikiart")
args = parser.parse_args()
if args.command == "train":
train(args.epochs, args.batch_size)
elif args.command == "interface":
make_interface()
elif args.command == "gallery":
create_gallery(args.gallery_path)
else:
parser.print_help()
if __name__ == "__main__":
main()