File size: 5,840 Bytes
fbf52be
 
 
 
 
 
 
b7cf5a4
fbf52be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
00ac3c8
fbf52be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
00ac3c8
 
 
fbf52be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
import pipmaster as pm
from llama_index.core.llms import (
    ChatMessage,
    MessageRole,
    ChatResponse,
)
from typing import List, Optional
from lightrag.utils import logger

# Install required dependencies
if not pm.is_installed("llama-index"):
    pm.install("llama-index")

from llama_index.core.embeddings import BaseEmbedding
from llama_index.core.settings import Settings as LlamaIndexSettings
from tenacity import (
    retry,
    stop_after_attempt,
    wait_exponential,
    retry_if_exception_type,
)
from lightrag.utils import (
    wrap_embedding_func_with_attrs,
    locate_json_string_body_from_string,
)
from lightrag.exceptions import (
    APIConnectionError,
    RateLimitError,
    APITimeoutError,
)
import numpy as np


def configure_llama_index(settings: LlamaIndexSettings = None, **kwargs):
    """
    Configure LlamaIndex settings.

    Args:
        settings: LlamaIndex Settings instance. If None, uses default settings.
        **kwargs: Additional settings to override/configure
    """
    if settings is None:
        settings = LlamaIndexSettings()

    # Update settings with any provided kwargs
    for key, value in kwargs.items():
        if hasattr(settings, key):
            setattr(settings, key, value)
        else:
            logger.warning(f"Unknown LlamaIndex setting: {key}")

    # Set as global settings
    LlamaIndexSettings.set_global(settings)
    return settings


def format_chat_messages(messages):
    """Format chat messages into LlamaIndex format."""
    formatted_messages = []

    for msg in messages:
        role = msg.get("role", "user")
        content = msg.get("content", "")

        if role == "system":
            formatted_messages.append(
                ChatMessage(role=MessageRole.SYSTEM, content=content)
            )
        elif role == "assistant":
            formatted_messages.append(
                ChatMessage(role=MessageRole.ASSISTANT, content=content)
            )
        elif role == "user":
            formatted_messages.append(
                ChatMessage(role=MessageRole.USER, content=content)
            )
        else:
            logger.warning(f"Unknown role {role}, treating as user message")
            formatted_messages.append(
                ChatMessage(role=MessageRole.USER, content=content)
            )

    return formatted_messages


@retry(
    stop=stop_after_attempt(3),
    wait=wait_exponential(multiplier=1, min=4, max=60),
    retry=retry_if_exception_type(
        (RateLimitError, APIConnectionError, APITimeoutError)
    ),
)
async def llama_index_complete_if_cache(
    model: str,
    prompt: str,
    system_prompt: Optional[str] = None,
    history_messages: List[dict] = [],
    chat_kwargs={},
) -> str:
    """Complete the prompt using LlamaIndex."""
    try:
        # Format messages for chat
        formatted_messages = []

        # Add system message if provided
        if system_prompt:
            formatted_messages.append(
                ChatMessage(role=MessageRole.SYSTEM, content=system_prompt)
            )

        # Add history messages
        for msg in history_messages:
            formatted_messages.append(
                ChatMessage(
                    role=MessageRole.USER
                    if msg["role"] == "user"
                    else MessageRole.ASSISTANT,
                    content=msg["content"],
                )
            )

        # Add current prompt
        formatted_messages.append(ChatMessage(role=MessageRole.USER, content=prompt))

        response: ChatResponse = await model.achat(
            messages=formatted_messages, **chat_kwargs
        )

        # In newer versions, the response is in message.content
        content = response.message.content
        return content

    except Exception as e:
        logger.error(f"Error in llama_index_complete_if_cache: {str(e)}")
        raise


async def llama_index_complete(
    prompt,
    system_prompt=None,
    history_messages=None,
    keyword_extraction=False,
    settings: LlamaIndexSettings = None,
    **kwargs,
) -> str:
    """
    Main completion function for LlamaIndex

    Args:
        prompt: Input prompt
        system_prompt: Optional system prompt
        history_messages: Optional chat history
        keyword_extraction: Whether to extract keywords from response
        settings: Optional LlamaIndex settings
        **kwargs: Additional arguments
    """
    if history_messages is None:
        history_messages = []

    keyword_extraction = kwargs.pop("keyword_extraction", None)
    result = await llama_index_complete_if_cache(
        kwargs.get("llm_instance"),
        prompt,
        system_prompt=system_prompt,
        history_messages=history_messages,
        **kwargs,
    )
    if keyword_extraction:
        return locate_json_string_body_from_string(result)
    return result


@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
@retry(
    stop=stop_after_attempt(3),
    wait=wait_exponential(multiplier=1, min=4, max=60),
    retry=retry_if_exception_type(
        (RateLimitError, APIConnectionError, APITimeoutError)
    ),
)
async def llama_index_embed(
    texts: list[str],
    embed_model: BaseEmbedding = None,
    settings: LlamaIndexSettings = None,
    **kwargs,
) -> np.ndarray:
    """
    Generate embeddings using LlamaIndex

    Args:
        texts: List of texts to embed
        embed_model: LlamaIndex embedding model
        settings: Optional LlamaIndex settings
        **kwargs: Additional arguments
    """
    if settings:
        configure_llama_index(settings)

    if embed_model is None:
        raise ValueError("embed_model must be provided")

    # Use _get_text_embeddings for batch processing
    embeddings = embed_model._get_text_embeddings(texts)
    return np.array(embeddings)