Entz commited on
Commit
fd806b4
1 Parent(s): 04d2598

Update myollama.py

Browse files
Files changed (1) hide show
  1. myollama.py +167 -15
myollama.py CHANGED
@@ -1,35 +1,59 @@
 
 
1
  import json
2
  from typing import Any, Dict, Sequence, Tuple
3
 
4
  import httpx
5
  from httpx import Timeout
6
-
7
- from llama_index.legacy.bridge.pydantic import Field
8
- from llama_index.legacy.constants import DEFAULT_CONTEXT_WINDOW, DEFAULT_NUM_OUTPUTS
9
- from llama_index.legacy.core.llms.types import (
10
  ChatMessage,
11
  ChatResponse,
12
  ChatResponseGen,
 
13
  CompletionResponse,
 
14
  CompletionResponseGen,
15
  LLMMetadata,
16
  MessageRole,
17
  )
18
- from llama_index.legacy.llms.base import llm_chat_callback, llm_completion_callback
19
- from llama_index.legacy.llms.custom import CustomLLM
 
 
20
 
21
  DEFAULT_REQUEST_TIMEOUT = 30.0
22
 
23
 
24
- def get_addtional_kwargs(
25
  response: Dict[str, Any], exclude: Tuple[str, ...]
26
  ) -> Dict[str, Any]:
27
  return {k: v for k, v in response.items() if k not in exclude}
28
 
29
 
30
  class Ollama(CustomLLM):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  base_url: str = Field(
32
- default="http://localhost:11434",
33
  description="Base url the model is hosted under.",
34
  )
35
  model: str = Field(description="The Ollama model to use.")
@@ -51,6 +75,10 @@ class Ollama(CustomLLM):
51
  prompt_key: str = Field(
52
  default="prompt", description="The key to use for the prompt in API calls."
53
  )
 
 
 
 
54
  additional_kwargs: Dict[str, Any] = Field(
55
  default_factory=dict,
56
  description="Additional model parameters for the Ollama API.",
@@ -98,6 +126,9 @@ class Ollama(CustomLLM):
98
  **kwargs,
99
  }
100
 
 
 
 
101
  with httpx.Client(timeout=Timeout(self.request_timeout)) as client:
102
  response = client.post(
103
  url=f"{self.base_url}/api/chat",
@@ -110,12 +141,12 @@ class Ollama(CustomLLM):
110
  message=ChatMessage(
111
  content=message.get("content"),
112
  role=MessageRole(message.get("role")),
113
- additional_kwargs=get_addtional_kwargs(
114
  message, ("content", "role")
115
  ),
116
  ),
117
  raw=raw,
118
- additional_kwargs=get_addtional_kwargs(raw, ("message",)),
119
  )
120
 
121
  @llm_chat_callback()
@@ -137,6 +168,9 @@ class Ollama(CustomLLM):
137
  **kwargs,
138
  }
139
 
 
 
 
140
  with httpx.Client(timeout=Timeout(self.request_timeout)) as client:
141
  with client.stream(
142
  method="POST",
@@ -157,15 +191,59 @@ class Ollama(CustomLLM):
157
  message=ChatMessage(
158
  content=text,
159
  role=MessageRole(message.get("role")),
160
- additional_kwargs=get_addtional_kwargs(
161
  message, ("content", "role")
162
  ),
163
  ),
164
  delta=delta,
165
  raw=chunk,
166
- additional_kwargs=get_addtional_kwargs(chunk, ("message",)),
 
 
167
  )
168
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  @llm_completion_callback()
170
  def complete(
171
  self, prompt: str, formatted: bool = False, **kwargs: Any
@@ -178,6 +256,9 @@ class Ollama(CustomLLM):
178
  **kwargs,
179
  }
180
 
 
 
 
181
  with httpx.Client(timeout=Timeout(self.request_timeout)) as client:
182
  response = client.post(
183
  url=f"{self.base_url}/api/generate",
@@ -189,7 +270,36 @@ class Ollama(CustomLLM):
189
  return CompletionResponse(
190
  text=text,
191
  raw=raw,
192
- additional_kwargs=get_addtional_kwargs(raw, ("response",)),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
  )
194
 
195
  @llm_completion_callback()
@@ -204,6 +314,9 @@ class Ollama(CustomLLM):
204
  **kwargs,
205
  }
206
 
 
 
 
207
  with httpx.Client(timeout=Timeout(self.request_timeout)) as client:
208
  with client.stream(
209
  method="POST",
@@ -221,7 +334,46 @@ class Ollama(CustomLLM):
221
  delta=delta,
222
  text=text,
223
  raw=chunk,
224
- additional_kwargs=get_addtional_kwargs(
225
  chunk, ("response",)
226
  ),
227
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/run-llama/llama_index/blob/main/llama-index-integrations/llms/llama-index-llms-ollama/llama_index/llms/ollama/base.py
2
+
3
  import json
4
  from typing import Any, Dict, Sequence, Tuple
5
 
6
  import httpx
7
  from httpx import Timeout
8
+ from llama_index.core.base.llms.types import (
 
 
 
9
  ChatMessage,
10
  ChatResponse,
11
  ChatResponseGen,
12
+ ChatResponseAsyncGen,
13
  CompletionResponse,
14
+ CompletionResponseAsyncGen,
15
  CompletionResponseGen,
16
  LLMMetadata,
17
  MessageRole,
18
  )
19
+ from llama_index.core.bridge.pydantic import Field
20
+ from llama_index.core.constants import DEFAULT_CONTEXT_WINDOW, DEFAULT_NUM_OUTPUTS
21
+ from llama_index.core.llms.callbacks import llm_chat_callback, llm_completion_callback
22
+ from llama_index.core.llms.custom import CustomLLM
23
 
24
  DEFAULT_REQUEST_TIMEOUT = 30.0
25
 
26
 
27
+ def get_additional_kwargs(
28
  response: Dict[str, Any], exclude: Tuple[str, ...]
29
  ) -> Dict[str, Any]:
30
  return {k: v for k, v in response.items() if k not in exclude}
31
 
32
 
33
  class Ollama(CustomLLM):
34
+ """Ollama LLM.
35
+
36
+ Visit https://ollama.com/ to download and install Ollama.
37
+
38
+ Run `ollama serve` to start a server.
39
+
40
+ Run `ollama pull <name>` to download a model to run.
41
+
42
+ Examples:
43
+ `pip install llama-index-llms-ollama`
44
+
45
+ ```python
46
+ from llama_index.llms.ollama import Ollama
47
+
48
+ llm = Ollama(model="llama2", request_timeout=60.0)
49
+
50
+ response = llm.complete("What is the capital of France?")
51
+ print(response)
52
+ ```
53
+ """
54
+
55
  base_url: str = Field(
56
+ default="http://localhost:11435",
57
  description="Base url the model is hosted under.",
58
  )
59
  model: str = Field(description="The Ollama model to use.")
 
75
  prompt_key: str = Field(
76
  default="prompt", description="The key to use for the prompt in API calls."
77
  )
78
+ json_mode: bool = Field(
79
+ default=False,
80
+ description="Whether to use JSON mode for the Ollama API.",
81
+ )
82
  additional_kwargs: Dict[str, Any] = Field(
83
  default_factory=dict,
84
  description="Additional model parameters for the Ollama API.",
 
126
  **kwargs,
127
  }
128
 
129
+ if self.json_mode:
130
+ payload["format"] = "json"
131
+
132
  with httpx.Client(timeout=Timeout(self.request_timeout)) as client:
133
  response = client.post(
134
  url=f"{self.base_url}/api/chat",
 
141
  message=ChatMessage(
142
  content=message.get("content"),
143
  role=MessageRole(message.get("role")),
144
+ additional_kwargs=get_additional_kwargs(
145
  message, ("content", "role")
146
  ),
147
  ),
148
  raw=raw,
149
+ additional_kwargs=get_additional_kwargs(raw, ("message",)),
150
  )
151
 
152
  @llm_chat_callback()
 
168
  **kwargs,
169
  }
170
 
171
+ if self.json_mode:
172
+ payload["format"] = "json"
173
+
174
  with httpx.Client(timeout=Timeout(self.request_timeout)) as client:
175
  with client.stream(
176
  method="POST",
 
191
  message=ChatMessage(
192
  content=text,
193
  role=MessageRole(message.get("role")),
194
+ additional_kwargs=get_additional_kwargs(
195
  message, ("content", "role")
196
  ),
197
  ),
198
  delta=delta,
199
  raw=chunk,
200
+ additional_kwargs=get_additional_kwargs(
201
+ chunk, ("message",)
202
+ ),
203
  )
204
 
205
+ @llm_chat_callback()
206
+ async def achat(
207
+ self, messages: Sequence[ChatMessage], **kwargs: Any
208
+ ) -> ChatResponseAsyncGen:
209
+ payload = {
210
+ "model": self.model,
211
+ "messages": [
212
+ {
213
+ "role": message.role.value,
214
+ "content": message.content,
215
+ **message.additional_kwargs,
216
+ }
217
+ for message in messages
218
+ ],
219
+ "options": self._model_kwargs,
220
+ "stream": False,
221
+ **kwargs,
222
+ }
223
+
224
+ if self.json_mode:
225
+ payload["format"] = "json"
226
+
227
+ async with httpx.AsyncClient(timeout=Timeout(self.request_timeout)) as client:
228
+ response = await client.post(
229
+ url=f"{self.base_url}/api/chat",
230
+ json=payload,
231
+ )
232
+ response.raise_for_status()
233
+ raw = response.json()
234
+ message = raw["message"]
235
+ return ChatResponse(
236
+ message=ChatMessage(
237
+ content=message.get("content"),
238
+ role=MessageRole(message.get("role")),
239
+ additional_kwargs=get_additional_kwargs(
240
+ message, ("content", "role")
241
+ ),
242
+ ),
243
+ raw=raw,
244
+ additional_kwargs=get_additional_kwargs(raw, ("message",)),
245
+ )
246
+
247
  @llm_completion_callback()
248
  def complete(
249
  self, prompt: str, formatted: bool = False, **kwargs: Any
 
256
  **kwargs,
257
  }
258
 
259
+ if self.json_mode:
260
+ payload["format"] = "json"
261
+
262
  with httpx.Client(timeout=Timeout(self.request_timeout)) as client:
263
  response = client.post(
264
  url=f"{self.base_url}/api/generate",
 
270
  return CompletionResponse(
271
  text=text,
272
  raw=raw,
273
+ additional_kwargs=get_additional_kwargs(raw, ("response",)),
274
+ )
275
+
276
+ @llm_completion_callback()
277
+ async def acomplete(
278
+ self, prompt: str, formatted: bool = False, **kwargs: Any
279
+ ) -> CompletionResponse:
280
+ payload = {
281
+ self.prompt_key: prompt,
282
+ "model": self.model,
283
+ "options": self._model_kwargs,
284
+ "stream": False,
285
+ **kwargs,
286
+ }
287
+
288
+ if self.json_mode:
289
+ payload["format"] = "json"
290
+
291
+ async with httpx.AsyncClient(timeout=Timeout(self.request_timeout)) as client:
292
+ response = await client.post(
293
+ url=f"{self.base_url}/api/generate",
294
+ json=payload,
295
+ )
296
+ response.raise_for_status()
297
+ raw = response.json()
298
+ text = raw.get("response")
299
+ return CompletionResponse(
300
+ text=text,
301
+ raw=raw,
302
+ additional_kwargs=get_additional_kwargs(raw, ("response",)),
303
  )
304
 
305
  @llm_completion_callback()
 
314
  **kwargs,
315
  }
316
 
317
+ if self.json_mode:
318
+ payload["format"] = "json"
319
+
320
  with httpx.Client(timeout=Timeout(self.request_timeout)) as client:
321
  with client.stream(
322
  method="POST",
 
334
  delta=delta,
335
  text=text,
336
  raw=chunk,
337
+ additional_kwargs=get_additional_kwargs(
338
  chunk, ("response",)
339
  ),
340
+ )
341
+
342
+ @llm_completion_callback()
343
+ async def astream_complete(
344
+ self, prompt: str, formatted: bool = False, **kwargs: Any
345
+ ) -> CompletionResponseAsyncGen:
346
+ payload = {
347
+ self.prompt_key: prompt,
348
+ "model": self.model,
349
+ "options": self._model_kwargs,
350
+ "stream": True,
351
+ **kwargs,
352
+ }
353
+
354
+ if self.json_mode:
355
+ payload["format"] = "json"
356
+
357
+ async def gen() -> CompletionResponseAsyncGen:
358
+ async with httpx.AsyncClient(
359
+ timeout=Timeout(self.request_timeout)
360
+ ) as client:
361
+ async with client.stream(
362
+ method="POST",
363
+ url=f"{self.base_url}/api/generate",
364
+ json=payload,
365
+ ) as response:
366
+ async for line in response.aiter_lines():
367
+ if line:
368
+ chunk = json.loads(line)
369
+ delta = chunk.get("response")
370
+ yield CompletionResponse(
371
+ delta=delta,
372
+ text=delta,
373
+ raw=chunk,
374
+ additional_kwargs=get_additional_kwargs(
375
+ chunk, ("response",)
376
+ ),
377
+ )
378
+
379
+ return gen()