yainage90's picture
Update README.md
7329ef7 verified
metadata
tags:
  - model_hub_mixin
  - pytorch_model_hub_mixin

This is fashion image feature extractor model.

1. Model Architecture

I used microsoft/swin-base-patch4-window7-224 for base image encoder model. Just added a 128 size fully connected layer to lower embedding size. The dataset used anchor (product areas detected from posts) - positive (product thumbnail) image pairs. Within each batch, all samples except one's own positive were used as negative samples, training to minimize the distance between anchor-positive pairs while maximizing the distance between anchor-negative pairs. This method is known as contrastive learning, which is the training method used by OpenAI's CLIP model. Initially, anchor - positive - negative pairs were explicitly constructed in a 1:1:1 ratio using triplet loss, but training with in-batch negative sampling and contrastive loss showed much better performance as it allowed learning from more negative samples.

image_encoder contrastive_learning

2. Training dataset

User posting images from onthelook and kream were crawled and preprocessed. First, raw data of image-product thumbnail combinations from posts were collected. Then, object detection was performed on posting images, and category classification was performed on product thumbnails to pair images of the same category together. For thumbnail category classification, a trained category classifier was used. Finally, about 290,000 anchor-positive image pairs were created for 6 categories: tops, bottoms, outer, shoes, bags, and hats. Finally, approximately 290,000 anchor-positive image pairs were created for 6 categories: tops, bottoms, outer, shoes, bags, and hats.

You can find object-detection model -> https://huggingface.co/yainage90/fashion-object-detection

You can find details of model in this github repo -> fashion-visual-search

from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as v2
from transformers import AutoImageProcessor, SwinModel, SwinConfig
from huggingface_hub import PyTorchModelHubMixin

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

ckpt = "yainage90/fashion-image-feature-extractor"
encoder_config = SwinConfig.from_pretrained(ckpt)
encoder_image_processor = AutoImageProcessor.from_pretrained(ckpt)

class ImageEncoder(nn.Module, PyTorchModelHubMixin):
    def __init__(self):
        super(ImageEncoder, self).__init__()
        self.swin = SwinModel(config=encoder_config)
        self.embedding_layer = nn.Linear(encoder_config.hidden_size, 128)

    def forward(self, image_tensor):
        features = self.swin(image_tensor).pooler_output
        embeddings = self.embedding_layer(features)
        embeddings = F.normalize(embeddings, p=2, dim=1)

        return embeddings

encoder = ImageEncoder().from_pretrained('yainage90/fashion-image-feature-extractor').to(device)

transform = v2.Compose([
    v2.Resize((encoder_config.image_size, encoder_config.image_size)),
    v2.ToTensor(),
    v2.Normalize(mean=encoder_image_processor.image_mean, std=encoder_image_processor.image_std),
])

image = Image.open('<path/to/image>').convert('RGB')
image = transform(image)
with torch.no_grad():
    embedding = encoder(image.unsqueeze(0).to(device)).cpu().numpy()
detection_image1 result_image1 detection_image2 result_image2