davanstrien HF staff commited on
Commit
d869c4e
1 Parent(s): 6c4cdba

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +17 -16
handler.py CHANGED
@@ -4,6 +4,8 @@ from PIL import Image
4
  import requests
5
  import torch
6
  import gc
 
 
7
 
8
  class EndpointHandler:
9
  def __init__(self, path=""):
@@ -22,27 +24,30 @@ class EndpointHandler:
22
  )
23
 
24
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
25
- # Clear CUDA cache
26
  torch.cuda.empty_cache()
27
  gc.collect()
28
 
29
- # Extract inputs from the request data
30
  inputs = data.get("inputs", {})
31
  image_url = inputs.get("image_url")
 
32
  text_prompt = inputs.get("text_prompt", "Describe this image.")
33
 
34
- if not image_url:
35
- return [{"error": "No image_url provided in inputs"}]
 
 
 
 
 
 
 
 
 
 
36
 
37
- # Download and process the image
38
- try:
39
- image = Image.open(requests.get(image_url, stream=True).raw)
40
- if image.mode != "RGB":
41
- image = image.convert("RGB")
42
- except Exception as e:
43
- return [{"error": f"Failed to load image: {str(e)}"}]
44
 
45
- # Process the image and text
46
  try:
47
  with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16):
48
  inputs = self.processor.process(
@@ -50,21 +55,17 @@ class EndpointHandler:
50
  text=text_prompt
51
  )
52
 
53
- # Move inputs to the correct device and make a batch of size 1
54
  inputs = {k: v.to(self.model.device).unsqueeze(0) for k, v in inputs.items()}
55
 
56
- # Generate output
57
  output = self.model.generate_from_batch(
58
  inputs,
59
  GenerationConfig(max_new_tokens=200, stop_strings="<|endoftext|>"),
60
  tokenizer=self.processor.tokenizer
61
  )
62
 
63
- # Decode the generated tokens
64
  generated_tokens = output[0, inputs['input_ids'].size(1):]
65
  generated_text = self.processor.tokenizer.decode(generated_tokens, skip_special_tokens=True)
66
 
67
- # Clear CUDA cache again
68
  torch.cuda.empty_cache()
69
  gc.collect()
70
 
 
4
  import requests
5
  import torch
6
  import gc
7
+ import base64
8
+ import io
9
 
10
  class EndpointHandler:
11
  def __init__(self, path=""):
 
24
  )
25
 
26
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
 
27
  torch.cuda.empty_cache()
28
  gc.collect()
29
 
 
30
  inputs = data.get("inputs", {})
31
  image_url = inputs.get("image_url")
32
+ image_data = inputs.get("image")
33
  text_prompt = inputs.get("text_prompt", "Describe this image.")
34
 
35
+ if image_url:
36
+ try:
37
+ image = Image.open(requests.get(image_url, stream=True).raw)
38
+ except Exception as e:
39
+ return [{"error": f"Failed to load image from URL: {str(e)}"}]
40
+ elif image_data:
41
+ try:
42
+ image = Image.open(io.BytesIO(base64.b64decode(image_data)))
43
+ except Exception as e:
44
+ return [{"error": f"Failed to decode image data: {str(e)}"}]
45
+ else:
46
+ return [{"error": "No image_url or image data provided in inputs"}]
47
 
48
+ if image.mode != "RGB":
49
+ image = image.convert("RGB")
 
 
 
 
 
50
 
 
51
  try:
52
  with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16):
53
  inputs = self.processor.process(
 
55
  text=text_prompt
56
  )
57
 
 
58
  inputs = {k: v.to(self.model.device).unsqueeze(0) for k, v in inputs.items()}
59
 
 
60
  output = self.model.generate_from_batch(
61
  inputs,
62
  GenerationConfig(max_new_tokens=200, stop_strings="<|endoftext|>"),
63
  tokenizer=self.processor.tokenizer
64
  )
65
 
 
66
  generated_tokens = output[0, inputs['input_ids'].size(1):]
67
  generated_text = self.processor.tokenizer.decode(generated_tokens, skip_special_tokens=True)
68
 
 
69
  torch.cuda.empty_cache()
70
  gc.collect()
71