File size: 8,230 Bytes
129cd69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
# This module contains utility classes and functions for interacting with Arcee API.
# For more information and updates, refer to the Arcee utils page:
# [https://github.com/arcee-ai/arcee-python/blob/main/arcee/dalm.py]

from enum import Enum
from typing import Any, Dict, List, Literal, Mapping, Optional, Union

import requests
from langchain_core.pydantic_v1 import BaseModel, root_validator
from langchain_core.retrievers import Document


class ArceeRoute(str, Enum):
    """Routes available for the Arcee API as enumerator."""

    generate = "models/generate"
    retrieve = "models/retrieve"
    model_training_status = "models/status/{id_or_name}"


class DALMFilterType(str, Enum):
    """Filter types available for a DALM retrieval as enumerator."""

    fuzzy_search = "fuzzy_search"
    strict_search = "strict_search"


class DALMFilter(BaseModel):
    """Filters available for a DALM retrieval and generation.

    Arguments:
        field_name: The field to filter on. Can be 'document' or 'name' to filter
            on your document's raw text or title. Any other field will be presumed
            to be a metadata field you included when uploading your context data
        filter_type: Currently 'fuzzy_search' and 'strict_search' are supported.
            'fuzzy_search' means a fuzzy search on the provided field is performed.
            The exact strict doesn't need to exist in the document
            for this to find a match.
            Very useful for scanning a document for some keyword terms.
            'strict_search' means that the exact string must appear
            in the provided field.
            This is NOT an exact eq filter. ie a document with content
            "the happy dog crossed the street" will match on a strict_search of
            "dog" but won't match on "the dog".
            Python equivalent of `return search_string in full_string`.
        value: The actual value to search for in the context data/metadata
    """

    field_name: str
    filter_type: DALMFilterType
    value: str
    _is_metadata: bool = False

    @root_validator()
    def set_meta(cls, values: Dict) -> Dict:
        """document and name are reserved arcee keys. Anything else is metadata"""
        values["_is_meta"] = values.get("field_name") not in ["document", "name"]
        return values


class ArceeDocumentSource(BaseModel):
    """Source of an Arcee document."""

    document: str
    name: str
    id: str


class ArceeDocument(BaseModel):
    """Arcee document."""

    index: str
    id: str
    score: float
    source: ArceeDocumentSource


class ArceeDocumentAdapter:
    """Adapter for Arcee documents"""

    @classmethod
    def adapt(cls, arcee_document: ArceeDocument) -> Document:
        """Adapts an `ArceeDocument` to a langchain's `Document` object."""
        return Document(
            page_content=arcee_document.source.document,
            metadata={
                # arcee document; source metadata
                "name": arcee_document.source.name,
                "source_id": arcee_document.source.id,
                # arcee document metadata
                "index": arcee_document.index,
                "id": arcee_document.id,
                "score": arcee_document.score,
            },
        )


class ArceeWrapper:
    """Wrapper for Arcee API."""

    def __init__(
        self,
        arcee_api_key: str,
        arcee_api_url: str,
        arcee_api_version: str,
        model_kwargs: Optional[Dict[str, Any]],
        model_name: str,
    ):
        """Initialize ArceeWrapper.

        Arguments:
            arcee_api_key: API key for Arcee API.
            arcee_api_url: URL for Arcee API.
            arcee_api_version: Version of Arcee API.
            model_kwargs: Keyword arguments for Arcee API.
            model_name: Name of an Arcee model.

        """
        self.arcee_api_key = arcee_api_key
        self.model_kwargs = model_kwargs
        self.arcee_api_url = arcee_api_url
        self.arcee_api_version = arcee_api_version

        try:
            route = ArceeRoute.model_training_status.value.format(id_or_name=model_name)
            response = self._make_request("get", route)
            self.model_id = response.get("model_id")
            self.model_training_status = response.get("status")
        except Exception as e:
            raise ValueError(
                f"Error while validating model training status for '{model_name}': {e}"
            ) from e

    def validate_model_training_status(self) -> None:
        if self.model_training_status != "training_complete":
            raise Exception(
                f"Model {self.model_id} is not ready. "
                "Please wait for training to complete."
            )

    def _make_request(
        self,
        method: Literal["post", "get"],
        route: Union[ArceeRoute, str],
        body: Optional[Mapping[str, Any]] = None,
        params: Optional[dict] = None,
        headers: Optional[dict] = None,
    ) -> dict:
        """Make a request to the Arcee API
        Args:
            method: The HTTP method to use
            route: The route to call
            body: The body of the request
            params: The query params of the request
            headers: The headers of the request
        """
        headers = self._make_request_headers(headers=headers)
        url = self._make_request_url(route=route)

        req_type = getattr(requests, method)

        response = req_type(url, json=body, params=params, headers=headers)
        if response.status_code not in (200, 201):
            raise Exception(f"Failed to make request. Response: {response.text}")
        return response.json()

    def _make_request_headers(self, headers: Optional[Dict] = None) -> Dict:
        headers = headers or {}
        internal_headers = {
            "X-Token": self.arcee_api_key,
            "Content-Type": "application/json",
        }
        headers.update(internal_headers)
        return headers

    def _make_request_url(self, route: Union[ArceeRoute, str]) -> str:
        return f"{self.arcee_api_url}/{self.arcee_api_version}/{route}"

    def _make_request_body_for_models(
        self, prompt: str, **kwargs: Mapping[str, Any]
    ) -> Mapping[str, Any]:
        """Make the request body for generate/retrieve models endpoint"""
        _model_kwargs = self.model_kwargs or {}
        _params = {**_model_kwargs, **kwargs}

        filters = [DALMFilter(**f) for f in _params.get("filters", [])]
        return dict(
            model_id=self.model_id,
            query=prompt,
            size=_params.get("size", 3),
            filters=filters,
            id=self.model_id,
        )

    def generate(
        self,
        prompt: str,
        **kwargs: Any,
    ) -> str:
        """Generate text from Arcee DALM.

        Args:
            prompt: Prompt to generate text from.
            size: The max number of context results to retrieve. Defaults to 3.
              (Can be less if filters are provided).
            filters: Filters to apply to the context dataset.
        """

        response = self._make_request(
            method="post",
            route=ArceeRoute.generate.value,
            body=self._make_request_body_for_models(
                prompt=prompt,
                **kwargs,
            ),
        )
        return response["text"]

    def retrieve(
        self,
        query: str,
        **kwargs: Any,
    ) -> List[Document]:
        """Retrieve {size} contexts with your retriever for a given query

        Args:
            query: Query to submit to the model
            size: The max number of context results to retrieve. Defaults to 3.
              (Can be less if filters are provided).
            filters: Filters to apply to the context dataset.
        """

        response = self._make_request(
            method="post",
            route=ArceeRoute.retrieve.value,
            body=self._make_request_body_for_models(
                prompt=query,
                **kwargs,
            ),
        )
        return [
            ArceeDocumentAdapter.adapt(ArceeDocument(**doc))
            for doc in response["results"]
        ]