# ------------------------------------------------------------------- # This source file is available under the terms of the # Pimcore Open Core License (POCL) # Full copyright and license information is available in # LICENSE.md which is distributed with this source code. # # @copyright Copyright (c) Pimcore GmbH (https://www.pimcore.com) # @license Pimcore Open Core License (POCL) # ------------------------------------------------------------------- import torch import base64 import io import logging from PIL import Image from pydantic import BaseModel from fastapi import Request, HTTPException import json from typing import Optional, Union, Dict, Any from transformers import AutoProcessor, AutoModel class EmbeddingRequest(BaseModel): inputs: str parameters: Optional[dict] = None class BaseEmbeddingTaskService: """Base class for embedding services with common functionality""" def __init__(self, logger: logging.Logger): self._logger = logger self._model_cache = {} self._processor_cache = {} async def get_embedding_request(self, request: Request) -> EmbeddingRequest: """Parse request body into EmbeddingRequest""" content_type = request.headers.get("content-type", "") if content_type.startswith("application/json"): data = await request.json() return EmbeddingRequest(**data) if content_type.startswith("application/x-www-form-urlencoded"): raw = await request.body() try: data = json.loads(raw) return EmbeddingRequest(**data) except Exception: try: data = json.loads(raw.decode("utf-8")) return EmbeddingRequest(**data) except Exception: raise HTTPException(status_code=400, detail="Invalid request body") raise HTTPException(status_code=400, detail="Unsupported content type") def _get_device(self) -> torch.device: """Get the appropriate device (GPU if available, otherwise CPU)""" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self._logger.info(f"Using device: {device}") return device def _load_processor(self, model_name: str): """Load and cache processor for the model using AutoProcessor""" if model_name not in self._processor_cache: try: self._processor_cache[model_name] = AutoProcessor.from_pretrained(model_name, trust_remote_code=True) self._logger.info(f"Loaded processor for model: {model_name}") except Exception as e: self._logger.error(f"Failed to load processor for model '{model_name}': {str(e)}") raise HTTPException( status_code=404, detail=f"Processor for model '{model_name}' could not be loaded: {str(e)}" ) else: self._logger.info(f"Using cached processor for model: {model_name}") return self._processor_cache[model_name] def _load_model(self, model_name: str, cache_suffix: str = ""): """Load and cache model using AutoModel""" cache_key = f"{model_name}{cache_suffix}" if cache_key not in self._model_cache: try: device = self._get_device() model = AutoModel.from_pretrained(model_name, trust_remote_code=True) model.to(device) self._model_cache[cache_key] = model self._logger.info(f"Loaded model: {model_name} on {device}") except Exception as e: self._logger.error(f"Failed to load model '{model_name}': {str(e)}") raise HTTPException( status_code=404, detail=f"Model '{model_name}' could not be loaded: {str(e)}" ) else: self._logger.info(f"Using cached model: {model_name} (cache key: {cache_key})") return self._model_cache[cache_key] async def get_embedding_vector_size(self, model_name: str) -> dict: """Get the vector size of embeddings for a given model""" try: # Load the model to get its configuration model = self._load_model(model_name) # Try to get the embedding dimension from the model configuration used_attribute = None if hasattr(model.config, 'hidden_size'): vector_size = model.config.hidden_size used_attribute = "hidden_size" elif hasattr(model.config, 'projection_dim'): vector_size = model.config.projection_dim used_attribute = "projection_dim" elif hasattr(model.config, 'd_model'): vector_size = model.config.d_model used_attribute = "d_model" elif hasattr(model.config, 'text_config') and hasattr(model.config.text_config, 'hidden_size'): vector_size = model.config.text_config.hidden_size used_attribute = "text_config.hidden_size" elif hasattr(model.config, 'vision_config') and hasattr(model.config.vision_config, 'hidden_size'): vector_size = model.config.vision_config.hidden_size used_attribute = "vision_config.hidden_size" else: # If we can't determine from config, we'll need to run a dummy inference raise AttributeError("Could not determine vector size from model configuration") self._logger.info(f"Model {model_name} has embedding vector size: {vector_size}") return { "model_name": model_name, "vector_size": vector_size, "config_attribute_used": used_attribute } except Exception as e: self._logger.error(f"Failed to get vector size for model '{model_name}': {str(e)}") raise HTTPException( status_code=404, detail=f"Could not determine vector size for model '{model_name}': {str(e)}" ) def _extract_embeddings(self, model_output, model_name: str) -> torch.Tensor: """Extract embeddings from model output with fallback strategies""" # Try different embedding extraction methods in order of preference # 1. Check for pooler_output (most common) if hasattr(model_output, 'pooler_output') and model_output.pooler_output is not None: self._logger.debug(f"Using pooler_output for {model_name}") return model_output.pooler_output # 2. Check for last_hidden_state and pool it if hasattr(model_output, 'last_hidden_state') and model_output.last_hidden_state is not None: self._logger.debug(f"Using pooled last_hidden_state for {model_name}") # Mean pooling over sequence dimension return model_output.last_hidden_state.mean(dim=1) # 3. Check for image_embeds (CLIP-style models) if hasattr(model_output, 'image_embeds') and model_output.image_embeds is not None: self._logger.debug(f"Using image_embeds for {model_name}") return model_output.image_embeds # 4. Check for text_embeds (CLIP-style models) if hasattr(model_output, 'text_embeds') and model_output.text_embeds is not None: self._logger.debug(f"Using text_embeds for {model_name}") return model_output.text_embeds # 5. Fallback: try to use the output directly if it's a tensor if isinstance(model_output, torch.Tensor): self._logger.debug(f"Using direct tensor output for {model_name}") return model_output # 6. Last resort: check if output is a tuple and use the first element if isinstance(model_output, tuple) and len(model_output) > 0: self._logger.debug(f"Using first element of tuple output for {model_name}") return model_output[0] # If none of the above work, raise an error raise HTTPException( status_code=500, detail=f"Could not extract embeddings from model output for {model_name}. " f"Available attributes: {dir(model_output) if hasattr(model_output, '__dict__') else 'Unknown'}" ) class ImageEmbeddingTaskService(BaseEmbeddingTaskService): """Service for generating image embeddings""" def _decode_base64_image(self, base64_string: str) -> Image.Image: """Decode base64 string to PIL Image""" try: # Remove data URL prefix if present if base64_string.startswith('data:image'): base64_string = base64_string.split(',')[1] image_data = base64.b64decode(base64_string) image = Image.open(io.BytesIO(image_data)) # Convert to RGB if necessary if image.mode != 'RGB': image = image.convert('RGB') return image except Exception as e: raise HTTPException(status_code=400, detail=f"Invalid image data: {str(e)}") def _generate_image_embeddings(self, image: Image.Image, model, processor, model_name: str) -> list: """Generate embeddings for an image""" device = self._get_device() # Process the image inputs = processor(images=image, return_tensors="pt", padding=True) # Move inputs to the same device as the model inputs = {k: v.to(device) for k, v in inputs.items()} # Get the embeddings with torch.no_grad(): # Try using specialized methods first for CLIP-like models if hasattr(model, 'get_image_features'): self._logger.debug(f"Using get_image_features for {model_name}") embeddings = model.get_image_features(pixel_values=inputs.get('pixel_values')) elif hasattr(model, 'vision_model'): self._logger.debug(f"Using vision_model for {model_name}") vision_outputs = model.vision_model(**inputs) embeddings = self._extract_embeddings(vision_outputs, model_name) else: self._logger.debug(f"Using full model for {model_name}") outputs = model(**inputs) embeddings = self._extract_embeddings(outputs, model_name) self._logger.info(f"Image embedding shape: {embeddings.shape}") # Move back to CPU before converting to numpy embeddings_array = embeddings.cpu().numpy() return embeddings_array[0].tolist() async def generate_embedding(self, request: Request, model_name: str): """Main method to generate image embeddings""" embedding_request: EmbeddingRequest = await self.get_embedding_request(request) self._logger.info(f"Generating image embedding for model: {model_name}") # Load processor and model using auto-detection processor = self._load_processor(model_name) model = self._load_model(model_name, "_image") # Decode image from base64 image = self._decode_base64_image(embedding_request.inputs) try: # Generate embeddings embeddings = self._generate_image_embeddings(image, model, processor, model_name) self._logger.info("Image embedding generation completed") return {"embeddings": embeddings} except Exception as e: self._logger.error(f"Embedding generation failed for model '{model_name}': {str(e)}") raise HTTPException( status_code=500, detail=f"Embedding generation failed: {str(e)}" ) async def generate_embedding_from_upload(self, uploaded_file, model_name: str): """Generate image embeddings from uploaded file""" from fastapi import UploadFile self._logger.info(f"Generating image embedding from uploaded file for model: {model_name}") # Validate file type if not uploaded_file.content_type.startswith('image/'): raise HTTPException( status_code=400, detail=f"Invalid file type: {uploaded_file.content_type}. Only image files are supported." ) try: # Read file content file_content = await uploaded_file.read() # Convert to PIL Image image = Image.open(io.BytesIO(file_content)).convert('RGB') # Load processor and model using auto-detection processor = self._load_processor(model_name) model = self._load_model(model_name, "_image") # Generate embeddings embeddings = self._generate_image_embeddings(image, model, processor, model_name) self._logger.info("Image embedding generation from upload completed") return {"embeddings": embeddings} except Exception as e: self._logger.error(f"Embedding generation from upload failed for model '{model_name}': {str(e)}") raise HTTPException( status_code=500, detail=f"Embedding generation from upload failed: {str(e)}" ) class TextEmbeddingTaskService(BaseEmbeddingTaskService): """Service for generating text embeddings""" def _generate_text_embeddings(self, text: str, model, processor, model_name: str) -> list: """Generate embeddings for text""" device = self._get_device() # Process the text inputs = processor(text=[text], return_tensors="pt", padding=True, truncation=True) # Move inputs to the same device as the model inputs = {k: v.to(device) for k, v in inputs.items()} # Get the embeddings with torch.no_grad(): # Try using specialized methods first for CLIP-like models if hasattr(model, 'get_text_features'): self._logger.debug(f"Using get_text_features for {model_name}") embeddings = model.get_text_features( input_ids=inputs.get('input_ids'), attention_mask=inputs.get('attention_mask') ) elif hasattr(model, 'text_model'): self._logger.debug(f"Using text_model for {model_name}") text_outputs = model.text_model(**inputs) embeddings = self._extract_embeddings(text_outputs, model_name) else: self._logger.debug(f"Using full model for {model_name}") outputs = model(**inputs) embeddings = self._extract_embeddings(outputs, model_name) self._logger.info(f"Text embedding shape: {embeddings.shape}") # Move back to CPU before converting to numpy embeddings_array = embeddings.cpu().numpy() return embeddings_array[0].tolist() async def generate_embedding(self, request: Request, model_name: str): """Main method to generate text embeddings""" embedding_request: EmbeddingRequest = await self.get_embedding_request(request) self._logger.info(f"Generating text embedding for: {embedding_request.inputs[:500]}...") # Load processor and model using auto-detection processor = self._load_processor(model_name) model = self._load_model(model_name, "_text") try: # Generate embeddings embeddings = self._generate_text_embeddings(embedding_request.inputs, model, processor, model_name) self._logger.info("Text embedding generation completed") return {"embeddings": embeddings} except Exception as e: self._logger.error(f"Embedding generation failed for model '{model_name}': {str(e)}") raise HTTPException( status_code=500, detail=f"Embedding generation failed: {str(e)}" )