tcm03
commited on
Commit
·
8cda892
1
Parent(s):
941ce80
Add image encoding
Browse files- inference.py +29 -18
inference.py
CHANGED
@@ -33,14 +33,12 @@ def load_model():
|
|
33 |
model.load_state_dict(sd, strict=False)
|
34 |
model = model.to(device).eval()
|
35 |
|
36 |
-
# Initialize transformer
|
37 |
global transformer
|
38 |
transformer = _transform(model.visual.input_resolution, is_train=False)
|
39 |
print("Model loaded successfully.")
|
40 |
|
41 |
-
# Preprocessing Functions
|
42 |
def preprocess_image(image_base64):
|
43 |
-
"""Convert base64 encoded
|
44 |
image = Image.open(BytesIO(base64.b64decode(image_base64))).convert("RGB")
|
45 |
image = transformer(image).unsqueeze(0).to(device)
|
46 |
return image
|
@@ -49,39 +47,52 @@ def preprocess_text(text):
|
|
49 |
"""Tokenize text query."""
|
50 |
return tokenize([str(text)])[0].unsqueeze(0).to(device)
|
51 |
|
52 |
-
def get_fused_embedding(
|
53 |
"""Fuse sketch and text features into a single embedding."""
|
54 |
with torch.no_grad():
|
55 |
-
|
56 |
-
image_tensor = preprocess_image(image_base64)
|
57 |
text_tensor = preprocess_text(text)
|
58 |
|
59 |
-
|
60 |
-
sketch_feature = model.encode_sketch(image_tensor)
|
61 |
text_feature = model.encode_text(text_tensor)
|
62 |
|
63 |
-
# Normalize Features
|
64 |
sketch_feature = sketch_feature / sketch_feature.norm(dim=-1, keepdim=True)
|
65 |
text_feature = text_feature / text_feature.norm(dim=-1, keepdim=True)
|
66 |
|
67 |
-
# Fuse Features
|
68 |
fused_embedding = model.feature_fuse(sketch_feature, text_feature)
|
69 |
return fused_embedding.cpu().numpy().tolist()
|
70 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
# Hugging Face Inference API Entry Point
|
72 |
def infer(inputs):
|
73 |
"""
|
74 |
Inference API entry point.
|
75 |
Inputs:
|
76 |
-
- '
|
77 |
- 'text': Text query.
|
78 |
"""
|
79 |
load_model() # Ensure the model is loaded once
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
|
|
84 |
|
85 |
-
|
86 |
-
|
87 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
model.load_state_dict(sd, strict=False)
|
34 |
model = model.to(device).eval()
|
35 |
|
|
|
36 |
global transformer
|
37 |
transformer = _transform(model.visual.input_resolution, is_train=False)
|
38 |
print("Model loaded successfully.")
|
39 |
|
|
|
40 |
def preprocess_image(image_base64):
|
41 |
+
"""Convert base64 encoded sketch to tensor."""
|
42 |
image = Image.open(BytesIO(base64.b64decode(image_base64))).convert("RGB")
|
43 |
image = transformer(image).unsqueeze(0).to(device)
|
44 |
return image
|
|
|
47 |
"""Tokenize text query."""
|
48 |
return tokenize([str(text)])[0].unsqueeze(0).to(device)
|
49 |
|
50 |
+
def get_fused_embedding(sketch_base64, text):
|
51 |
"""Fuse sketch and text features into a single embedding."""
|
52 |
with torch.no_grad():
|
53 |
+
sketch_tensor = preprocess_image(sketch_base64)
|
|
|
54 |
text_tensor = preprocess_text(text)
|
55 |
|
56 |
+
sketch_feature = model.encode_sketch(sketch_tensor)
|
|
|
57 |
text_feature = model.encode_text(text_tensor)
|
58 |
|
|
|
59 |
sketch_feature = sketch_feature / sketch_feature.norm(dim=-1, keepdim=True)
|
60 |
text_feature = text_feature / text_feature.norm(dim=-1, keepdim=True)
|
61 |
|
|
|
62 |
fused_embedding = model.feature_fuse(sketch_feature, text_feature)
|
63 |
return fused_embedding.cpu().numpy().tolist()
|
64 |
|
65 |
+
def get_image_embedding(image_base64):
|
66 |
+
"""Convert base64 encoded image to tensor."""
|
67 |
+
image_tensor = preprocess_image(image_base64)
|
68 |
+
with torch.no_grad():
|
69 |
+
image_feature = model.encode_image(image_tensor)
|
70 |
+
image_feature = image_feature / image_feature.norm(dim=-1, keepdim=True)
|
71 |
+
return image_feature.cpu().numpy().tolist()
|
72 |
+
|
73 |
# Hugging Face Inference API Entry Point
|
74 |
def infer(inputs):
|
75 |
"""
|
76 |
Inference API entry point.
|
77 |
Inputs:
|
78 |
+
- 'sketch': Base64 encoded sketch image.
|
79 |
- 'text': Text query.
|
80 |
"""
|
81 |
load_model() # Ensure the model is loaded once
|
82 |
+
if "sketch" in inputs:
|
83 |
+
sketch_base64 = inputs.get("sketch", "")
|
84 |
+
text_query = inputs.get("text", "")
|
85 |
+
if not sketch_base64 or not text_query:
|
86 |
+
return {"error": "Both 'sketch' (base64) and 'text' are required inputs."}
|
87 |
|
88 |
+
# Generate Fused Embedding
|
89 |
+
fused_embedding = get_fused_embedding(sketch_base64, text_query)
|
90 |
+
return {"embedding": fused_embedding}
|
91 |
+
elif "image" in inputs:
|
92 |
+
image_base64 = inputs.get("image", "")
|
93 |
+
if not image_base64:
|
94 |
+
return {"error": "Image 'image' (base64) is required input."}
|
95 |
+
embedding = get_image_embedding(image_base64)
|
96 |
+
return {"embedding": embedding}
|
97 |
+
else:
|
98 |
+
return {"error": "Input 'sketch' or 'image' is required."}
|