File size: 1,991 Bytes
2abf116
 
a4dc223
f30185d
2abf116
f30185d
2abf116
 
 
 
6e0ffb7
2abf116
 
 
 
 
 
 
 
 
 
 
 
 
3bda041
 
 
 
 
 
 
2abf116
 
f30185d
a792718
2abf116
 
 
 
 
64af198
2abf116
 
 
 
 
 
 
 
 
 
 
 
 
64af198
2abf116
 
 
 
 
 
a4dc223
5012113
a4dc223
 
 
2abf116
 
a4dc223
5012113
a4dc223
2abf116
a4dc223
f30185d
a4dc223
5012113
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import json
import spaces
import requests
import numpy as np
import gradio as gr
from PIL import Image
from io import BytesIO
from turtle import title
from transformers import pipeline
import ast
pipe = pipeline("zero-shot-image-classification", model="patrickjohncyh/fashion-clip")

file_path = 'config.json'

# Open and read the JSON file
with open(file_path, 'r') as file:
    data = json.load(file)

COLOURS_DICT = data['color_mapping']



def shot(input, category):
    subColour,mainColour,score = get_colour(ast.literal_eval(str(input)),category)
    return {
        "colors":[
            "main":mainColour,
            "sub":subColour,
            "score":score
        ]   
    }



@spaces.GPU  
def get_colour(image_urls, category):
    colourLabels = list(COLOURS_DICT.keys())
    for i in range(len(colourLabels)):
        colourLabels[i] = colourLabels[i] + " clothing: " + category

    responses = pipe(image_urls, candidate_labels=colourLabels)
    # Get the most common colour
    mainColour = responses[0][0]['label'].split(" clothing:")[0]


    if mainColour not in COLOURS_DICT:
        return None, None, None

    # Add category to the end of each label
    labels = COLOURS_DICT[mainColour]
    for i in range(len(labels)):
        labels[i] = labels[i] + " clothing: " + category

    # Run pipeline in one go
    responses = pipe(image_urls, candidate_labels=labels)
    subColour = responses[0][0]['label'].split(" clothing:")[0]

    return subColour, mainColour, responses[0][0]['score']




# Define the Gradio interface with the updated components
iface = gr.Interface(
    fn=shot, 
    inputs=[
        gr.Textbox(label="Image URLs (starting with http/https) comma seperated "), 
        gr.Textbox(label="Category")
    ], 
    outputs=gr.Label(), 
    description="Add an image URL (starting with http/https) or upload a picture, and provide a list of labels separated by commas.",
    title="Full product flow"
)

# Launch the interface
iface.launch()