import gradio as gr import torch import torchvision.transforms as transforms from PIL import Image from safetensors.torch import load_model from huggingface_hub import hf_hub_download from timm import list_models, create_model import os import numpy as np # Intialize the model model_name='swin_s3_base_224' model = create_model( model_name, num_classes=36 ) load_model(model,f'./{model_name}/model.safetensors') # Define class names class_names = ["3/4 Sleeve", "Accessory", "Babydoll", "Closed Back", "Corset", "Crochet", "Cutouts", "Draped", "Floral", "Gloves", "Halter", "Lace", "Long", "Long Sleeve", "Midi", "No Slit", "Off The Shoulder", "One Shoulder", "Open Back", "Pockets", "Print", "Puff Sleeve", "Ruched", "Satin", "Sequins", "Shimmer", "Short", "Short Sleeve", "Side Slit", "Square Neck", "Strapless", "Sweetheart Neck", "Tight", "V-Neck", "Velvet", "Wrap"] label2id = {c:idx for idx,c in enumerate(class_names)} id2label = {idx:c for idx,c in enumerate(class_names)} def predict_features(image_path): # Load PIL image pil_image = Image.open(image_path).convert('RGB') # Define transformations to resize and convert image to tensor transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor() ]) tensor_image = transform(pil_image) inputs = tensor_image.unsqueeze(0) with torch.no_grad(): logits = model(inputs) # apply sigmoid activation to convert logits to probabilities # getting labels with confidence threshold of 0.5 predictions = logits.sigmoid() > 0.5 # converting one-hot encoded predictions back to list of labels predictions = predictions.float().numpy().flatten() # convert boolean predictions to float pred_labels = np.where(predictions==1)[0] # find indices where prediction is 1 pred_labels = ([id2label[label] for label in pred_labels]) # converting integer labels to string print(pred_labels) return pred_labels def analyze(image): return str(predict_features(image)) demo = gr.Interface(fn=analyze, title='Feature Prediction', description=""" [Model](https://huggingface.co/LucyintheSky/lucy-feature-prediction) """, inputs=gr.Image(type='filepath'), outputs="text", examples=[['./1.jpg'], ['./2.jpg'], ['./3.jpg'], ['./4.jpg']]) demo.launch()