Leigh Jewell commited on
Commit
5bb7068
1 Parent(s): 0894749

First commit

Browse files
Files changed (2) hide show
  1. app.py +60 -0
  2. requirements.txt +9 -0
app.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoModel, AutoProcessor
3
+ import torch
4
+ import requests
5
+ from PIL import Image
6
+ from io import BytesIO
7
+
8
+ fashion_items = ['top', 'trousers', 'jumper']
9
+
10
+ # Load model and processor
11
+ model_name = 'Marqo/marqo-fashionSigLIP'
12
+ model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
13
+ processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
14
+
15
+ # Preprocess and normalize text data
16
+ with torch.no_grad():
17
+ # Ensure truncation and padding are activated
18
+ processed_texts = processor(
19
+ text=fashion_items,
20
+ return_tensors="pt",
21
+ truncation=True, # Ensure text is truncated to fit model input size
22
+ padding=True # Pad shorter sequences so that all are the same length
23
+ )['input_ids']
24
+
25
+ text_features = model.get_text_features(processed_texts)
26
+ text_features = text_features / text_features.norm(dim=-1, keepdim=True)
27
+
28
+
29
+ # Prediction function
30
+ def predict_from_url(url):
31
+ # Check if the URL is empty
32
+ if not url:
33
+ return {"Error": "Please input a URL"}
34
+
35
+ try:
36
+ image = Image.open(BytesIO(requests.get(url).content))
37
+ except Exception as e:
38
+ return {"Error": f"Failed to load image: {str(e)}"}
39
+
40
+ processed_image = processor(images=image, return_tensors="pt")['pixel_values']
41
+
42
+ with torch.no_grad():
43
+ image_features = model.get_image_features(processed_image)
44
+ image_features = image_features / image_features.norm(dim=-1, keepdim=True)
45
+ text_probs = (100 * image_features @ text_features.T).softmax(dim=-1)
46
+
47
+ return {fashion_items[i]: float(text_probs[0, i]) for i in range(len(fashion_items))}
48
+
49
+
50
+ # Gradio interface
51
+ demo = gr.Interface(
52
+ fn=predict_from_url,
53
+ inputs=gr.Textbox(label="Enter Image URL"),
54
+ outputs=gr.Label(label="Classification Results"),
55
+ title="Fashion Item Classifier",
56
+ allow_flagging="never"
57
+ )
58
+
59
+ # Launch the interface
60
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ transformers
2
+ torch
3
+ requests
4
+ Pillow
5
+ open_clip_torch
6
+ ftfy
7
+
8
+ # This is only needed for local deployment
9
+ gradio