tcm03 commited on
Commit
8cda892
·
1 Parent(s): 941ce80

Add image encoding

Browse files
Files changed (1) hide show
  1. inference.py +29 -18
inference.py CHANGED
@@ -33,14 +33,12 @@ def load_model():
33
  model.load_state_dict(sd, strict=False)
34
  model = model.to(device).eval()
35
 
36
- # Initialize transformer
37
  global transformer
38
  transformer = _transform(model.visual.input_resolution, is_train=False)
39
  print("Model loaded successfully.")
40
 
41
- # Preprocessing Functions
42
  def preprocess_image(image_base64):
43
- """Convert base64 encoded image to tensor."""
44
  image = Image.open(BytesIO(base64.b64decode(image_base64))).convert("RGB")
45
  image = transformer(image).unsqueeze(0).to(device)
46
  return image
@@ -49,39 +47,52 @@ def preprocess_text(text):
49
  """Tokenize text query."""
50
  return tokenize([str(text)])[0].unsqueeze(0).to(device)
51
 
52
- def get_fused_embedding(image_base64, text):
53
  """Fuse sketch and text features into a single embedding."""
54
  with torch.no_grad():
55
- # Preprocess Inputs
56
- image_tensor = preprocess_image(image_base64)
57
  text_tensor = preprocess_text(text)
58
 
59
- # Extract Features
60
- sketch_feature = model.encode_sketch(image_tensor)
61
  text_feature = model.encode_text(text_tensor)
62
 
63
- # Normalize Features
64
  sketch_feature = sketch_feature / sketch_feature.norm(dim=-1, keepdim=True)
65
  text_feature = text_feature / text_feature.norm(dim=-1, keepdim=True)
66
 
67
- # Fuse Features
68
  fused_embedding = model.feature_fuse(sketch_feature, text_feature)
69
  return fused_embedding.cpu().numpy().tolist()
70
 
 
 
 
 
 
 
 
 
71
  # Hugging Face Inference API Entry Point
72
  def infer(inputs):
73
  """
74
  Inference API entry point.
75
  Inputs:
76
- - 'image': Base64 encoded sketch image.
77
  - 'text': Text query.
78
  """
79
  load_model() # Ensure the model is loaded once
80
- image_base64 = inputs.get("image", "")
81
- text_query = inputs.get("text", "")
82
- if not image_base64 or not text_query:
83
- return {"error": "Both 'image' (base64) and 'text' are required inputs."}
 
84
 
85
- # Generate Fused Embedding
86
- fused_embedding = get_fused_embedding(image_base64, text_query)
87
- return {"fused_embedding": fused_embedding}
 
 
 
 
 
 
 
 
 
33
  model.load_state_dict(sd, strict=False)
34
  model = model.to(device).eval()
35
 
 
36
  global transformer
37
  transformer = _transform(model.visual.input_resolution, is_train=False)
38
  print("Model loaded successfully.")
39
 
 
40
  def preprocess_image(image_base64):
41
+ """Convert base64 encoded sketch to tensor."""
42
  image = Image.open(BytesIO(base64.b64decode(image_base64))).convert("RGB")
43
  image = transformer(image).unsqueeze(0).to(device)
44
  return image
 
47
  """Tokenize text query."""
48
  return tokenize([str(text)])[0].unsqueeze(0).to(device)
49
 
50
+ def get_fused_embedding(sketch_base64, text):
51
  """Fuse sketch and text features into a single embedding."""
52
  with torch.no_grad():
53
+ sketch_tensor = preprocess_image(sketch_base64)
 
54
  text_tensor = preprocess_text(text)
55
 
56
+ sketch_feature = model.encode_sketch(sketch_tensor)
 
57
  text_feature = model.encode_text(text_tensor)
58
 
 
59
  sketch_feature = sketch_feature / sketch_feature.norm(dim=-1, keepdim=True)
60
  text_feature = text_feature / text_feature.norm(dim=-1, keepdim=True)
61
 
 
62
  fused_embedding = model.feature_fuse(sketch_feature, text_feature)
63
  return fused_embedding.cpu().numpy().tolist()
64
 
65
+ def get_image_embedding(image_base64):
66
+ """Convert base64 encoded image to tensor."""
67
+ image_tensor = preprocess_image(image_base64)
68
+ with torch.no_grad():
69
+ image_feature = model.encode_image(image_tensor)
70
+ image_feature = image_feature / image_feature.norm(dim=-1, keepdim=True)
71
+ return image_feature.cpu().numpy().tolist()
72
+
73
  # Hugging Face Inference API Entry Point
74
  def infer(inputs):
75
  """
76
  Inference API entry point.
77
  Inputs:
78
+ - 'sketch': Base64 encoded sketch image.
79
  - 'text': Text query.
80
  """
81
  load_model() # Ensure the model is loaded once
82
+ if "sketch" in inputs:
83
+ sketch_base64 = inputs.get("sketch", "")
84
+ text_query = inputs.get("text", "")
85
+ if not sketch_base64 or not text_query:
86
+ return {"error": "Both 'sketch' (base64) and 'text' are required inputs."}
87
 
88
+ # Generate Fused Embedding
89
+ fused_embedding = get_fused_embedding(sketch_base64, text_query)
90
+ return {"embedding": fused_embedding}
91
+ elif "image" in inputs:
92
+ image_base64 = inputs.get("image", "")
93
+ if not image_base64:
94
+ return {"error": "Image 'image' (base64) is required input."}
95
+ embedding = get_image_embedding(image_base64)
96
+ return {"embedding": embedding}
97
+ else:
98
+ return {"error": "Input 'sketch' or 'image' is required."}