arjunanand13
commited on
Commit
•
321843f
1
Parent(s):
01fb00e
Update handler.py
Browse files- handler.py +26 -28
handler.py
CHANGED
@@ -39,47 +39,46 @@ class EndpointHandler:
|
|
39 |
if torch.cuda.is_available():
|
40 |
torch.cuda.empty_cache()
|
41 |
|
42 |
-
def process_image(self,
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
image
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
image = Image.open(BytesIO(image_data))
|
51 |
-
else:
|
52 |
-
raise ValueError("Invalid image input")
|
53 |
-
elif isinstance(image_input, bytes):
|
54 |
-
image = Image.open(BytesIO(image_input))
|
55 |
-
else:
|
56 |
-
raise ValueError("Unsupported image input type")
|
57 |
-
|
58 |
-
return image
|
59 |
|
60 |
def __call__(self, data):
|
61 |
try:
|
62 |
-
#
|
63 |
-
|
64 |
-
text_input = data.pop("text", "")
|
65 |
|
66 |
-
#
|
67 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
|
69 |
-
# Prepare inputs
|
70 |
-
|
71 |
images=image if image else None,
|
72 |
text=text_input,
|
73 |
return_tensors="pt"
|
74 |
)
|
75 |
|
76 |
# Move inputs to device
|
77 |
-
|
78 |
-
|
79 |
|
80 |
# Generate output
|
81 |
with torch.no_grad():
|
82 |
-
outputs = self.model.generate(**
|
83 |
|
84 |
# Decode outputs
|
85 |
decoded_outputs = self.processor.batch_decode(outputs, skip_special_tokens=True)
|
@@ -88,7 +87,6 @@ class EndpointHandler:
|
|
88 |
|
89 |
except Exception as e:
|
90 |
return {"error": str(e)}
|
91 |
-
|
92 |
# import subprocess
|
93 |
# import sys
|
94 |
# import torch
|
|
|
39 |
if torch.cuda.is_available():
|
40 |
torch.cuda.empty_cache()
|
41 |
|
42 |
+
def process_image(self, image_path):
|
43 |
+
try:
|
44 |
+
with open(image_path, 'rb') as image_file:
|
45 |
+
image = Image.open(image_file)
|
46 |
+
return image
|
47 |
+
except Exception as e:
|
48 |
+
print(f"Error processing image: {str(e)}")
|
49 |
+
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
|
51 |
def __call__(self, data):
|
52 |
try:
|
53 |
+
# Extract inputs from the expected Hugging Face format
|
54 |
+
inputs = data.pop("inputs", data)
|
|
|
55 |
|
56 |
+
# Check if inputs is a dict or string
|
57 |
+
if isinstance(inputs, dict):
|
58 |
+
image_path = inputs.get("image", None)
|
59 |
+
text_input = inputs.get("text", "")
|
60 |
+
else:
|
61 |
+
# If inputs is not a dict, assume it's the image path
|
62 |
+
image_path = inputs
|
63 |
+
text_input = "What is in this image?"
|
64 |
+
|
65 |
+
# Process image
|
66 |
+
image = self.process_image(image_path) if image_path else None
|
67 |
|
68 |
+
# Prepare inputs for the model
|
69 |
+
model_inputs = self.processor(
|
70 |
images=image if image else None,
|
71 |
text=text_input,
|
72 |
return_tensors="pt"
|
73 |
)
|
74 |
|
75 |
# Move inputs to device
|
76 |
+
model_inputs = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
|
77 |
+
for k, v in model_inputs.items()}
|
78 |
|
79 |
# Generate output
|
80 |
with torch.no_grad():
|
81 |
+
outputs = self.model.generate(**model_inputs)
|
82 |
|
83 |
# Decode outputs
|
84 |
decoded_outputs = self.processor.batch_decode(outputs, skip_special_tokens=True)
|
|
|
87 |
|
88 |
except Exception as e:
|
89 |
return {"error": str(e)}
|
|
|
90 |
# import subprocess
|
91 |
# import sys
|
92 |
# import torch
|