File size: 4,871 Bytes
01d5a5d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
from lpm_kernel.api.services.user_llm_config_service import UserLLMConfigService
from lpm_kernel.configs.config import Config
from typing import List, Union
import requests
import numpy as np
from lpm_kernel.configs.logging import get_train_process_logger
logger = get_train_process_logger()
import lpm_kernel.common.strategy.classification as classification
from sentence_transformers import SentenceTransformer
import json

class EmbeddingError(Exception):
    """Custom exception class for embedding-related errors"""
    def __init__(self, message, original_error=None):
        super().__init__(message)
        self.original_error = original_error

class LLMClient:
    """LLM client utility class"""

    def __init__(self):
        self.config = Config.from_env()
        self.user_llm_config_service = UserLLMConfigService()
        self.embedding_max_text_length = int(self.config.get('EMBEDDING_MAX_TEXT_LENGTH', 8000))
        # self.user_llm_config = self.user_llm_config_service.get_available_llm()

        # self.chat_api_key = self.user_llm_config.chat_api_key
        # self.chat_base_url = self.user_llm_config.chat_endpoint
        # self.chat_model = self.user_llm_config.chat_model_name
        # self.embedding_api_key = self.user_llm_config.embedding_api_key
        # self.embedding_base_url = self.user_llm_config.embedding_endpoint
        # self.embedding_model = self.user_llm_config.embedding_model_name


    def get_embedding(self, texts: Union[str, List[str]]) -> np.ndarray:
        """Calculate text embedding

        Args:
            texts (str or list): Input text or list of texts

        Returns:
            numpy.ndarray: Embedding vector of the text
        """
        # Ensure texts is in list format
        if isinstance(texts, str):
            texts = [texts]

        # Split long texts into chunks using configured max length
        chunked_texts = []
        text_chunk_counts = []  # Keep track of how many chunks each text was split into
        
        for text in texts:
            if len(text) > self.embedding_max_text_length:
                # Split into chunks of embedding_max_text_length
                chunks = [text[i:i + self.embedding_max_text_length] 
                         for i in range(0, len(text), self.embedding_max_text_length)]
                chunked_texts.extend(chunks)
                text_chunk_counts.append(len(chunks))
            else:
                chunked_texts.append(text)
                text_chunk_counts.append(1)

        user_llm_config = self.user_llm_config_service.get_available_llm()
        if not user_llm_config:
            raise EmbeddingError("No LLM configuration found")
        
        try:
            # Send request to embedding endpoint
            embeddings_array = classification.strategy_classification(user_llm_config, chunked_texts)

            # If we split any texts, we need to merge their embeddings back
            if sum(text_chunk_counts) > len(texts):
                final_embeddings = []
                start_idx = 0
                for chunk_count in text_chunk_counts:
                    if chunk_count > 1:
                        # Average embeddings for split text
                        chunk_embeddings = embeddings_array[start_idx:start_idx + chunk_count]
                        avg_embedding = np.mean(chunk_embeddings, axis=0)
                        final_embeddings.append(avg_embedding)
                    else:
                        final_embeddings.append(embeddings_array[start_idx])
                    start_idx += chunk_count
                return np.array(final_embeddings)
            
            return embeddings_array

        except requests.exceptions.RequestException as e:
            # Handle request errors
            error_msg = f"Request error getting embeddings: {str(e)}"
            logger.error(error_msg)
            raise EmbeddingError(error_msg, e)
        except json.JSONDecodeError as e:
            # Handle JSON parsing errors
            error_msg = f"Invalid JSON response from embedding API: {str(e)}"
            logger.error(error_msg)
            raise EmbeddingError(error_msg, e)
        except (KeyError, IndexError, ValueError) as e:
            # Handle response structure errors
            error_msg = f"Invalid response structure from embedding API: {str(e)}"
            logger.error(error_msg)
            raise EmbeddingError(error_msg, e)
        except Exception as e:
            # Fallback for any other errors
            error_msg = f"Unexpected error getting embeddings: {str(e)}"
            logger.error(error_msg, exc_info=True)
            raise EmbeddingError(error_msg, e)

    @property
    def chat_credentials(self):
        """Get LLM authentication information"""
        return {"api_key": self.chat_api_key, "base_url": self.chat_base_url}