arjunanand13 commited on
Commit
01fb00e
·
verified ·
1 Parent(s): f6f8735

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +106 -15
handler.py CHANGED
@@ -1,6 +1,10 @@
1
  import subprocess
2
  import sys
3
  import torch
 
 
 
 
4
  from transformers import AutoModelForCausalLM, AutoProcessor
5
 
6
  def install(package):
@@ -8,8 +12,7 @@ def install(package):
8
 
9
  class EndpointHandler:
10
  def __init__(self, path=""):
11
-
12
- required_packages = ['timm', 'einops', 'flash-attn']
13
  for package in required_packages:
14
  try:
15
  install(package)
@@ -17,11 +20,9 @@ class EndpointHandler:
17
  except Exception as e:
18
  print(f"Failed to install {package}: {str(e)}")
19
 
20
-
21
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
  print(f"Using device: {self.device}")
23
 
24
-
25
  self.model_name = "microsoft/Florence-2-base-ft"
26
  self.model = AutoModelForCausalLM.from_pretrained(
27
  self.model_name,
@@ -35,28 +36,118 @@ class EndpointHandler:
35
  revision='refs/pr/6'
36
  )
37
 
38
-
39
  if torch.cuda.is_available():
40
  torch.cuda.empty_cache()
41
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  def __call__(self, data):
43
  try:
 
 
 
 
 
 
44
 
45
- inputs = data.pop("inputs", data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
 
48
- processed_inputs = self.processor(inputs, return_tensors="pt")
49
 
50
 
51
- processed_inputs = {k: v.to(self.device) for k, v in processed_inputs.items()}
52
 
53
 
54
- with torch.no_grad():
55
- outputs = self.model.generate(**processed_inputs)
56
 
57
 
58
- decoded_outputs = self.processor.batch_decode(outputs, skip_special_tokens=True)
59
 
60
- return {"outputs": decoded_outputs}
61
- except Exception as e:
62
- return {"error": str(e)}
 
1
  import subprocess
2
  import sys
3
  import torch
4
+ import base64
5
+ from io import BytesIO
6
+ from PIL import Image
7
+ import requests
8
  from transformers import AutoModelForCausalLM, AutoProcessor
9
 
10
  def install(package):
 
12
 
13
  class EndpointHandler:
14
  def __init__(self, path=""):
15
+ required_packages = ['timm', 'einops', 'flash-attn', 'Pillow']
 
16
  for package in required_packages:
17
  try:
18
  install(package)
 
20
  except Exception as e:
21
  print(f"Failed to install {package}: {str(e)}")
22
 
 
23
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
  print(f"Using device: {self.device}")
25
 
 
26
  self.model_name = "microsoft/Florence-2-base-ft"
27
  self.model = AutoModelForCausalLM.from_pretrained(
28
  self.model_name,
 
36
  revision='refs/pr/6'
37
  )
38
 
 
39
  if torch.cuda.is_available():
40
  torch.cuda.empty_cache()
41
+
42
+ def process_image(self, image_input):
43
+ if isinstance(image_input, str):
44
+ # Check if it's a URL
45
+ if image_input.startswith('http://') or image_input.startswith('https://'):
46
+ image = Image.open(requests.get(image_input, stream=True).raw)
47
+ # Check if it's a base64 string
48
+ elif image_input.startswith('data:image'):
49
+ image_data = base64.b64decode(image_input.split(',')[1])
50
+ image = Image.open(BytesIO(image_data))
51
+ else:
52
+ raise ValueError("Invalid image input")
53
+ elif isinstance(image_input, bytes):
54
+ image = Image.open(BytesIO(image_input))
55
+ else:
56
+ raise ValueError("Unsupported image input type")
57
+
58
+ return image
59
+
60
  def __call__(self, data):
61
  try:
62
+ # Handle different input formats
63
+ image_input = data.pop("image", None)
64
+ text_input = data.pop("text", "")
65
+
66
+ # Process image if provided
67
+ image = self.process_image(image_input) if image_input else None
68
 
69
+ # Prepare inputs
70
+ inputs = self.processor(
71
+ images=image if image else None,
72
+ text=text_input,
73
+ return_tensors="pt"
74
+ )
75
+
76
+ # Move inputs to device
77
+ inputs = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
78
+ for k, v in inputs.items()}
79
+
80
+ # Generate output
81
+ with torch.no_grad():
82
+ outputs = self.model.generate(**inputs)
83
+
84
+ # Decode outputs
85
+ decoded_outputs = self.processor.batch_decode(outputs, skip_special_tokens=True)
86
+
87
+ return {"generated_text": decoded_outputs[0]}
88
+
89
+ except Exception as e:
90
+ return {"error": str(e)}
91
+
92
+ # import subprocess
93
+ # import sys
94
+ # import torch
95
+ # from transformers import AutoModelForCausalLM, AutoProcessor
96
+
97
+ # def install(package):
98
+ # subprocess.check_call([sys.executable, "-m", "pip", "install", "--no-warn-script-location", package])
99
+
100
+ # class EndpointHandler:
101
+ # def __init__(self, path=""):
102
+
103
+ # required_packages = ['timm', 'einops', 'flash-attn']
104
+ # for package in required_packages:
105
+ # try:
106
+ # install(package)
107
+ # print(f"Successfully installed {package}")
108
+ # except Exception as e:
109
+ # print(f"Failed to install {package}: {str(e)}")
110
+
111
+
112
+ # self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
113
+ # print(f"Using device: {self.device}")
114
+
115
+
116
+ # self.model_name = "microsoft/Florence-2-base-ft"
117
+ # self.model = AutoModelForCausalLM.from_pretrained(
118
+ # self.model_name,
119
+ # trust_remote_code=True,
120
+ # revision='refs/pr/6'
121
+ # ).to(self.device)
122
+
123
+ # self.processor = AutoProcessor.from_pretrained(
124
+ # self.model_name,
125
+ # trust_remote_code=True,
126
+ # revision='refs/pr/6'
127
+ # )
128
+
129
+
130
+ # if torch.cuda.is_available():
131
+ # torch.cuda.empty_cache()
132
+
133
+ # def __call__(self, data):
134
+ # try:
135
+
136
+ # inputs = data.pop("inputs", data)
137
 
138
 
139
+ # processed_inputs = self.processor(inputs, return_tensors="pt")
140
 
141
 
142
+ # processed_inputs = {k: v.to(self.device) for k, v in processed_inputs.items()}
143
 
144
 
145
+ # with torch.no_grad():
146
+ # outputs = self.model.generate(**processed_inputs)
147
 
148
 
149
+ # decoded_outputs = self.processor.batch_decode(outputs, skip_special_tokens=True)
150
 
151
+ # return {"outputs": decoded_outputs}
152
+ # except Exception as e:
153
+ # return {"error": str(e)}