Spaces:
Sleeping
Sleeping
Update myollama.py
Browse files- myollama.py +167 -15
myollama.py
CHANGED
@@ -1,35 +1,59 @@
|
|
|
|
|
|
1 |
import json
|
2 |
from typing import Any, Dict, Sequence, Tuple
|
3 |
|
4 |
import httpx
|
5 |
from httpx import Timeout
|
6 |
-
|
7 |
-
from llama_index.legacy.bridge.pydantic import Field
|
8 |
-
from llama_index.legacy.constants import DEFAULT_CONTEXT_WINDOW, DEFAULT_NUM_OUTPUTS
|
9 |
-
from llama_index.legacy.core.llms.types import (
|
10 |
ChatMessage,
|
11 |
ChatResponse,
|
12 |
ChatResponseGen,
|
|
|
13 |
CompletionResponse,
|
|
|
14 |
CompletionResponseGen,
|
15 |
LLMMetadata,
|
16 |
MessageRole,
|
17 |
)
|
18 |
-
from llama_index.
|
19 |
-
from llama_index.
|
|
|
|
|
20 |
|
21 |
DEFAULT_REQUEST_TIMEOUT = 30.0
|
22 |
|
23 |
|
24 |
-
def
|
25 |
response: Dict[str, Any], exclude: Tuple[str, ...]
|
26 |
) -> Dict[str, Any]:
|
27 |
return {k: v for k, v in response.items() if k not in exclude}
|
28 |
|
29 |
|
30 |
class Ollama(CustomLLM):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
base_url: str = Field(
|
32 |
-
default="http://localhost:
|
33 |
description="Base url the model is hosted under.",
|
34 |
)
|
35 |
model: str = Field(description="The Ollama model to use.")
|
@@ -51,6 +75,10 @@ class Ollama(CustomLLM):
|
|
51 |
prompt_key: str = Field(
|
52 |
default="prompt", description="The key to use for the prompt in API calls."
|
53 |
)
|
|
|
|
|
|
|
|
|
54 |
additional_kwargs: Dict[str, Any] = Field(
|
55 |
default_factory=dict,
|
56 |
description="Additional model parameters for the Ollama API.",
|
@@ -98,6 +126,9 @@ class Ollama(CustomLLM):
|
|
98 |
**kwargs,
|
99 |
}
|
100 |
|
|
|
|
|
|
|
101 |
with httpx.Client(timeout=Timeout(self.request_timeout)) as client:
|
102 |
response = client.post(
|
103 |
url=f"{self.base_url}/api/chat",
|
@@ -110,12 +141,12 @@ class Ollama(CustomLLM):
|
|
110 |
message=ChatMessage(
|
111 |
content=message.get("content"),
|
112 |
role=MessageRole(message.get("role")),
|
113 |
-
additional_kwargs=
|
114 |
message, ("content", "role")
|
115 |
),
|
116 |
),
|
117 |
raw=raw,
|
118 |
-
additional_kwargs=
|
119 |
)
|
120 |
|
121 |
@llm_chat_callback()
|
@@ -137,6 +168,9 @@ class Ollama(CustomLLM):
|
|
137 |
**kwargs,
|
138 |
}
|
139 |
|
|
|
|
|
|
|
140 |
with httpx.Client(timeout=Timeout(self.request_timeout)) as client:
|
141 |
with client.stream(
|
142 |
method="POST",
|
@@ -157,15 +191,59 @@ class Ollama(CustomLLM):
|
|
157 |
message=ChatMessage(
|
158 |
content=text,
|
159 |
role=MessageRole(message.get("role")),
|
160 |
-
additional_kwargs=
|
161 |
message, ("content", "role")
|
162 |
),
|
163 |
),
|
164 |
delta=delta,
|
165 |
raw=chunk,
|
166 |
-
additional_kwargs=
|
|
|
|
|
167 |
)
|
168 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
@llm_completion_callback()
|
170 |
def complete(
|
171 |
self, prompt: str, formatted: bool = False, **kwargs: Any
|
@@ -178,6 +256,9 @@ class Ollama(CustomLLM):
|
|
178 |
**kwargs,
|
179 |
}
|
180 |
|
|
|
|
|
|
|
181 |
with httpx.Client(timeout=Timeout(self.request_timeout)) as client:
|
182 |
response = client.post(
|
183 |
url=f"{self.base_url}/api/generate",
|
@@ -189,7 +270,36 @@ class Ollama(CustomLLM):
|
|
189 |
return CompletionResponse(
|
190 |
text=text,
|
191 |
raw=raw,
|
192 |
-
additional_kwargs=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
193 |
)
|
194 |
|
195 |
@llm_completion_callback()
|
@@ -204,6 +314,9 @@ class Ollama(CustomLLM):
|
|
204 |
**kwargs,
|
205 |
}
|
206 |
|
|
|
|
|
|
|
207 |
with httpx.Client(timeout=Timeout(self.request_timeout)) as client:
|
208 |
with client.stream(
|
209 |
method="POST",
|
@@ -221,7 +334,46 @@ class Ollama(CustomLLM):
|
|
221 |
delta=delta,
|
222 |
text=text,
|
223 |
raw=chunk,
|
224 |
-
additional_kwargs=
|
225 |
chunk, ("response",)
|
226 |
),
|
227 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# https://github.com/run-llama/llama_index/blob/main/llama-index-integrations/llms/llama-index-llms-ollama/llama_index/llms/ollama/base.py
|
2 |
+
|
3 |
import json
|
4 |
from typing import Any, Dict, Sequence, Tuple
|
5 |
|
6 |
import httpx
|
7 |
from httpx import Timeout
|
8 |
+
from llama_index.core.base.llms.types import (
|
|
|
|
|
|
|
9 |
ChatMessage,
|
10 |
ChatResponse,
|
11 |
ChatResponseGen,
|
12 |
+
ChatResponseAsyncGen,
|
13 |
CompletionResponse,
|
14 |
+
CompletionResponseAsyncGen,
|
15 |
CompletionResponseGen,
|
16 |
LLMMetadata,
|
17 |
MessageRole,
|
18 |
)
|
19 |
+
from llama_index.core.bridge.pydantic import Field
|
20 |
+
from llama_index.core.constants import DEFAULT_CONTEXT_WINDOW, DEFAULT_NUM_OUTPUTS
|
21 |
+
from llama_index.core.llms.callbacks import llm_chat_callback, llm_completion_callback
|
22 |
+
from llama_index.core.llms.custom import CustomLLM
|
23 |
|
24 |
DEFAULT_REQUEST_TIMEOUT = 30.0
|
25 |
|
26 |
|
27 |
+
def get_additional_kwargs(
|
28 |
response: Dict[str, Any], exclude: Tuple[str, ...]
|
29 |
) -> Dict[str, Any]:
|
30 |
return {k: v for k, v in response.items() if k not in exclude}
|
31 |
|
32 |
|
33 |
class Ollama(CustomLLM):
|
34 |
+
"""Ollama LLM.
|
35 |
+
|
36 |
+
Visit https://ollama.com/ to download and install Ollama.
|
37 |
+
|
38 |
+
Run `ollama serve` to start a server.
|
39 |
+
|
40 |
+
Run `ollama pull <name>` to download a model to run.
|
41 |
+
|
42 |
+
Examples:
|
43 |
+
`pip install llama-index-llms-ollama`
|
44 |
+
|
45 |
+
```python
|
46 |
+
from llama_index.llms.ollama import Ollama
|
47 |
+
|
48 |
+
llm = Ollama(model="llama2", request_timeout=60.0)
|
49 |
+
|
50 |
+
response = llm.complete("What is the capital of France?")
|
51 |
+
print(response)
|
52 |
+
```
|
53 |
+
"""
|
54 |
+
|
55 |
base_url: str = Field(
|
56 |
+
default="http://localhost:11435",
|
57 |
description="Base url the model is hosted under.",
|
58 |
)
|
59 |
model: str = Field(description="The Ollama model to use.")
|
|
|
75 |
prompt_key: str = Field(
|
76 |
default="prompt", description="The key to use for the prompt in API calls."
|
77 |
)
|
78 |
+
json_mode: bool = Field(
|
79 |
+
default=False,
|
80 |
+
description="Whether to use JSON mode for the Ollama API.",
|
81 |
+
)
|
82 |
additional_kwargs: Dict[str, Any] = Field(
|
83 |
default_factory=dict,
|
84 |
description="Additional model parameters for the Ollama API.",
|
|
|
126 |
**kwargs,
|
127 |
}
|
128 |
|
129 |
+
if self.json_mode:
|
130 |
+
payload["format"] = "json"
|
131 |
+
|
132 |
with httpx.Client(timeout=Timeout(self.request_timeout)) as client:
|
133 |
response = client.post(
|
134 |
url=f"{self.base_url}/api/chat",
|
|
|
141 |
message=ChatMessage(
|
142 |
content=message.get("content"),
|
143 |
role=MessageRole(message.get("role")),
|
144 |
+
additional_kwargs=get_additional_kwargs(
|
145 |
message, ("content", "role")
|
146 |
),
|
147 |
),
|
148 |
raw=raw,
|
149 |
+
additional_kwargs=get_additional_kwargs(raw, ("message",)),
|
150 |
)
|
151 |
|
152 |
@llm_chat_callback()
|
|
|
168 |
**kwargs,
|
169 |
}
|
170 |
|
171 |
+
if self.json_mode:
|
172 |
+
payload["format"] = "json"
|
173 |
+
|
174 |
with httpx.Client(timeout=Timeout(self.request_timeout)) as client:
|
175 |
with client.stream(
|
176 |
method="POST",
|
|
|
191 |
message=ChatMessage(
|
192 |
content=text,
|
193 |
role=MessageRole(message.get("role")),
|
194 |
+
additional_kwargs=get_additional_kwargs(
|
195 |
message, ("content", "role")
|
196 |
),
|
197 |
),
|
198 |
delta=delta,
|
199 |
raw=chunk,
|
200 |
+
additional_kwargs=get_additional_kwargs(
|
201 |
+
chunk, ("message",)
|
202 |
+
),
|
203 |
)
|
204 |
|
205 |
+
@llm_chat_callback()
|
206 |
+
async def achat(
|
207 |
+
self, messages: Sequence[ChatMessage], **kwargs: Any
|
208 |
+
) -> ChatResponseAsyncGen:
|
209 |
+
payload = {
|
210 |
+
"model": self.model,
|
211 |
+
"messages": [
|
212 |
+
{
|
213 |
+
"role": message.role.value,
|
214 |
+
"content": message.content,
|
215 |
+
**message.additional_kwargs,
|
216 |
+
}
|
217 |
+
for message in messages
|
218 |
+
],
|
219 |
+
"options": self._model_kwargs,
|
220 |
+
"stream": False,
|
221 |
+
**kwargs,
|
222 |
+
}
|
223 |
+
|
224 |
+
if self.json_mode:
|
225 |
+
payload["format"] = "json"
|
226 |
+
|
227 |
+
async with httpx.AsyncClient(timeout=Timeout(self.request_timeout)) as client:
|
228 |
+
response = await client.post(
|
229 |
+
url=f"{self.base_url}/api/chat",
|
230 |
+
json=payload,
|
231 |
+
)
|
232 |
+
response.raise_for_status()
|
233 |
+
raw = response.json()
|
234 |
+
message = raw["message"]
|
235 |
+
return ChatResponse(
|
236 |
+
message=ChatMessage(
|
237 |
+
content=message.get("content"),
|
238 |
+
role=MessageRole(message.get("role")),
|
239 |
+
additional_kwargs=get_additional_kwargs(
|
240 |
+
message, ("content", "role")
|
241 |
+
),
|
242 |
+
),
|
243 |
+
raw=raw,
|
244 |
+
additional_kwargs=get_additional_kwargs(raw, ("message",)),
|
245 |
+
)
|
246 |
+
|
247 |
@llm_completion_callback()
|
248 |
def complete(
|
249 |
self, prompt: str, formatted: bool = False, **kwargs: Any
|
|
|
256 |
**kwargs,
|
257 |
}
|
258 |
|
259 |
+
if self.json_mode:
|
260 |
+
payload["format"] = "json"
|
261 |
+
|
262 |
with httpx.Client(timeout=Timeout(self.request_timeout)) as client:
|
263 |
response = client.post(
|
264 |
url=f"{self.base_url}/api/generate",
|
|
|
270 |
return CompletionResponse(
|
271 |
text=text,
|
272 |
raw=raw,
|
273 |
+
additional_kwargs=get_additional_kwargs(raw, ("response",)),
|
274 |
+
)
|
275 |
+
|
276 |
+
@llm_completion_callback()
|
277 |
+
async def acomplete(
|
278 |
+
self, prompt: str, formatted: bool = False, **kwargs: Any
|
279 |
+
) -> CompletionResponse:
|
280 |
+
payload = {
|
281 |
+
self.prompt_key: prompt,
|
282 |
+
"model": self.model,
|
283 |
+
"options": self._model_kwargs,
|
284 |
+
"stream": False,
|
285 |
+
**kwargs,
|
286 |
+
}
|
287 |
+
|
288 |
+
if self.json_mode:
|
289 |
+
payload["format"] = "json"
|
290 |
+
|
291 |
+
async with httpx.AsyncClient(timeout=Timeout(self.request_timeout)) as client:
|
292 |
+
response = await client.post(
|
293 |
+
url=f"{self.base_url}/api/generate",
|
294 |
+
json=payload,
|
295 |
+
)
|
296 |
+
response.raise_for_status()
|
297 |
+
raw = response.json()
|
298 |
+
text = raw.get("response")
|
299 |
+
return CompletionResponse(
|
300 |
+
text=text,
|
301 |
+
raw=raw,
|
302 |
+
additional_kwargs=get_additional_kwargs(raw, ("response",)),
|
303 |
)
|
304 |
|
305 |
@llm_completion_callback()
|
|
|
314 |
**kwargs,
|
315 |
}
|
316 |
|
317 |
+
if self.json_mode:
|
318 |
+
payload["format"] = "json"
|
319 |
+
|
320 |
with httpx.Client(timeout=Timeout(self.request_timeout)) as client:
|
321 |
with client.stream(
|
322 |
method="POST",
|
|
|
334 |
delta=delta,
|
335 |
text=text,
|
336 |
raw=chunk,
|
337 |
+
additional_kwargs=get_additional_kwargs(
|
338 |
chunk, ("response",)
|
339 |
),
|
340 |
+
)
|
341 |
+
|
342 |
+
@llm_completion_callback()
|
343 |
+
async def astream_complete(
|
344 |
+
self, prompt: str, formatted: bool = False, **kwargs: Any
|
345 |
+
) -> CompletionResponseAsyncGen:
|
346 |
+
payload = {
|
347 |
+
self.prompt_key: prompt,
|
348 |
+
"model": self.model,
|
349 |
+
"options": self._model_kwargs,
|
350 |
+
"stream": True,
|
351 |
+
**kwargs,
|
352 |
+
}
|
353 |
+
|
354 |
+
if self.json_mode:
|
355 |
+
payload["format"] = "json"
|
356 |
+
|
357 |
+
async def gen() -> CompletionResponseAsyncGen:
|
358 |
+
async with httpx.AsyncClient(
|
359 |
+
timeout=Timeout(self.request_timeout)
|
360 |
+
) as client:
|
361 |
+
async with client.stream(
|
362 |
+
method="POST",
|
363 |
+
url=f"{self.base_url}/api/generate",
|
364 |
+
json=payload,
|
365 |
+
) as response:
|
366 |
+
async for line in response.aiter_lines():
|
367 |
+
if line:
|
368 |
+
chunk = json.loads(line)
|
369 |
+
delta = chunk.get("response")
|
370 |
+
yield CompletionResponse(
|
371 |
+
delta=delta,
|
372 |
+
text=delta,
|
373 |
+
raw=chunk,
|
374 |
+
additional_kwargs=get_additional_kwargs(
|
375 |
+
chunk, ("response",)
|
376 |
+
),
|
377 |
+
)
|
378 |
+
|
379 |
+
return gen()
|