arjunanand13 commited on
Commit
321843f
1 Parent(s): 01fb00e

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +26 -28
handler.py CHANGED
@@ -39,47 +39,46 @@ class EndpointHandler:
39
  if torch.cuda.is_available():
40
  torch.cuda.empty_cache()
41
 
42
- def process_image(self, image_input):
43
- if isinstance(image_input, str):
44
- # Check if it's a URL
45
- if image_input.startswith('http://') or image_input.startswith('https://'):
46
- image = Image.open(requests.get(image_input, stream=True).raw)
47
- # Check if it's a base64 string
48
- elif image_input.startswith('data:image'):
49
- image_data = base64.b64decode(image_input.split(',')[1])
50
- image = Image.open(BytesIO(image_data))
51
- else:
52
- raise ValueError("Invalid image input")
53
- elif isinstance(image_input, bytes):
54
- image = Image.open(BytesIO(image_input))
55
- else:
56
- raise ValueError("Unsupported image input type")
57
-
58
- return image
59
 
60
  def __call__(self, data):
61
  try:
62
- # Handle different input formats
63
- image_input = data.pop("image", None)
64
- text_input = data.pop("text", "")
65
 
66
- # Process image if provided
67
- image = self.process_image(image_input) if image_input else None
 
 
 
 
 
 
 
 
 
68
 
69
- # Prepare inputs
70
- inputs = self.processor(
71
  images=image if image else None,
72
  text=text_input,
73
  return_tensors="pt"
74
  )
75
 
76
  # Move inputs to device
77
- inputs = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
78
- for k, v in inputs.items()}
79
 
80
  # Generate output
81
  with torch.no_grad():
82
- outputs = self.model.generate(**inputs)
83
 
84
  # Decode outputs
85
  decoded_outputs = self.processor.batch_decode(outputs, skip_special_tokens=True)
@@ -88,7 +87,6 @@ class EndpointHandler:
88
 
89
  except Exception as e:
90
  return {"error": str(e)}
91
-
92
  # import subprocess
93
  # import sys
94
  # import torch
 
39
  if torch.cuda.is_available():
40
  torch.cuda.empty_cache()
41
 
42
+ def process_image(self, image_path):
43
+ try:
44
+ with open(image_path, 'rb') as image_file:
45
+ image = Image.open(image_file)
46
+ return image
47
+ except Exception as e:
48
+ print(f"Error processing image: {str(e)}")
49
+ return None
 
 
 
 
 
 
 
 
 
50
 
51
  def __call__(self, data):
52
  try:
53
+ # Extract inputs from the expected Hugging Face format
54
+ inputs = data.pop("inputs", data)
 
55
 
56
+ # Check if inputs is a dict or string
57
+ if isinstance(inputs, dict):
58
+ image_path = inputs.get("image", None)
59
+ text_input = inputs.get("text", "")
60
+ else:
61
+ # If inputs is not a dict, assume it's the image path
62
+ image_path = inputs
63
+ text_input = "What is in this image?"
64
+
65
+ # Process image
66
+ image = self.process_image(image_path) if image_path else None
67
 
68
+ # Prepare inputs for the model
69
+ model_inputs = self.processor(
70
  images=image if image else None,
71
  text=text_input,
72
  return_tensors="pt"
73
  )
74
 
75
  # Move inputs to device
76
+ model_inputs = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
77
+ for k, v in model_inputs.items()}
78
 
79
  # Generate output
80
  with torch.no_grad():
81
+ outputs = self.model.generate(**model_inputs)
82
 
83
  # Decode outputs
84
  decoded_outputs = self.processor.batch_decode(outputs, skip_special_tokens=True)
 
87
 
88
  except Exception as e:
89
  return {"error": str(e)}
 
90
  # import subprocess
91
  # import sys
92
  # import torch