|
import subprocess |
|
import sys |
|
import torch |
|
import base64 |
|
from io import BytesIO |
|
from PIL import Image |
|
import requests |
|
from transformers import AutoModelForCausalLM, AutoProcessor |
|
|
|
def install(package): |
|
subprocess.check_call([sys.executable, "-m", "pip", "install", "--no-warn-script-location", package]) |
|
|
|
class EndpointHandler: |
|
def __init__(self, path=""): |
|
required_packages = ['timm', 'einops', 'flash-attn', 'Pillow'] |
|
for package in required_packages: |
|
try: |
|
install(package) |
|
print(f"Successfully installed {package}") |
|
except Exception as e: |
|
print(f"Failed to install {package}: {str(e)}") |
|
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
print(f"Using device: {self.device}") |
|
|
|
self.model_name = "microsoft/Florence-2-base-ft" |
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
self.model_name, |
|
trust_remote_code=True, |
|
revision='refs/pr/6' |
|
).to(self.device) |
|
|
|
self.processor = AutoProcessor.from_pretrained( |
|
self.model_name, |
|
trust_remote_code=True, |
|
revision='refs/pr/6' |
|
) |
|
|
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
|
|
def process_image(self, image_path): |
|
try: |
|
with open(image_path, 'rb') as image_file: |
|
image = Image.open(image_file) |
|
return image |
|
except Exception as e: |
|
print(f"Error processing image: {str(e)}") |
|
return None |
|
|
|
def __call__(self, data): |
|
try: |
|
|
|
inputs = data.pop("inputs", data) |
|
|
|
|
|
if isinstance(inputs, dict): |
|
image_path = inputs.get("image", None) |
|
text_input = inputs.get("text", "") |
|
else: |
|
|
|
image_path = inputs |
|
text_input = "What is in this image?" |
|
|
|
|
|
image = self.process_image(image_path) if image_path else None |
|
|
|
|
|
model_inputs = self.processor( |
|
images=image if image else None, |
|
text=text_input, |
|
return_tensors="pt" |
|
) |
|
|
|
|
|
model_inputs = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v |
|
for k, v in model_inputs.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = self.model.generate(**model_inputs) |
|
|
|
|
|
decoded_outputs = self.processor.batch_decode(outputs, skip_special_tokens=True) |
|
|
|
return {"generated_text": decoded_outputs[0]} |
|
|
|
except Exception as e: |
|
return {"error": str(e)} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|