Spaces:
Sleeping
Sleeping
| """Dashscope (Alibaba Cloud) ModelClient integration.""" | |
| import os | |
| import pickle | |
| from typing import ( | |
| Dict, | |
| Optional, | |
| Any, | |
| Callable, | |
| Generator, | |
| Union, | |
| Literal, | |
| List, | |
| Sequence, | |
| ) | |
| import logging | |
| import backoff | |
| from copy import deepcopy | |
| from tqdm import tqdm | |
| # optional import | |
| from adalflow.utils.lazy_import import safe_import, OptionalPackages | |
| openai = safe_import(OptionalPackages.OPENAI.value[0], OptionalPackages.OPENAI.value[1]) | |
| from openai import OpenAI, AsyncOpenAI, Stream | |
| from openai import ( | |
| APITimeoutError, | |
| InternalServerError, | |
| RateLimitError, | |
| UnprocessableEntityError, | |
| BadRequestError, | |
| ) | |
| from openai.types import ( | |
| Completion, | |
| CreateEmbeddingResponse, | |
| ) | |
| from openai.types.chat import ChatCompletionChunk, ChatCompletion | |
| from adalflow.core.model_client import ModelClient | |
| from adalflow.core.types import ( | |
| ModelType, | |
| EmbedderOutput, | |
| CompletionUsage, | |
| GeneratorOutput, | |
| Document, | |
| Embedding, | |
| EmbedderOutputType, | |
| EmbedderInputType, | |
| ) | |
| from adalflow.core.component import DataComponent | |
| from adalflow.core.embedder import ( | |
| BatchEmbedderOutputType, | |
| BatchEmbedderInputType, | |
| ) | |
| import adalflow.core.functional as F | |
| from adalflow.components.model_client.utils import parse_embedding_response | |
| from api.logging_config import setup_logging | |
| # # Disable tqdm progress bars | |
| # os.environ["TQDM_DISABLE"] = "1" | |
| setup_logging() | |
| log = logging.getLogger(__name__) | |
| def get_first_message_content(completion: ChatCompletion) -> str: | |
| """When we only need the content of the first message.""" | |
| log.info(f"🔍 get_first_message_content called with: {type(completion)}") | |
| log.debug(f"raw completion: {completion}") | |
| try: | |
| if hasattr(completion, 'choices') and len(completion.choices) > 0: | |
| choice = completion.choices[0] | |
| if hasattr(choice, 'message') and hasattr(choice.message, 'content'): | |
| content = choice.message.content | |
| log.info(f"✅ Successfully extracted content: {type(content)}, length: {len(content) if content else 0}") | |
| return content | |
| else: | |
| log.error("❌ Choice doesn't have message.content") | |
| return str(completion) | |
| else: | |
| log.error("❌ Completion doesn't have choices") | |
| return str(completion) | |
| except Exception as e: | |
| log.error(f"❌ Error in get_first_message_content: {e}") | |
| return str(completion) | |
| def parse_stream_response(completion: ChatCompletionChunk) -> str: | |
| """Parse the response of the stream API.""" | |
| return completion.choices[0].delta.content | |
| def handle_streaming_response(generator: Stream[ChatCompletionChunk]): | |
| """Handle the streaming response.""" | |
| for completion in generator: | |
| log.debug(f"Raw chunk completion: {completion}") | |
| parsed_content = parse_stream_response(completion) | |
| yield parsed_content | |
| class DashscopeClient(ModelClient): | |
| """A component wrapper for the Dashscope (Alibaba Cloud) API client. | |
| Dashscope provides access to Alibaba Cloud's Qwen and other models through an OpenAI-compatible API. | |
| Args: | |
| api_key (Optional[str], optional): Dashscope API key. Defaults to None. | |
| workspace_id (Optional[str], optional): Dashscope workspace ID. Defaults to None. | |
| base_url (str): The API base URL. Defaults to "https://dashscope.aliyuncs.com/compatible-mode/v1". | |
| env_api_key_name (str): Environment variable name for the API key. Defaults to "DASHSCOPE_API_KEY". | |
| env_workspace_id_name (str): Environment variable name for the workspace ID. Defaults to "DASHSCOPE_WORKSPACE_ID". | |
| References: | |
| - Dashscope API Documentation: https://help.aliyun.com/zh/dashscope/ | |
| """ | |
| def __init__( | |
| self, | |
| api_key: Optional[str] = None, | |
| workspace_id: Optional[str] = None, | |
| chat_completion_parser: Callable[[Completion], Any] = None, | |
| input_type: Literal["text", "messages"] = "text", | |
| base_url: Optional[str] = None, | |
| env_base_url_name: str = "DASHSCOPE_BASE_URL", | |
| env_api_key_name: str = "DASHSCOPE_API_KEY", | |
| env_workspace_id_name: str = "DASHSCOPE_WORKSPACE_ID", | |
| ): | |
| super().__init__() | |
| self._api_key = api_key | |
| self._workspace_id = workspace_id | |
| self._env_api_key_name = env_api_key_name | |
| self._env_workspace_id_name = env_workspace_id_name | |
| self._env_base_url_name = env_base_url_name | |
| self.base_url = base_url or os.getenv(self._env_base_url_name, "https://dashscope.aliyuncs.com/compatible-mode/v1") | |
| self.sync_client = self.init_sync_client() | |
| self.async_client = None | |
| # Force use of get_first_message_content to ensure string output | |
| self.chat_completion_parser = get_first_message_content | |
| self._input_type = input_type | |
| self._api_kwargs = {} | |
| def _prepare_client_config(self): | |
| """ | |
| Private helper method to prepare client configuration. | |
| Returns: | |
| tuple: (api_key, workspace_id, base_url) for client initialization | |
| Raises: | |
| ValueError: If API key is not provided | |
| """ | |
| api_key = self._api_key or os.getenv(self._env_api_key_name) | |
| workspace_id = self._workspace_id or os.getenv(self._env_workspace_id_name) | |
| if not api_key: | |
| raise ValueError( | |
| f"Environment variable {self._env_api_key_name} must be set" | |
| ) | |
| if not workspace_id: | |
| log.warning(f"Environment variable {self._env_workspace_id_name} not set. Some features may not work properly.") | |
| # For Dashscope, we need to include the workspace ID in the base URL if provided | |
| base_url = self.base_url | |
| if workspace_id: | |
| # Add workspace ID to headers or URL as required by Dashscope | |
| base_url = f"{self.base_url.rstrip('/')}" | |
| return api_key, workspace_id, base_url | |
| def init_sync_client(self): | |
| api_key, workspace_id, base_url = self._prepare_client_config() | |
| client = OpenAI(api_key=api_key, base_url=base_url) | |
| # Store workspace_id for later use in requests | |
| if workspace_id: | |
| client._workspace_id = workspace_id | |
| return client | |
| def init_async_client(self): | |
| api_key, workspace_id, base_url = self._prepare_client_config() | |
| client = AsyncOpenAI(api_key=api_key, base_url=base_url) | |
| # Store workspace_id for later use in requests | |
| if workspace_id: | |
| client._workspace_id = workspace_id | |
| return client | |
| def parse_chat_completion( | |
| self, | |
| completion: Union[ChatCompletion, Generator[ChatCompletionChunk, None, None]], | |
| ) -> "GeneratorOutput": | |
| """Parse the completion response to a GeneratorOutput.""" | |
| try: | |
| # If the completion is already a GeneratorOutput, return it directly (prevent recursion) | |
| if isinstance(completion, GeneratorOutput): | |
| return completion | |
| # Check if it's a ChatCompletion object (non-streaming response) | |
| if hasattr(completion, 'choices') and hasattr(completion, 'usage'): | |
| # ALWAYS extract the string content directly | |
| try: | |
| # Direct extraction of message content | |
| if (hasattr(completion, 'choices') and | |
| len(completion.choices) > 0 and | |
| hasattr(completion.choices[0], 'message') and | |
| hasattr(completion.choices[0].message, 'content')): | |
| content = completion.choices[0].message.content | |
| if isinstance(content, str): | |
| parsed_data = content | |
| else: | |
| parsed_data = str(content) | |
| else: | |
| # Fallback: convert entire completion to string | |
| parsed_data = str(completion) | |
| except Exception as e: | |
| # Ultimate fallback | |
| parsed_data = str(completion) | |
| return GeneratorOutput( | |
| data=parsed_data, | |
| usage=CompletionUsage( | |
| completion_tokens=completion.usage.completion_tokens, | |
| prompt_tokens=completion.usage.prompt_tokens, | |
| total_tokens=completion.usage.total_tokens, | |
| ), | |
| raw_response=str(completion), | |
| ) | |
| else: | |
| # Handle streaming response - collect all content parts into a single string | |
| content_parts = [] | |
| usage_info = None | |
| for chunk in completion: | |
| if chunk.choices[0].delta.content: | |
| content_parts.append(chunk.choices[0].delta.content) | |
| # Try to get usage info from the last chunk | |
| if hasattr(chunk, 'usage') and chunk.usage: | |
| usage_info = chunk.usage | |
| # Join all content parts into a single string | |
| full_content = ''.join(content_parts) | |
| # Create usage object | |
| usage = None | |
| if usage_info: | |
| usage = CompletionUsage( | |
| completion_tokens=usage_info.completion_tokens, | |
| prompt_tokens=usage_info.prompt_tokens, | |
| total_tokens=usage_info.total_tokens, | |
| ) | |
| return GeneratorOutput( | |
| data=full_content, | |
| usage=usage, | |
| raw_response="streaming" | |
| ) | |
| except Exception as e: | |
| log.error(f"Error parsing completion: {e}") | |
| raise | |
| def track_completion_usage( | |
| self, | |
| completion: Union[ChatCompletion, Generator[ChatCompletionChunk, None, None]], | |
| ) -> CompletionUsage: | |
| """Track the completion usage.""" | |
| if isinstance(completion, ChatCompletion): | |
| return CompletionUsage( | |
| completion_tokens=completion.usage.completion_tokens, | |
| prompt_tokens=completion.usage.prompt_tokens, | |
| total_tokens=completion.usage.total_tokens, | |
| ) | |
| else: | |
| # For streaming, we can't track usage accurately | |
| return CompletionUsage(completion_tokens=0, prompt_tokens=0, total_tokens=0) | |
| def parse_embedding_response( | |
| self, response: CreateEmbeddingResponse | |
| ) -> EmbedderOutput: | |
| """Parse the embedding response to a EmbedderOutput.""" | |
| # Add detailed debugging | |
| try: | |
| result = parse_embedding_response(response) | |
| if result.data: | |
| log.info(f"🔍 Number of embeddings: {len(result.data)}") | |
| if len(result.data) > 0: | |
| log.info(f"🔍 First embedding length: {len(result.data[0].embedding) if hasattr(result.data[0], 'embedding') else 'N/A'}") | |
| else: | |
| log.warning(f"🔍 No embedding data found in result") | |
| return result | |
| except Exception as e: | |
| log.error(f"🔍 Error parsing DashScope embedding response: {e}") | |
| log.error(f"🔍 Raw response details: {repr(response)}") | |
| return EmbedderOutput(data=[], error=str(e), raw_response=response) | |
| def convert_inputs_to_api_kwargs( | |
| self, | |
| input: Optional[Any] = None, | |
| model_kwargs: Dict = {}, | |
| model_type: ModelType = ModelType.UNDEFINED, | |
| ) -> Dict: | |
| """Convert inputs to API kwargs.""" | |
| final_model_kwargs = model_kwargs.copy() | |
| if model_type == ModelType.LLM: | |
| messages = [] | |
| if isinstance(input, str): | |
| messages = [{"role": "user", "content": input}] | |
| elif isinstance(input, list): | |
| messages = input | |
| else: | |
| raise ValueError(f"Unsupported input type: {type(input)}") | |
| api_kwargs = { | |
| "messages": messages, | |
| **final_model_kwargs | |
| } | |
| # Add workspace ID to headers if available | |
| workspace_id = getattr(self.sync_client, '_workspace_id', None) or getattr(self.async_client, '_workspace_id', None) | |
| if workspace_id: | |
| # Dashscope may require workspace ID in headers | |
| if 'extra_headers' not in api_kwargs: | |
| api_kwargs['extra_headers'] = {} | |
| api_kwargs['extra_headers']['X-DashScope-WorkSpace'] = workspace_id | |
| return api_kwargs | |
| elif model_type == ModelType.EMBEDDER: | |
| # Convert Documents to text strings for embedding | |
| processed_input = input | |
| if isinstance(input, list): | |
| # Extract text from Document objects | |
| processed_input = [] | |
| for item in input: | |
| if hasattr(item, 'text'): | |
| # It's a Document object, extract text | |
| processed_input.append(item.text) | |
| elif isinstance(item, str): | |
| # It's already a string | |
| processed_input.append(item) | |
| else: | |
| # Try to convert to string | |
| processed_input.append(str(item)) | |
| elif hasattr(input, 'text'): | |
| # Single Document object | |
| processed_input = input.text | |
| elif isinstance(input, str): | |
| # Single string | |
| processed_input = input | |
| else: | |
| # Convert to string as fallback | |
| processed_input = str(input) | |
| api_kwargs = { | |
| "input": processed_input, | |
| **final_model_kwargs | |
| } | |
| # Add workspace ID to headers if available | |
| workspace_id = getattr(self.sync_client, '_workspace_id', None) or getattr(self.async_client, '_workspace_id', None) | |
| if workspace_id: | |
| if 'extra_headers' not in api_kwargs: | |
| api_kwargs['extra_headers'] = {} | |
| api_kwargs['extra_headers']['X-DashScope-WorkSpace'] = workspace_id | |
| return api_kwargs | |
| else: | |
| raise ValueError(f"model_type {model_type} is not supported") | |
| def call(self, api_kwargs: Dict = {}, model_type: ModelType = ModelType.UNDEFINED): | |
| """Call the Dashscope API.""" | |
| if model_type == ModelType.LLM: | |
| if not api_kwargs.get("stream", False): | |
| # For non-streaming, enable_thinking must be false. | |
| # Pass it via extra_body to avoid TypeError from openai client validation. | |
| extra_body = api_kwargs.get("extra_body", {}) | |
| extra_body["enable_thinking"] = False | |
| api_kwargs["extra_body"] = extra_body | |
| completion = self.sync_client.chat.completions.create(**api_kwargs) | |
| if api_kwargs.get("stream", False): | |
| return handle_streaming_response(completion) | |
| else: | |
| return self.parse_chat_completion(completion) | |
| elif model_type == ModelType.EMBEDDER: | |
| # Extract input texts from api_kwargs | |
| texts = api_kwargs.get("input", []) | |
| if not texts: | |
| log.warning("😭 No input texts provided") | |
| return EmbedderOutput(data=[], error="No input texts provided", raw_response=None) | |
| # Ensure texts is a list | |
| if isinstance(texts, str): | |
| texts = [texts] | |
| # Filter out empty or None texts - following HuggingFace client pattern | |
| valid_texts = [] | |
| valid_indices = [] | |
| for i, text in enumerate(texts): | |
| if text and isinstance(text, str) and text.strip(): | |
| valid_texts.append(text) | |
| valid_indices.append(i) | |
| else: | |
| log.warning(f"🔍 Skipping empty or invalid text at index {i}: type={type(text)}, length={len(text) if hasattr(text, '__len__') else 'N/A'}, repr={repr(text)[:100]}") | |
| if not valid_texts: | |
| log.error("😭 No valid texts found after filtering") | |
| return EmbedderOutput(data=[], error="No valid texts found after filtering", raw_response=None) | |
| if len(valid_texts) != len(texts): | |
| filtered_count = len(texts) - len(valid_texts) | |
| log.warning(f"🔍 Filtered out {filtered_count} empty/invalid texts out of {len(texts)} total texts") | |
| # Create modified api_kwargs with only valid texts | |
| filtered_api_kwargs = api_kwargs.copy() | |
| filtered_api_kwargs["input"] = valid_texts | |
| log.info(f"🔍 DashScope embedding API call with {len(valid_texts)} valid texts out of {len(texts)} total") | |
| try: | |
| response = self.sync_client.embeddings.create(**filtered_api_kwargs) | |
| log.info(f"🔍 DashScope API call successful, response type: {type(response)}") | |
| result = self.parse_embedding_response(response) | |
| # If we filtered texts, we need to create embeddings for the original indices | |
| if len(valid_texts) != len(texts): | |
| log.info(f"🔍 Creating embeddings for {len(texts)} original positions") | |
| # Get the correct embedding dimension from the first valid embedding | |
| embedding_dim = None # Must be determined from a successful response | |
| if result.data and len(result.data) > 0 and hasattr(result.data[0], 'embedding'): | |
| embedding_dim = len(result.data[0].embedding) | |
| log.info(f"🔍 Using embedding dimension: {embedding_dim}") | |
| final_data = [] | |
| valid_idx = 0 | |
| for i in range(len(texts)): | |
| if i in valid_indices: | |
| # Use the embedding from valid texts | |
| final_data.append(result.data[valid_idx]) | |
| valid_idx += 1 | |
| else: | |
| # Create zero embedding for filtered texts with correct dimension | |
| log.warning(f"🔍 Creating zero embedding for filtered text at index {i}") | |
| final_data.append(Embedding( | |
| embedding=[0.0] * embedding_dim, # Use correct embedding dimension | |
| index=i | |
| )) | |
| result = EmbedderOutput( | |
| data=final_data, | |
| error=None, | |
| raw_response=result.raw_response | |
| ) | |
| return result | |
| except Exception as e: | |
| log.error(f"🔍 DashScope API call failed: {e}") | |
| return EmbedderOutput(data=[], error=str(e), raw_response=None) | |
| else: | |
| raise ValueError(f"model_type {model_type} is not supported") | |
| async def acall( | |
| self, api_kwargs: Dict = {}, model_type: ModelType = ModelType.UNDEFINED | |
| ): | |
| """Async call to the Dashscope API.""" | |
| if not self.async_client: | |
| self.async_client = self.init_async_client() | |
| if model_type == ModelType.LLM: | |
| if not api_kwargs.get("stream", False): | |
| # For non-streaming, enable_thinking must be false. | |
| extra_body = api_kwargs.get("extra_body", {}) | |
| extra_body["enable_thinking"] = False | |
| api_kwargs["extra_body"] = extra_body | |
| completion = await self.async_client.chat.completions.create(**api_kwargs) | |
| if api_kwargs.get("stream", False): | |
| return handle_streaming_response(completion) | |
| else: | |
| return self.parse_chat_completion(completion) | |
| elif model_type == ModelType.EMBEDDER: | |
| # Extract input texts from api_kwargs | |
| texts = api_kwargs.get("input", []) | |
| if not texts: | |
| log.warning("😭 No input texts provided") | |
| return EmbedderOutput(data=[], error="No input texts provided", raw_response=None) | |
| # Ensure texts is a list | |
| if isinstance(texts, str): | |
| texts = [texts] | |
| # Filter out empty or None texts - following HuggingFace client pattern | |
| valid_texts = [] | |
| valid_indices = [] | |
| for i, text in enumerate(texts): | |
| if text and isinstance(text, str) and text.strip(): | |
| valid_texts.append(text) | |
| valid_indices.append(i) | |
| else: | |
| log.warning(f"🔍 Skipping empty or invalid text at index {i}: type={type(text)}, length={len(text) if hasattr(text, '__len__') else 'N/A'}, repr={repr(text)[:100]}") | |
| if not valid_texts: | |
| log.error("😭 No valid texts found after filtering") | |
| return EmbedderOutput(data=[], error="No valid texts found after filtering", raw_response=None) | |
| if len(valid_texts) != len(texts): | |
| filtered_count = len(texts) - len(valid_texts) | |
| log.warning(f"🔍 Filtered out {filtered_count} empty/invalid texts out of {len(texts)} total texts") | |
| # Create modified api_kwargs with only valid texts | |
| filtered_api_kwargs = api_kwargs.copy() | |
| filtered_api_kwargs["input"] = valid_texts | |
| log.info(f"🔍 DashScope async embedding API call with {len(valid_texts)} valid texts out of {len(texts)} total") | |
| try: | |
| response = await self.async_client.embeddings.create(**filtered_api_kwargs) | |
| log.info(f"🔍 DashScope async API call successful, response type: {type(response)}") | |
| result = self.parse_embedding_response(response) | |
| # If we filtered texts, we need to create embeddings for the original indices | |
| if len(valid_texts) != len(texts): | |
| log.info(f"🔍 Creating embeddings for {len(texts)} original positions") | |
| # Get the correct embedding dimension from the first valid embedding | |
| embedding_dim = 256 # Default fallback based on config | |
| if result.data and len(result.data) > 0 and hasattr(result.data[0], 'embedding'): | |
| embedding_dim = len(result.data[0].embedding) | |
| log.info(f"🔍 Using embedding dimension: {embedding_dim}") | |
| final_data = [] | |
| valid_idx = 0 | |
| for i in range(len(texts)): | |
| if i in valid_indices: | |
| # Use the embedding from valid texts | |
| final_data.append(result.data[valid_idx]) | |
| valid_idx += 1 | |
| else: | |
| # Create zero embedding for filtered texts with correct dimension | |
| log.warning(f"🔍 Creating zero embedding for filtered text at index {i}") | |
| final_data.append(Embedding( | |
| embedding=[0.0] * embedding_dim, # Use correct embedding dimension | |
| index=i | |
| )) | |
| result = EmbedderOutput( | |
| data=final_data, | |
| error=None, | |
| raw_response=result.raw_response | |
| ) | |
| return result | |
| except Exception as e: | |
| log.error(f"🔍 DashScope async API call failed: {e}") | |
| return EmbedderOutput(data=[], error=str(e), raw_response=None) | |
| else: | |
| raise ValueError(f"model_type {model_type} is not supported") | |
| def from_dict(cls, data: Dict[str, Any]): | |
| """Create an instance from a dictionary.""" | |
| return cls(**data) | |
| def to_dict(self) -> Dict[str, Any]: | |
| """Convert to dictionary.""" | |
| return { | |
| "api_key": self._api_key, | |
| "workspace_id": self._workspace_id, | |
| "base_url": self.base_url, | |
| "input_type": self._input_type, | |
| } | |
| def __getstate__(self): | |
| """ | |
| Customize serialization to exclude non-picklable client objects. | |
| This method is called by pickle when saving the object's state. | |
| """ | |
| state = self.__dict__.copy() | |
| # Remove the unpicklable client instances | |
| if 'sync_client' in state: | |
| del state['sync_client'] | |
| if 'async_client' in state: | |
| del state['async_client'] | |
| return state | |
| def __setstate__(self, state): | |
| """ | |
| Customize deserialization to re-create the client objects. | |
| This method is called by pickle when loading the object's state. | |
| """ | |
| self.__dict__.update(state) | |
| # Re-initialize the clients after unpickling | |
| self.sync_client = self.init_sync_client() | |
| self.async_client = None # It will be lazily initialized when acall is used | |
| class DashScopeEmbedder(DataComponent): | |
| r""" | |
| A user-facing component that orchestrates an embedder model via the DashScope model client and output processors. | |
| Args: | |
| model_client (ModelClient): The DashScope model client to use for the embedder. | |
| model_kwargs (Dict[str, Any], optional): The model kwargs to pass to the model client. Defaults to {}. | |
| output_processors (Optional[Component], optional): The output processors after model call. Defaults to None. | |
| """ | |
| model_type: ModelType = ModelType.EMBEDDER | |
| model_client: ModelClient | |
| output_processors: Optional[DataComponent] | |
| def __init__( | |
| self, | |
| *, | |
| model_client: ModelClient, | |
| model_kwargs: Dict[str, Any] = {}, | |
| output_processors: Optional[DataComponent] = None, | |
| ) -> None: | |
| super().__init__(model_kwargs=model_kwargs) | |
| if not isinstance(model_kwargs, Dict): | |
| raise TypeError( | |
| f"{type(self).__name__} requires a dictionary for model_kwargs, not a string" | |
| ) | |
| self.model_kwargs = model_kwargs.copy() | |
| if not isinstance(model_client, ModelClient): | |
| raise TypeError( | |
| f"{type(self).__name__} requires a ModelClient instance for model_client." | |
| ) | |
| self.model_client = model_client | |
| self.output_processors = output_processors | |
| def call( | |
| self, | |
| input: EmbedderInputType, | |
| model_kwargs: Optional[Dict] = {}, | |
| ) -> EmbedderOutputType: | |
| log.debug(f"Calling {self.__class__.__name__} with input: {input}") | |
| api_kwargs = self.model_client.convert_inputs_to_api_kwargs( | |
| input=input, | |
| model_kwargs=self._compose_model_kwargs(**model_kwargs), | |
| model_type=self.model_type, | |
| ) | |
| try: | |
| output = self.model_client.call( | |
| api_kwargs=api_kwargs, model_type=self.model_type | |
| ) | |
| except Exception as e: | |
| log.error(f"🤡 Error calling the DashScope model: {e}") | |
| output = EmbedderOutput(error=str(e)) | |
| return output | |
| async def acall( | |
| self, | |
| input: EmbedderInputType, | |
| model_kwargs: Optional[Dict] = {}, | |
| ) -> EmbedderOutputType: | |
| log.debug(f"Calling {self.__class__.__name__} with input: {input}") | |
| api_kwargs = self.model_client.convert_inputs_to_api_kwargs( | |
| input=input, | |
| model_kwargs=self._compose_model_kwargs(**model_kwargs), | |
| model_type=self.model_type, | |
| ) | |
| output: EmbedderOutputType = None | |
| try: | |
| response = await self.model_client.acall( | |
| api_kwargs=api_kwargs, model_type=self.model_type | |
| ) | |
| output = self.model_client.parse_embedding_response(response) | |
| except Exception as e: | |
| log.error(f"Error calling the DashScope model: {e}") | |
| output = EmbedderOutput(error=str(e)) | |
| output.input = [input] if isinstance(input, str) else input | |
| log.debug(f"Output from {self.__class__.__name__}: {output}") | |
| return output | |
| def _compose_model_kwargs(self, **model_kwargs) -> Dict[str, object]: | |
| return F.compose_model_kwargs(self.model_kwargs, model_kwargs) | |
| # Batch Embedding Components for DashScope | |
| class DashScopeBatchEmbedder(DataComponent): | |
| """Batch embedder specifically designed for DashScope API""" | |
| def __init__(self, embedder, batch_size: int = 100, embedding_cache_file_name: str = "default") -> None: | |
| super().__init__(batch_size=batch_size) | |
| self.embedder = embedder | |
| self.batch_size = batch_size | |
| if self.batch_size > 25: | |
| log.warning(f"DashScope batch embedder initialization, batch size: {self.batch_size}, note that DashScope batch embedding size cannot exceed 25, automatically set to 25") | |
| self.batch_size = 25 | |
| self.cache_path = f'./embedding_cache/{embedding_cache_file_name}_{self.embedder.__class__.__name__}_dashscope_embeddings.pkl' | |
| def call( | |
| self, input: BatchEmbedderInputType, model_kwargs: Optional[Dict] = {}, force_recreate: bool = False | |
| ) -> BatchEmbedderOutputType: | |
| """ | |
| Batch call to DashScope embedder | |
| Args: | |
| input: List of input texts | |
| model_kwargs: Model parameters | |
| force_recreate: Whether to force recreation | |
| Returns: | |
| Batch embedding output | |
| """ | |
| # Check cache first | |
| if not force_recreate and os.path.exists(self.cache_path): | |
| try: | |
| with open(self.cache_path, 'rb') as f: | |
| embeddings = pickle.load(f) | |
| log.info(f"Loaded cached DashScope embeddings from: {self.cache_path}") | |
| return embeddings | |
| except Exception as e: | |
| log.warning(f"Failed to load cache file {self.cache_path}: {e}, proceeding with fresh embedding") | |
| if isinstance(input, str): | |
| input = [input] | |
| n = len(input) | |
| embeddings: List[EmbedderOutput] = [] | |
| log.info(f"Starting DashScope batch embedding processing, total {n} texts, batch size: {self.batch_size}") | |
| for i in tqdm( | |
| range(0, n, self.batch_size), | |
| desc="DashScope batch embedding", | |
| disable=False, | |
| ): | |
| batch_input = input[i : min(i + self.batch_size, n)] | |
| try: | |
| # Use correct calling method: directly call embedder instance | |
| batch_output = self.embedder( | |
| input=batch_input, model_kwargs=model_kwargs | |
| ) | |
| embeddings.append(batch_output) | |
| # Validate batch output | |
| if batch_output.error: | |
| log.error(f"Batch {i//self.batch_size + 1} embedding failed: {batch_output.error}") | |
| elif batch_output.data: | |
| log.debug(f"Batch {i//self.batch_size + 1} successfully generated {len(batch_output.data)} embedding vectors") | |
| else: | |
| log.warning(f"Batch {i//self.batch_size + 1} returned no embedding data") | |
| except Exception as e: | |
| log.error(f"Batch {i//self.batch_size + 1} processing exception: {e}") | |
| # Create error embedding output | |
| error_output = EmbedderOutput( | |
| data=[], | |
| error=str(e), | |
| raw_response=None | |
| ) | |
| embeddings.append(error_output) | |
| log.info(f"DashScope batch embedding completed, processed {len(embeddings)} batches") | |
| # Save to cache | |
| try: | |
| # Use a more robust cache directory path | |
| cache_dir = os.path.dirname(self.cache_path) | |
| if not os.path.exists(cache_dir): | |
| os.makedirs(cache_dir, exist_ok=True) | |
| with open(self.cache_path, 'wb') as f: | |
| pickle.dump(embeddings, f) | |
| log.info(f"Saved DashScope embeddings cache to: {self.cache_path}") | |
| except Exception as e: | |
| log.warning(f"Failed to save cache to {self.cache_path}: {e}") | |
| return embeddings | |
| def __call__(self, input: BatchEmbedderInputType, model_kwargs: Optional[Dict] = {}, force_recreate: bool = False) -> BatchEmbedderOutputType: | |
| """ | |
| Call operator interface, delegates to call method | |
| """ | |
| return self.call(input=input, model_kwargs=model_kwargs, force_recreate=force_recreate) | |
| class DashScopeToEmbeddings(DataComponent): | |
| """Component that converts document sequences to embedding vector sequences, specifically optimized for DashScope API""" | |
| def __init__(self, embedder, batch_size: int = 100, force_recreate_db: bool = False, embedding_cache_file_name: str = "default") -> None: | |
| super().__init__(batch_size=batch_size) | |
| self.embedder = embedder | |
| self.batch_size = batch_size | |
| self.batch_embedder = DashScopeBatchEmbedder(embedder=embedder, batch_size=batch_size, embedding_cache_file_name=embedding_cache_file_name) | |
| self.force_recreate_db = force_recreate_db | |
| def __call__(self, input: List[Document]) -> List[Document]: | |
| """ | |
| Process list of documents, generating embedding vectors for each document | |
| Args: | |
| input: List of input documents | |
| Returns: | |
| List of documents containing embedding vectors | |
| """ | |
| output = deepcopy(input) | |
| # Convert to text list | |
| embedder_input: List[str] = [chunk.text for chunk in output] | |
| log.info(f"Starting to process embeddings for {len(embedder_input)} documents") | |
| # Batch process embeddings | |
| outputs: List[EmbedderOutput] = self.batch_embedder( | |
| input=embedder_input, | |
| force_recreate=self.force_recreate_db | |
| ) | |
| # Validate output | |
| total_embeddings = 0 | |
| error_batches = 0 | |
| for batch_output in outputs: | |
| if batch_output.error: | |
| error_batches += 1 | |
| log.error(f"Found error batch: {batch_output.error}") | |
| elif batch_output.data: | |
| total_embeddings += len(batch_output.data) | |
| log.info(f"Embedding statistics: total {total_embeddings} valid embeddings, {error_batches} error batches") | |
| # Assign embedding vectors back to documents | |
| doc_idx = 0 | |
| for batch_idx, batch_output in tqdm( | |
| enumerate(outputs), | |
| desc="Assigning embedding vectors to documents", | |
| disable=False | |
| ): | |
| if batch_output.error: | |
| # Create empty vectors for documents in error batches | |
| batch_size_actual = min(self.batch_size, len(output) - doc_idx) | |
| log.warning(f"Creating empty vectors for {batch_size_actual} documents in batch {batch_idx}") | |
| for i in range(batch_size_actual): | |
| if doc_idx < len(output): | |
| output[doc_idx].vector = [] | |
| doc_idx += 1 | |
| else: | |
| # Assign normal embedding vectors | |
| for embedding in batch_output.data: | |
| if doc_idx < len(output): | |
| if hasattr(embedding, 'embedding'): | |
| output[doc_idx].vector = embedding.embedding | |
| else: | |
| log.warning(f"Invalid embedding format for document {doc_idx}") | |
| output[doc_idx].vector = [] | |
| doc_idx += 1 | |
| # Validate results | |
| valid_count = 0 | |
| empty_count = 0 | |
| for doc in output: | |
| if hasattr(doc, 'vector') and doc.vector and len(doc.vector) > 0: | |
| valid_count += 1 | |
| else: | |
| empty_count += 1 | |
| log.info(f"Embedding results: {valid_count} valid vectors, {empty_count} empty vectors") | |
| if valid_count == 0: | |
| log.error("❌ All documents have empty embedding vectors!") | |
| elif empty_count > 0: | |
| log.warning(f"⚠️ Found {empty_count} empty embedding vectors") | |
| else: | |
| log.info("✅ All documents successfully generated embedding vectors") | |
| return output | |
| def _extra_repr(self) -> str: | |
| return f"batch_size={self.batch_size}" |