|
|
|
|
|
import gradio as gr |
|
import os |
|
from PIL import Image |
|
|
|
url = "http://static.okkular.io/scripted.model" |
|
|
|
output_file = "./scripted.model" |
|
with urllib.request.urlopen(url) as response, open(output_file, 'wb') as out_file: |
|
shutil.copyfileobj(response, out_file) |
|
|
|
def get_stl(input_sku): |
|
preds=shop_the_look(f'./data/dress_{input_sku}.jpg') |
|
ret_bag = preds['./segs/bag.jpg'][1] |
|
ret_shoes = preds['./segs/shoe.jpg'][1] |
|
return Image.open(f'./data/dress_{input_sku}.jpg'), Image.open(ret_bag), Image.open(ret_shoes) |
|
|
|
sku = gr.Dropdown( |
|
["1", "2", "3", '4', '5'], label="Dress Sku", |
|
), |
|
|
|
|
|
demo = gr.Interface(get_stl, gr.Dropdown( |
|
["1", "2", "3", '4', '5'], label="Dress Sku"), ["image", "image", "image"]) |
|
|
|
|
|
demo.launch(root_path=f"/{os.getenv('TOKEN')}") |
|
|
|
from PIL import Image, ImageChops |
|
import numpy as np |
|
from transformers import AutoFeatureExtractor, SegformerForSemanticSegmentation |
|
from PIL import Image |
|
import requests |
|
import matplotlib.pyplot as plt |
|
import torch |
|
import torch.nn as nn |
|
import os |
|
import nmslib |
|
from fastai.vision.all import * |
|
|
|
|
|
|
|
def get_segment(image, num,ret=False): |
|
|
|
extractor = AutoFeatureExtractor.from_pretrained("mattmdjaga/segformer_b2_clothes") |
|
model = SegformerForSemanticSegmentation.from_pretrained("mattmdjaga/segformer_b2_clothes") |
|
|
|
inputs = extractor(images=image, return_tensors="pt") |
|
|
|
outputs = model(**inputs) |
|
logits = outputs.logits.cpu() |
|
|
|
upsampled_logits = nn.functional.interpolate( |
|
logits, |
|
size=image.size[::-1], |
|
mode="bilinear", |
|
align_corners=False, |
|
) |
|
|
|
|
|
|
|
pred_seg = upsampled_logits.argmax(dim=1)[0] |
|
np_im = np.array(image) |
|
pred_seg[pred_seg != num] = 0 |
|
mask = pred_seg.detach().cpu().numpy() |
|
|
|
|
|
np_im[mask.squeeze()==0] = 0 |
|
|
|
np_im[np.where((np_im==[0,0,0]).all(axis=2))] = [255,255,255] |
|
|
|
|
|
im = Image.fromarray(np.uint8(np_im)).convert('RGB') |
|
im = trim(im) |
|
|
|
if ret==False: |
|
plt.imshow(im) |
|
plt.show() |
|
elif ret==True: |
|
print('here and returning', im) |
|
return im |
|
|
|
|
|
def trim(im): |
|
bg = Image.new(im.mode, im.size, im.getpixel((0,0))) |
|
diff = ImageChops.difference(im, bg) |
|
diff = ImageChops.add(diff, diff, 2.0, -100) |
|
bbox = diff.getbbox() |
|
if bbox: |
|
return im.crop(bbox) |
|
|
|
|
|
def get_pred_seg(image_url): |
|
extractor = AutoFeatureExtractor.from_pretrained("mattmdjaga/segformer_b2_clothes") |
|
model = SegformerForSemanticSegmentation.from_pretrained("mattmdjaga/segformer_b2_clothes") |
|
|
|
|
|
image = Image.open(image_url) |
|
inputs = extractor(images=image, return_tensors="pt") |
|
|
|
outputs = model(**inputs) |
|
logits = outputs.logits.cpu() |
|
|
|
upsampled_logits = nn.functional.interpolate( |
|
logits, |
|
size=image.size[::-1], |
|
mode="bilinear", |
|
align_corners=False, |
|
) |
|
|
|
pred_seg = upsampled_logits.argmax(dim=1)[0] |
|
|
|
return upsampled_logits,pred_seg |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_predictions(feed): |
|
pairs = [[x["sku"], x["category"]] for x in feed] |
|
skus, labels = zip(*pairs) |
|
labels = np.array(labels) |
|
skus = np.array(skus) |
|
categories = list(set(labels)) |
|
|
|
def get_image_fpath(x): |
|
return x[0] |
|
|
|
data = DataBlock( |
|
blocks=(ImageBlock, CategoryBlock), |
|
get_x = get_image_fpath, |
|
get_y = ItemGetter(1), |
|
item_tfms=[Resize(256)], |
|
batch_tfms=[Normalize.from_stats(*imagenet_stats)], |
|
splitter=IndexSplitter([]) |
|
) |
|
|
|
|
|
dls = data.dataloaders( |
|
pairs, |
|
device=default_device(), |
|
shuffle_fn=lambda x:x, |
|
drop_last=False |
|
) |
|
|
|
|
|
|
|
with open('./scripted.model', 'rb') as f: |
|
buffer = io.BytesIO(f.read()) |
|
|
|
|
|
model = torch.jit.load(buffer, map_location=torch.device('cpu')) |
|
|
|
preds_list = [] |
|
with torch.no_grad(): |
|
for x,y in progress_bar(iter(dls.train), total=len(dls.train)): |
|
pred = model(x) |
|
preds_list.append(pred) |
|
preds = torch.cat(preds_list) |
|
preds = to_np(preds) |
|
|
|
predictions_json = {} |
|
for cat in categories: |
|
filtered_preds = preds[labels == cat] |
|
filtered_skus = skus[labels==cat] |
|
neighbours,dists = get_neighbours(filtered_preds) |
|
|
|
|
|
for i, sku in enumerate(filtered_skus): |
|
predictions_json[sku] = [filtered_skus[j] for j in neighbours[i]] |
|
|
|
return predictions_json |
|
|
|
INDEX_TIME_PARAMS = {'M': 100, 'indexThreadQty': 8, |
|
'efConstruction': 2000, 'post': 0} |
|
QUERY_TIME_PARAMS = {"efSearch": 2000} |
|
N_NEIGHBOURS = 4 |
|
def get_neighbours(embeddings): |
|
index = nmslib.init(method='hnsw', space='l2') |
|
index.addDataPointBatch(embeddings) |
|
index.createIndex(INDEX_TIME_PARAMS) |
|
index.setQueryTimeParams(QUERY_TIME_PARAMS) |
|
res = index.knnQueryBatch( |
|
embeddings, k=min(N_NEIGHBOURS+1, embeddings.shape[0]), num_threads=8) |
|
proc_res = [l[None] for (l, d) in res] |
|
neighbours = np.concatenate(proc_res).astype(np.int32) |
|
dists = np.array([d for (_, d) in res]).astype(np.float32) |
|
return neighbours , dists |
|
|
|
|
|
def shop_the_look(prod): |
|
|
|
|
|
bag_segment=get_segment(Image.open(prod), 16, ret=True) |
|
bag_segment.save('./segs/bag.jpg') |
|
shoe_l = get_segment(Image.open(prod), 9, True) |
|
shoe_r = get_segment(Image.open(prod), 10, True) |
|
shoe_segment = concat_h(shoe_l, shoe_r) |
|
shoe_segment.save('./segs/shoe.jpg') |
|
|
|
|
|
feed= [] |
|
main_prods=os.listdir('./data') |
|
for sku in main_prods: |
|
if 'checkpoint' not in sku: |
|
cat = sku.split('_')[0] |
|
x={'sku':f'./data/{sku}', 'category':cat} |
|
feed.append(x) |
|
|
|
feed.extend([{'sku':'./segs/shoe.jpg', |
|
'category':'shoes'}, |
|
{'sku':'./segs/bag.jpg', |
|
'category':'bag'}]) |
|
|
|
preds=get_predictions(feed) |
|
return preds |
|
|
|
|
|
|
|
def concat_h(image1,image2): |
|
|
|
|
|
image1 = image1.resize((426, 240)) |
|
image1_size = image1.size |
|
image2_size = image2.size |
|
new_image = Image.new('RGB',(2*image1_size[0], image1_size[1]), (250,250,250)) |
|
new_image.paste(image1,(0,0)) |
|
new_image.paste(image2,(image1_size[0],0)) |
|
return new_image |