|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
import base64 |
|
|
import io |
|
|
import logging |
|
|
from PIL import Image |
|
|
from pydantic import BaseModel |
|
|
from fastapi import Request, HTTPException |
|
|
import json |
|
|
from typing import Optional, Union, Dict, Any |
|
|
from transformers import AutoProcessor, AutoModel |
|
|
|
|
|
|
|
|
class EmbeddingRequest(BaseModel): |
|
|
inputs: str |
|
|
parameters: Optional[dict] = None |
|
|
|
|
|
|
|
|
class BaseEmbeddingTaskService: |
|
|
"""Base class for embedding services with common functionality""" |
|
|
|
|
|
def __init__(self, logger: logging.Logger): |
|
|
self._logger = logger |
|
|
self._model_cache = {} |
|
|
self._processor_cache = {} |
|
|
|
|
|
async def get_embedding_request(self, request: Request) -> EmbeddingRequest: |
|
|
"""Parse request body into EmbeddingRequest""" |
|
|
content_type = request.headers.get("content-type", "") |
|
|
if content_type.startswith("application/json"): |
|
|
data = await request.json() |
|
|
return EmbeddingRequest(**data) |
|
|
if content_type.startswith("application/x-www-form-urlencoded"): |
|
|
raw = await request.body() |
|
|
try: |
|
|
data = json.loads(raw) |
|
|
return EmbeddingRequest(**data) |
|
|
except Exception: |
|
|
try: |
|
|
data = json.loads(raw.decode("utf-8")) |
|
|
return EmbeddingRequest(**data) |
|
|
except Exception: |
|
|
raise HTTPException(status_code=400, detail="Invalid request body") |
|
|
raise HTTPException(status_code=400, detail="Unsupported content type") |
|
|
|
|
|
def _get_device(self) -> torch.device: |
|
|
"""Get the appropriate device (GPU if available, otherwise CPU)""" |
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
self._logger.info(f"Using device: {device}") |
|
|
return device |
|
|
|
|
|
def _load_processor(self, model_name: str): |
|
|
"""Load and cache processor for the model using AutoProcessor""" |
|
|
if model_name not in self._processor_cache: |
|
|
try: |
|
|
self._processor_cache[model_name] = AutoProcessor.from_pretrained(model_name) |
|
|
self._logger.info(f"Loaded processor for model: {model_name}") |
|
|
except Exception as e: |
|
|
self._logger.error(f"Failed to load processor for model '{model_name}': {str(e)}") |
|
|
raise HTTPException( |
|
|
status_code=404, |
|
|
detail=f"Processor for model '{model_name}' could not be loaded: {str(e)}" |
|
|
) |
|
|
return self._processor_cache[model_name] |
|
|
|
|
|
def _load_model(self, model_name: str, cache_suffix: str = ""): |
|
|
"""Load and cache model using AutoModel""" |
|
|
cache_key = f"{model_name}{cache_suffix}" |
|
|
if cache_key not in self._model_cache: |
|
|
try: |
|
|
device = self._get_device() |
|
|
model = AutoModel.from_pretrained(model_name) |
|
|
model.to(device) |
|
|
self._model_cache[cache_key] = model |
|
|
self._logger.info(f"Loaded model: {model_name} on {device}") |
|
|
except Exception as e: |
|
|
self._logger.error(f"Failed to load model '{model_name}': {str(e)}") |
|
|
raise HTTPException( |
|
|
status_code=404, |
|
|
detail=f"Model '{model_name}' could not be loaded: {str(e)}" |
|
|
) |
|
|
return self._model_cache[cache_key] |
|
|
|
|
|
async def get_embedding_vector_size(self, model_name: str) -> dict: |
|
|
"""Get the vector size of embeddings for a given model""" |
|
|
try: |
|
|
|
|
|
model = self._load_model(model_name) |
|
|
|
|
|
|
|
|
used_attribute = None |
|
|
if hasattr(model.config, 'hidden_size'): |
|
|
vector_size = model.config.hidden_size |
|
|
used_attribute = "hidden_size" |
|
|
elif hasattr(model.config, 'projection_dim'): |
|
|
vector_size = model.config.projection_dim |
|
|
used_attribute = "projection_dim" |
|
|
elif hasattr(model.config, 'd_model'): |
|
|
vector_size = model.config.d_model |
|
|
used_attribute = "d_model" |
|
|
elif hasattr(model.config, 'text_config') and hasattr(model.config.text_config, 'hidden_size'): |
|
|
vector_size = model.config.text_config.hidden_size |
|
|
used_attribute = "text_config.hidden_size" |
|
|
elif hasattr(model.config, 'vision_config') and hasattr(model.config.vision_config, 'hidden_size'): |
|
|
vector_size = model.config.vision_config.hidden_size |
|
|
used_attribute = "vision_config.hidden_size" |
|
|
else: |
|
|
|
|
|
raise AttributeError("Could not determine vector size from model configuration") |
|
|
|
|
|
self._logger.info(f"Model {model_name} has embedding vector size: {vector_size}") |
|
|
return { |
|
|
"model_name": model_name, |
|
|
"vector_size": vector_size, |
|
|
"config_attribute_used": used_attribute |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
self._logger.error(f"Failed to get vector size for model '{model_name}': {str(e)}") |
|
|
raise HTTPException( |
|
|
status_code=404, |
|
|
detail=f"Could not determine vector size for model '{model_name}': {str(e)}" |
|
|
) |
|
|
|
|
|
def _extract_embeddings(self, model_output, model_name: str) -> torch.Tensor: |
|
|
"""Extract embeddings from model output with fallback strategies""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if hasattr(model_output, 'pooler_output') and model_output.pooler_output is not None: |
|
|
self._logger.debug(f"Using pooler_output for {model_name}") |
|
|
return model_output.pooler_output |
|
|
|
|
|
|
|
|
if hasattr(model_output, 'last_hidden_state') and model_output.last_hidden_state is not None: |
|
|
self._logger.debug(f"Using pooled last_hidden_state for {model_name}") |
|
|
|
|
|
return model_output.last_hidden_state.mean(dim=1) |
|
|
|
|
|
|
|
|
if hasattr(model_output, 'image_embeds') and model_output.image_embeds is not None: |
|
|
self._logger.debug(f"Using image_embeds for {model_name}") |
|
|
return model_output.image_embeds |
|
|
|
|
|
|
|
|
if hasattr(model_output, 'text_embeds') and model_output.text_embeds is not None: |
|
|
self._logger.debug(f"Using text_embeds for {model_name}") |
|
|
return model_output.text_embeds |
|
|
|
|
|
|
|
|
if isinstance(model_output, torch.Tensor): |
|
|
self._logger.debug(f"Using direct tensor output for {model_name}") |
|
|
return model_output |
|
|
|
|
|
|
|
|
if isinstance(model_output, tuple) and len(model_output) > 0: |
|
|
self._logger.debug(f"Using first element of tuple output for {model_name}") |
|
|
return model_output[0] |
|
|
|
|
|
|
|
|
raise HTTPException( |
|
|
status_code=500, |
|
|
detail=f"Could not extract embeddings from model output for {model_name}. " |
|
|
f"Available attributes: {dir(model_output) if hasattr(model_output, '__dict__') else 'Unknown'}" |
|
|
) |
|
|
|
|
|
|
|
|
class ImageEmbeddingTaskService(BaseEmbeddingTaskService): |
|
|
"""Service for generating image embeddings""" |
|
|
|
|
|
def _decode_base64_image(self, base64_string: str) -> Image.Image: |
|
|
"""Decode base64 string to PIL Image""" |
|
|
try: |
|
|
|
|
|
if base64_string.startswith('data:image'): |
|
|
base64_string = base64_string.split(',')[1] |
|
|
|
|
|
image_data = base64.b64decode(base64_string) |
|
|
image = Image.open(io.BytesIO(image_data)) |
|
|
|
|
|
|
|
|
if image.mode != 'RGB': |
|
|
image = image.convert('RGB') |
|
|
|
|
|
return image |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=400, detail=f"Invalid image data: {str(e)}") |
|
|
|
|
|
def _generate_image_embeddings(self, image: Image.Image, model, processor, model_name: str) -> list: |
|
|
"""Generate embeddings for an image""" |
|
|
device = self._get_device() |
|
|
|
|
|
|
|
|
inputs = processor(images=image, return_tensors="pt", padding=True) |
|
|
|
|
|
|
|
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
if hasattr(model, 'get_image_features'): |
|
|
self._logger.debug(f"Using get_image_features for {model_name}") |
|
|
embeddings = model.get_image_features(pixel_values=inputs.get('pixel_values')) |
|
|
elif hasattr(model, 'vision_model'): |
|
|
self._logger.debug(f"Using vision_model for {model_name}") |
|
|
vision_outputs = model.vision_model(**inputs) |
|
|
embeddings = self._extract_embeddings(vision_outputs, model_name) |
|
|
else: |
|
|
self._logger.debug(f"Using full model for {model_name}") |
|
|
outputs = model(**inputs) |
|
|
embeddings = self._extract_embeddings(outputs, model_name) |
|
|
|
|
|
self._logger.info(f"Image embedding shape: {embeddings.shape}") |
|
|
|
|
|
|
|
|
embeddings_array = embeddings.cpu().numpy() |
|
|
|
|
|
return embeddings_array[0].tolist() |
|
|
|
|
|
async def generate_embedding(self, request: Request, model_name: str): |
|
|
"""Main method to generate image embeddings""" |
|
|
embedding_request: EmbeddingRequest = await self.get_embedding_request(request) |
|
|
|
|
|
self._logger.info(f"Generating image embedding for model: {model_name}") |
|
|
|
|
|
|
|
|
processor = self._load_processor(model_name) |
|
|
model = self._load_model(model_name, "_image") |
|
|
|
|
|
|
|
|
image = self._decode_base64_image(embedding_request.inputs) |
|
|
|
|
|
try: |
|
|
|
|
|
embeddings = self._generate_image_embeddings(image, model, processor, model_name) |
|
|
|
|
|
self._logger.info("Image embedding generation completed") |
|
|
return {"embeddings": embeddings} |
|
|
|
|
|
except Exception as e: |
|
|
self._logger.error(f"Embedding generation failed for model '{model_name}': {str(e)}") |
|
|
raise HTTPException( |
|
|
status_code=500, |
|
|
detail=f"Embedding generation failed: {str(e)}" |
|
|
) |
|
|
|
|
|
async def generate_embedding_from_upload(self, uploaded_file, model_name: str): |
|
|
"""Generate image embeddings from uploaded file""" |
|
|
from fastapi import UploadFile |
|
|
|
|
|
self._logger.info(f"Generating image embedding from uploaded file for model: {model_name}") |
|
|
|
|
|
|
|
|
if not uploaded_file.content_type.startswith('image/'): |
|
|
raise HTTPException( |
|
|
status_code=400, |
|
|
detail=f"Invalid file type: {uploaded_file.content_type}. Only image files are supported." |
|
|
) |
|
|
|
|
|
try: |
|
|
|
|
|
file_content = await uploaded_file.read() |
|
|
|
|
|
|
|
|
image = Image.open(io.BytesIO(file_content)).convert('RGB') |
|
|
|
|
|
|
|
|
processor = self._load_processor(model_name) |
|
|
model = self._load_model(model_name, "_image") |
|
|
|
|
|
|
|
|
embeddings = self._generate_image_embeddings(image, model, processor, model_name) |
|
|
|
|
|
self._logger.info("Image embedding generation from upload completed") |
|
|
return {"embeddings": embeddings} |
|
|
|
|
|
except Exception as e: |
|
|
self._logger.error(f"Embedding generation from upload failed for model '{model_name}': {str(e)}") |
|
|
raise HTTPException( |
|
|
status_code=500, |
|
|
detail=f"Embedding generation from upload failed: {str(e)}" |
|
|
) |
|
|
|
|
|
|
|
|
class TextEmbeddingTaskService(BaseEmbeddingTaskService): |
|
|
"""Service for generating text embeddings""" |
|
|
|
|
|
def _generate_text_embeddings(self, text: str, model, processor, model_name: str) -> list: |
|
|
"""Generate embeddings for text""" |
|
|
device = self._get_device() |
|
|
|
|
|
|
|
|
inputs = processor(text=[text], return_tensors="pt", padding=True, truncation=True) |
|
|
|
|
|
|
|
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
if hasattr(model, 'get_text_features'): |
|
|
self._logger.debug(f"Using get_text_features for {model_name}") |
|
|
embeddings = model.get_text_features( |
|
|
input_ids=inputs.get('input_ids'), |
|
|
attention_mask=inputs.get('attention_mask') |
|
|
) |
|
|
elif hasattr(model, 'text_model'): |
|
|
self._logger.debug(f"Using text_model for {model_name}") |
|
|
text_outputs = model.text_model(**inputs) |
|
|
embeddings = self._extract_embeddings(text_outputs, model_name) |
|
|
else: |
|
|
self._logger.debug(f"Using full model for {model_name}") |
|
|
outputs = model(**inputs) |
|
|
embeddings = self._extract_embeddings(outputs, model_name) |
|
|
|
|
|
self._logger.info(f"Text embedding shape: {embeddings.shape}") |
|
|
|
|
|
|
|
|
embeddings_array = embeddings.cpu().numpy() |
|
|
|
|
|
return embeddings_array[0].tolist() |
|
|
|
|
|
async def generate_embedding(self, request: Request, model_name: str): |
|
|
"""Main method to generate text embeddings""" |
|
|
embedding_request: EmbeddingRequest = await self.get_embedding_request(request) |
|
|
|
|
|
self._logger.info(f"Generating text embedding for: {embedding_request.inputs[:50]}...") |
|
|
|
|
|
|
|
|
processor = self._load_processor(model_name) |
|
|
model = self._load_model(model_name, "_text") |
|
|
|
|
|
try: |
|
|
|
|
|
embeddings = self._generate_text_embeddings(embedding_request.inputs, model, processor, model_name) |
|
|
|
|
|
self._logger.info("Text embedding generation completed") |
|
|
return {"embeddings": embeddings} |
|
|
|
|
|
except Exception as e: |
|
|
self._logger.error(f"Embedding generation failed for model '{model_name}': {str(e)}") |
|
|
raise HTTPException( |
|
|
status_code=500, |
|
|
detail=f"Embedding generation failed: {str(e)}" |
|
|
) |
|
|
|