|
import gradio as gr |
|
from Models import VisionModel |
|
import huggingface_hub |
|
from PIL import Image |
|
import torch.amp.autocast_mode |
|
from pathlib import Path |
|
|
|
|
|
MODEL_REPO = "fancyfeast/joytag" |
|
|
|
|
|
@torch.no_grad() |
|
def predict(image: Image.Image): |
|
with torch.amp.autocast_mode.autocast('cuda', enabled=True): |
|
preds = model(image) |
|
tag_preds = preds['tags'].sigmoid().cpu() |
|
|
|
return {top_tags[i]: tag_preds[i] for i in range(len(top_tags))} |
|
|
|
|
|
print("Downloading model...") |
|
path = huggingface_hub.snapshot_download(MODEL_REPO) |
|
print("Loading model...") |
|
model = VisionModel.load_model(path) |
|
model.eval() |
|
|
|
with open(Path(path) / 'top_tags.txt', 'r') as f: |
|
top_tags = [line.strip() for line in f.readlines() if line.strip()] |
|
|
|
print("Starting server...") |
|
|
|
gradio_app = gr.Interface( |
|
predict, |
|
inputs=gr.Image(label="Source", sources=['upload', 'webcam'], type='pil'), |
|
outputs=[gr.Label(label="Result", num_top_classes=5)], |
|
title="JoyTag", |
|
) |
|
|
|
|
|
if __name__ == '__main__': |
|
gradio_app.launch() |
|
|