Spaces:
Sleeping
Sleeping
| from img2art_search.models.predict import predict | |
| from img2art_search.models.train import fine_tune_vit | |
| from img2art_search.models.compute_embeddings import create_gallery_embeddings | |
| import gradio as gr | |
| import argparse | |
| def make_interface(): | |
| interface = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(type="pil"), | |
| outputs=gr.Gallery(label="Most similar images", height=256 * 3), | |
| allow_flagging='never', | |
| # live=True, | |
| ) | |
| interface.launch(share=False) | |
| def train(epochs, batch_size): | |
| fine_tune_vit(epochs, batch_size) | |
| def create_gallery(gallery_path): | |
| create_gallery_embeddings(gallery_path) | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Train or infer the ViT model for image-to-art search.") | |
| subparsers = parser.add_subparsers(dest="command") | |
| # Subparser for training | |
| train_parser = subparsers.add_parser("train", help="Fine-tune the ViT model") | |
| train_parser.add_argument("--epochs", type=int, default=50, help="Number of training epochs") | |
| train_parser.add_argument("--batch_size", type=int, default=32, help="Batch size for training") | |
| # Subparser for inference | |
| _ = subparsers.add_parser("interface", help="Perform image-to-art search using the fine-tuned model") | |
| create_gallery_parser = subparsers.add_parser("gallery", help="Create new gallery from a path") | |
| create_gallery_parser.add_argument("--gallery_path", type=str, default="data/wikiart") | |
| args = parser.parse_args() | |
| if args.command == "train": | |
| train(args.epochs, args.batch_size) | |
| elif args.command == "interface": | |
| make_interface() | |
| elif args.command == "gallery": | |
| create_gallery(args.gallery_path) | |
| else: | |
| parser.print_help() | |
| if __name__ == "__main__": | |
| main() | |