user-agent commited on
Commit
2abf116
1 Parent(s): f2666b6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -24
app.py CHANGED
@@ -1,41 +1,69 @@
1
- from turtle import title
 
2
  import requests
3
- from io import BytesIO
4
- import gradio as gr
5
- from transformers import pipeline
6
  import numpy as np
 
7
  from PIL import Image
8
- import spaces
9
-
 
 
10
  pipe = pipeline("zero-shot-image-classification", model="patrickjohncyh/fashion-clip")
11
- images="dog.jpg"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  @spaces.GPU
14
- def shot(input, labels_text):
15
- if isinstance(input, str) and (input.startswith("http://") or input.startswith("https://")):
16
- # Input is a URL
17
- response = requests.get(input)
18
- PIL_image = Image.open(BytesIO(response.content)).convert('RGB')
19
- else:
20
- # Input is an uploaded image
21
- PIL_image = Image.fromarray(np.uint8(input)).convert('RGB')
22
-
23
- labels = labels_text.split(",")
24
- res = pipe(images=PIL_image,
25
- candidate_labels=labels,
26
- hypothesis_template="This is a photo of a {}")
27
- return {dic["label"]: dic["score"] for dic in res}
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  # Define the Gradio interface with the updated components
30
  iface = gr.Interface(
31
  fn=shot,
32
  inputs=[
33
- gr.Textbox(label="Image URL (starting with http/https) or Upload Image"),
34
- gr.Textbox(label="Labels (comma-separated)")
35
  ],
36
  outputs=gr.Label(),
37
  description="Add an image URL (starting with http/https) or upload a picture, and provide a list of labels separated by commas.",
38
- title="Zero-shot Image Classification"
39
  )
40
 
41
  # Launch the interface
 
1
+ import json
2
+ import spaces
3
  import requests
 
 
 
4
  import numpy as np
5
+ import gradio as gr
6
  from PIL import Image
7
+ from io import BytesIO
8
+ from turtle import title
9
+ from transformers import pipeline
10
+ import ast
11
  pipe = pipeline("zero-shot-image-classification", model="patrickjohncyh/fashion-clip")
12
+
13
+ file_path = 'config.json'
14
+
15
+ # Open and read the JSON file
16
+ with open(file_path, 'r') as file:
17
+ data = json.load(file)
18
+
19
+ COLOURS_DICT = data['color_mapping']
20
+
21
+
22
+
23
+ def shot(input, category):
24
+ subColour,mainColour,score = get_colour(ast.literal_eval(str(input)),category)
25
+ return subColour,mainColour,score
26
+
27
+
28
 
29
  @spaces.GPU
30
+ def get_colour(image_urls, category):
31
+ colourLabels = list(COLOURS_DICT.keys())
32
+ for i in range(len(colourLabels)):
33
+ colourLabels[i] = colourLabels[i] + " clothing: " + category
34
+
35
+ responses = pipe(image_urls, candidate_labels=colourLabels, device=device)
36
+ # Get the most common colour
37
+ mainColour = responses[0][0]['label'].split(" clothing:")[0]
38
+
39
+
40
+ if mainColour not in COLOURS_DICT:
41
+ return None, None, None
42
+
43
+ # Add category to the end of each label
44
+ labels = COLOURS_DICT[mainColour]
45
+ for i in range(len(labels)):
46
+ labels[i] = labels[i] + " clothing: " + category
47
+
48
+ # Run pipeline in one go
49
+ responses = pipe(image_urls, candidate_labels=labels, device=device)
50
+ subColour = responses[0][0]['label'].split(" clothing:")[0]
51
+
52
+ return subColour, mainColour, responses[0][0]['score']
53
+
54
+
55
+
56
 
57
  # Define the Gradio interface with the updated components
58
  iface = gr.Interface(
59
  fn=shot,
60
  inputs=[
61
+ gr.Textbox(label="Image URLs (starting with http/https) comma seperated "),
62
+ gr.Textbox(label="Category")
63
  ],
64
  outputs=gr.Label(),
65
  description="Add an image URL (starting with http/https) or upload a picture, and provide a list of labels separated by commas.",
66
+ title="Full product flow"
67
  )
68
 
69
  # Launch the interface