Florence-enphase2 / handler.py
arjunanand13's picture
Update handler.py
321843f verified
raw
history blame
5.17 kB
import subprocess
import sys
import torch
import base64
from io import BytesIO
from PIL import Image
import requests
from transformers import AutoModelForCausalLM, AutoProcessor
def install(package):
subprocess.check_call([sys.executable, "-m", "pip", "install", "--no-warn-script-location", package])
class EndpointHandler:
def __init__(self, path=""):
required_packages = ['timm', 'einops', 'flash-attn', 'Pillow']
for package in required_packages:
try:
install(package)
print(f"Successfully installed {package}")
except Exception as e:
print(f"Failed to install {package}: {str(e)}")
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {self.device}")
self.model_name = "microsoft/Florence-2-base-ft"
self.model = AutoModelForCausalLM.from_pretrained(
self.model_name,
trust_remote_code=True,
revision='refs/pr/6'
).to(self.device)
self.processor = AutoProcessor.from_pretrained(
self.model_name,
trust_remote_code=True,
revision='refs/pr/6'
)
if torch.cuda.is_available():
torch.cuda.empty_cache()
def process_image(self, image_path):
try:
with open(image_path, 'rb') as image_file:
image = Image.open(image_file)
return image
except Exception as e:
print(f"Error processing image: {str(e)}")
return None
def __call__(self, data):
try:
# Extract inputs from the expected Hugging Face format
inputs = data.pop("inputs", data)
# Check if inputs is a dict or string
if isinstance(inputs, dict):
image_path = inputs.get("image", None)
text_input = inputs.get("text", "")
else:
# If inputs is not a dict, assume it's the image path
image_path = inputs
text_input = "What is in this image?"
# Process image
image = self.process_image(image_path) if image_path else None
# Prepare inputs for the model
model_inputs = self.processor(
images=image if image else None,
text=text_input,
return_tensors="pt"
)
# Move inputs to device
model_inputs = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
for k, v in model_inputs.items()}
# Generate output
with torch.no_grad():
outputs = self.model.generate(**model_inputs)
# Decode outputs
decoded_outputs = self.processor.batch_decode(outputs, skip_special_tokens=True)
return {"generated_text": decoded_outputs[0]}
except Exception as e:
return {"error": str(e)}
# import subprocess
# import sys
# import torch
# from transformers import AutoModelForCausalLM, AutoProcessor
# def install(package):
# subprocess.check_call([sys.executable, "-m", "pip", "install", "--no-warn-script-location", package])
# class EndpointHandler:
# def __init__(self, path=""):
# required_packages = ['timm', 'einops', 'flash-attn']
# for package in required_packages:
# try:
# install(package)
# print(f"Successfully installed {package}")
# except Exception as e:
# print(f"Failed to install {package}: {str(e)}")
# self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# print(f"Using device: {self.device}")
# self.model_name = "microsoft/Florence-2-base-ft"
# self.model = AutoModelForCausalLM.from_pretrained(
# self.model_name,
# trust_remote_code=True,
# revision='refs/pr/6'
# ).to(self.device)
# self.processor = AutoProcessor.from_pretrained(
# self.model_name,
# trust_remote_code=True,
# revision='refs/pr/6'
# )
# if torch.cuda.is_available():
# torch.cuda.empty_cache()
# def __call__(self, data):
# try:
# inputs = data.pop("inputs", data)
# processed_inputs = self.processor(inputs, return_tensors="pt")
# processed_inputs = {k: v.to(self.device) for k, v in processed_inputs.items()}
# with torch.no_grad():
# outputs = self.model.generate(**processed_inputs)
# decoded_outputs = self.processor.batch_decode(outputs, skip_special_tokens=True)
# return {"outputs": decoded_outputs}
# except Exception as e:
# return {"error": str(e)}