|
|
| import os |
| import logging |
| import torch |
| from typing import Optional, List, Dict, Any, Union |
| from fastapi import FastAPI, HTTPException, Form |
| from pydantic import BaseModel, Field |
| from dotenv import load_dotenv |
| from enum import Enum |
|
|
| from openai.types.create_embedding_response import CreateEmbeddingResponse, Usage |
| from openai.types.embedding import Embedding |
|
|
| import os |
| import torch |
| import torch.nn.functional as F |
| import unicodedata |
| import numpy as np |
| import logging |
|
|
| from PIL import Image |
| from urllib.parse import urlparse |
| from dataclasses import dataclass |
| from typing import Optional, List, Union, Dict, Any |
| from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLPreTrainedModel, Qwen3VLModel, Qwen3VLConfig |
| from transformers.models.qwen3_vl.processing_qwen3_vl import Qwen3VLProcessor |
| from transformers.modeling_outputs import ModelOutput |
| from transformers.processing_utils import Unpack |
| from transformers.utils import TransformersKwargs |
| from transformers.cache_utils import Cache |
| from transformers.utils.generic import check_model_inputs |
| from qwen_vl_utils.vision_process import process_vision_info |
|
|
| logger = logging.getLogger(__name__) |
|
|
| |
| MAX_LENGTH = 2048 |
| IMAGE_BASE_FACTOR = 16 |
| IMAGE_FACTOR = IMAGE_BASE_FACTOR * 2 |
| MIN_PIXELS = 4 * IMAGE_FACTOR * IMAGE_FACTOR |
| MAX_PIXELS = 1800 * IMAGE_FACTOR * IMAGE_FACTOR |
| FPS = 1 |
| MAX_FRAMES = 64 |
| FRAME_MAX_PIXELS = 768 * IMAGE_FACTOR * IMAGE_FACTOR |
| MAX_TOTAL_PIXELS = 10 * FRAME_MAX_PIXELS |
| PAD_TOKEN = "<|endoftext|>" |
|
|
| |
| @dataclass |
| class Qwen3VLForEmbeddingOutput(ModelOutput): |
| last_hidden_state: Optional[torch.FloatTensor] = None |
| attention_mask: Optional[torch.Tensor] = None |
|
|
| |
| class Qwen3VLForEmbedding(Qwen3VLPreTrainedModel): |
| _checkpoint_conversion_mapping = {} |
| accepts_loss_kwargs = False |
| config: Qwen3VLConfig |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.model = Qwen3VLModel(config) |
| self.post_init() |
|
|
| def get_input_embeddings(self): |
| return self.model.get_input_embeddings() |
|
|
| def set_input_embeddings(self, value): |
| self.model.set_input_embeddings(value) |
|
|
| def set_decoder(self, decoder): |
| self.model.set_decoder(decoder) |
|
|
| def get_decoder(self): |
| return self.model.get_decoder() |
|
|
| |
| def get_video_features(self, pixel_values_videos: torch.FloatTensor, |
| video_grid_thw: Optional[torch.LongTensor] = None): |
| return self.model.get_video_features(pixel_values_videos, video_grid_thw) |
|
|
| |
| def get_image_features(self, pixel_values: torch.FloatTensor, |
| image_grid_thw: Optional[torch.LongTensor] = None): |
| return self.model.get_image_features(pixel_values, image_grid_thw) |
|
|
| |
| @property |
| def language_model(self): |
| return self.model.language_model |
|
|
| @property |
| def visual(self): |
| return self.model.visual |
|
|
| |
| |
| def forward(self, |
| input_ids: torch.LongTensor = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[Cache] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| pixel_values: Optional[torch.Tensor] = None, |
| pixel_values_videos: Optional[torch.FloatTensor] = None, |
| image_grid_thw: Optional[torch.LongTensor] = None, |
| video_grid_thw: Optional[torch.LongTensor] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| logits_to_keep: Union[int, torch.Tensor] = 0, |
| **kwargs: Unpack[TransformersKwargs], |
| ) -> Union[tuple, Qwen3VLForEmbeddingOutput]: |
| |
| outputs = self.model( |
| input_ids=input_ids, |
| pixel_values=pixel_values, |
| pixel_values_videos=pixel_values_videos, |
| image_grid_thw=image_grid_thw, |
| video_grid_thw=video_grid_thw, |
| position_ids=position_ids, |
| attention_mask=attention_mask, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| cache_position=cache_position, |
| **kwargs, |
| ) |
| |
| return Qwen3VLForEmbeddingOutput( |
| last_hidden_state=outputs.last_hidden_state, |
| attention_mask=attention_mask, |
| ) |
|
|
|
|
|
|
|
|
|
|
| |
| class Qwen3VLEmbedder(): |
| def __init__( |
| self, |
| model_name_or_path: str, |
| max_length: int = MAX_LENGTH, |
| min_pixels: int = MIN_PIXELS, |
| max_pixels: int = MAX_PIXELS, |
| total_pixels: int = MAX_TOTAL_PIXELS, |
| fps: float = FPS, |
| max_frames: int = MAX_FRAMES, |
| default_instruction: str = "Represent the user's input.", |
| **kwargs |
| ): |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| self.max_length = max_length |
| self.min_pixels = min_pixels |
| self.max_pixels = max_pixels |
| self.total_pixels = total_pixels |
| self.fps = fps |
| self.max_frames = max_frames |
|
|
| self.default_instruction = default_instruction |
|
|
| self.model = Qwen3VLForEmbedding.from_pretrained( |
| model_name_or_path, trust_remote_code=True, **kwargs |
| ).to(device) |
| self.processor = Qwen3VLProcessor.from_pretrained( |
| model_name_or_path, padding_side='right' |
| ) |
| self.model.eval() |
|
|
| @torch.no_grad() |
| def forward(self, inputs: Dict[str, Any]) -> Dict[str, torch.Tensor]: |
| outputs = self.model(**inputs) |
| return { |
| 'last_hidden_state': outputs.last_hidden_state, |
| 'attention_mask': inputs.get('attention_mask') |
| } |
|
|
|
|
| |
| def _preprocess_inputs(self, conversations: List[List[Dict]]) -> Dict[str, torch.Tensor]: |
| text = self.processor.apply_chat_template( |
| conversations, add_generation_prompt=True, tokenize=False |
| ) |
|
|
| try: |
| images, video_inputs, video_kwargs = process_vision_info( |
| conversations, image_patch_size=16, |
| return_video_metadata=True, return_video_kwargs=True |
| ) |
| except Exception as e: |
| logger.error(f"Error in processing vision info: {e}") |
| images = None |
| video_inputs = None |
| video_kwargs = {'do_sample_frames': False} |
| text = self.processor.apply_chat_template( |
| [{'role': 'user', 'content': [{'type': 'text', 'text': 'NULL'}]}], |
| add_generation_prompt=True, tokenize=False |
| ) |
|
|
| if video_inputs is not None: |
| videos, video_metadata = zip(*video_inputs) |
| videos = list(videos) |
| video_metadata = list(video_metadata) |
| else: |
| videos, video_metadata = None, None |
|
|
| inputs = self.processor( |
| text=text, images=images, videos=videos, video_metadata=video_metadata, truncation=True, |
| max_length=self.max_length, padding=True, do_resize=False, return_tensors='pt', |
| **video_kwargs |
| ) |
| return inputs |
|
|
| |
| @staticmethod |
| def _pooling_last(hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: |
| flipped_tensor = attention_mask.flip(dims=[1]) |
| last_one_positions = flipped_tensor.argmax(dim=1) |
| col = attention_mask.shape[1] - last_one_positions - 1 |
| row = torch.arange(hidden_state.shape[0], device=hidden_state.device) |
| return hidden_state[row, col] |
|
|
| |
| def process(self, inputs: List[List[Dict]], normalize: bool = True) -> tuple: |
| |
|
|
| |
|
|
| processed_inputs = self._preprocess_inputs(inputs) |
| processed_inputs = {k: v.to(self.model.device) for k, v in processed_inputs.items()} |
|
|
| outputs = self.forward(processed_inputs) |
| embeddings = self._pooling_last(outputs['last_hidden_state'], outputs['attention_mask']) |
|
|
| |
| if normalize: |
| embeddings = F.normalize(embeddings, p=2, dim=-1) |
|
|
| return embeddings |
| |
|
|
|
|
| class EmbeddingRequest(BaseModel): |
| messages: List[Dict[str, Any]] = Field(..., description="输入文本或文本列表") |
| model: str = Field(default="Qwen3VL", description="模型名称") |
| encoding_format: str = Field(default="float", description="输出格式") |
| |
| continue_final_message: Optional[bool] = Field(default=False, description="是否继续生成最终消息") |
| add_special_tokens: Optional[bool] = Field(default=False, description="是否添加特殊标记") |
|
|
|
|
|
|
|
|
| load_dotenv() |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") |
| logger = logging.getLogger(__name__) |
|
|
| app = FastAPI( |
| title="Qwen3VL Embedding API", |
| description="API for Qwen3VL Embedding model", |
| version="1.0.0" |
| ) |
|
|
| |
| logger.info("Loading Qwen3VL Embedding model...") |
| try: |
|
|
|
|
| model_path = os.getenv("EMBEDDING_MODEL_PATH", "") |
| embedding_model = Qwen3VLEmbedder( |
| model_name_or_path=model_path, |
| max_length=2048 |
| ) |
| logger.info("Qwen3VL Embedding model loaded successfully!") |
| except Exception as e: |
| logger.error(f"Failed to load model: {e}") |
| embedding_model = None |
|
|
|
|
|
|
| def wrap_embedding_list(embedding_list: List[float], index: int = 0) -> Embedding: |
| """ |
| 将embedding列表包装成Embedding类的实例 |
| |
| Args: |
| embedding_list: 包含浮点数值的列表,表示嵌入向量 |
| index: 在嵌入列表中的索引 |
| |
| Returns: |
| Embedding类的实例 |
| """ |
| return Embedding( |
| embedding=embedding_list, |
| index=index, |
| object="embedding" |
| ) |
|
|
|
|
| def wrap_multiple_embedding_lists(embedding_lists: List[List[float]]) -> List[Embedding]: |
| """ |
| 将多个embedding列表包装成Embedding类的实例列表 |
| |
| Args: |
| embedding_lists: 包含多个嵌入向量列表的列表 |
| |
| Returns: |
| Embedding类实例的列表 |
| """ |
| return [wrap_embedding_list(embedding_list, idx) for idx, embedding_list in enumerate(embedding_lists)] |
|
|
|
|
|
|
|
|
| @app.post("/v1/embeddings", response_model=CreateEmbeddingResponse) |
| async def create_embeddings(request: EmbeddingRequest): |
| """ |
| OpenAI 兼容的 Embeddings 接口 |
| """ |
| try: |
| if embedding_model is None: |
| raise HTTPException(status_code=500, detail="模型未正确加载") |
| |
| conversation = request.messages |
| |
| embedding_result = embedding_model.process(conversation, normalize=True) |
| embedding_list = embedding_result.cpu().tolist() |
| embedding_objects = wrap_multiple_embedding_lists(embedding_list) |
| |
| return CreateEmbeddingResponse( |
| data = embedding_objects, |
| model = request.model, |
| object = "list", |
| usage = Usage( |
| prompt_tokens = len(request.messages), |
| total_tokens = len(request.messages) |
| ) |
| |
| ) |
|
|
| except Exception as e: |
| logger.error(f"Error during embedding: {e}") |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
| if __name__ == "__main__": |
| import uvicorn |
| port = int(os.getenv("EMBEDDING_API_PORT", 8006)) |
| uvicorn.run(app, host="0.0.0.0", port=port) |