Entz commited on
Commit
04d2598
1 Parent(s): f69fed3

Update myollama.py

Browse files
Files changed (1) hide show
  1. myollama.py +30 -150
myollama.py CHANGED
@@ -1,25 +1,33 @@
1
- class Ollama(CustomLLM):
2
- """Ollama LLM.
3
-
4
- Visit https://ollama.com/ to download and install Ollama.
5
 
6
- Run `ollama serve` to start a server.
 
7
 
8
- Run `ollama pull <name>` to download a model to run.
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
- Examples:
11
- `pip install llama-index-llms-ollama`
12
 
13
- ```python
14
- from llama_index.llms.ollama import Ollama
15
 
16
- llm = Ollama(model="llama2", request_timeout=60.0)
 
 
 
17
 
18
- response = llm.complete("What is the capital of France?")
19
- print(response)
20
- ```
21
- """
22
 
 
23
  base_url: str = Field(
24
  default="http://localhost:11434",
25
  description="Base url the model is hosted under.",
@@ -43,10 +51,6 @@ class Ollama(CustomLLM):
43
  prompt_key: str = Field(
44
  default="prompt", description="The key to use for the prompt in API calls."
45
  )
46
- json_mode: bool = Field(
47
- default=False,
48
- description="Whether to use JSON mode for the Ollama API.",
49
- )
50
  additional_kwargs: Dict[str, Any] = Field(
51
  default_factory=dict,
52
  description="Additional model parameters for the Ollama API.",
@@ -94,9 +98,6 @@ class Ollama(CustomLLM):
94
  **kwargs,
95
  }
96
 
97
- if self.json_mode:
98
- payload["format"] = "json"
99
-
100
  with httpx.Client(timeout=Timeout(self.request_timeout)) as client:
101
  response = client.post(
102
  url=f"{self.base_url}/api/chat",
@@ -109,12 +110,12 @@ class Ollama(CustomLLM):
109
  message=ChatMessage(
110
  content=message.get("content"),
111
  role=MessageRole(message.get("role")),
112
- additional_kwargs=get_additional_kwargs(
113
  message, ("content", "role")
114
  ),
115
  ),
116
  raw=raw,
117
- additional_kwargs=get_additional_kwargs(raw, ("message",)),
118
  )
119
 
120
  @llm_chat_callback()
@@ -136,9 +137,6 @@ class Ollama(CustomLLM):
136
  **kwargs,
137
  }
138
 
139
- if self.json_mode:
140
- payload["format"] = "json"
141
-
142
  with httpx.Client(timeout=Timeout(self.request_timeout)) as client:
143
  with client.stream(
144
  method="POST",
@@ -159,59 +157,15 @@ class Ollama(CustomLLM):
159
  message=ChatMessage(
160
  content=text,
161
  role=MessageRole(message.get("role")),
162
- additional_kwargs=get_additional_kwargs(
163
  message, ("content", "role")
164
  ),
165
  ),
166
  delta=delta,
167
  raw=chunk,
168
- additional_kwargs=get_additional_kwargs(
169
- chunk, ("message",)
170
- ),
171
  )
172
 
173
- @llm_chat_callback()
174
- async def achat(
175
- self, messages: Sequence[ChatMessage], **kwargs: Any
176
- ) -> ChatResponseAsyncGen:
177
- payload = {
178
- "model": self.model,
179
- "messages": [
180
- {
181
- "role": message.role.value,
182
- "content": message.content,
183
- **message.additional_kwargs,
184
- }
185
- for message in messages
186
- ],
187
- "options": self._model_kwargs,
188
- "stream": False,
189
- **kwargs,
190
- }
191
-
192
- if self.json_mode:
193
- payload["format"] = "json"
194
-
195
- async with httpx.AsyncClient(timeout=Timeout(self.request_timeout)) as client:
196
- response = await client.post(
197
- url=f"{self.base_url}/api/chat",
198
- json=payload,
199
- )
200
- response.raise_for_status()
201
- raw = response.json()
202
- message = raw["message"]
203
- return ChatResponse(
204
- message=ChatMessage(
205
- content=message.get("content"),
206
- role=MessageRole(message.get("role")),
207
- additional_kwargs=get_additional_kwargs(
208
- message, ("content", "role")
209
- ),
210
- ),
211
- raw=raw,
212
- additional_kwargs=get_additional_kwargs(raw, ("message",)),
213
- )
214
-
215
  @llm_completion_callback()
216
  def complete(
217
  self, prompt: str, formatted: bool = False, **kwargs: Any
@@ -224,9 +178,6 @@ class Ollama(CustomLLM):
224
  **kwargs,
225
  }
226
 
227
- if self.json_mode:
228
- payload["format"] = "json"
229
-
230
  with httpx.Client(timeout=Timeout(self.request_timeout)) as client:
231
  response = client.post(
232
  url=f"{self.base_url}/api/generate",
@@ -238,36 +189,7 @@ class Ollama(CustomLLM):
238
  return CompletionResponse(
239
  text=text,
240
  raw=raw,
241
- additional_kwargs=get_additional_kwargs(raw, ("response",)),
242
- )
243
-
244
- @llm_completion_callback()
245
- async def acomplete(
246
- self, prompt: str, formatted: bool = False, **kwargs: Any
247
- ) -> CompletionResponse:
248
- payload = {
249
- self.prompt_key: prompt,
250
- "model": self.model,
251
- "options": self._model_kwargs,
252
- "stream": False,
253
- **kwargs,
254
- }
255
-
256
- if self.json_mode:
257
- payload["format"] = "json"
258
-
259
- async with httpx.AsyncClient(timeout=Timeout(self.request_timeout)) as client:
260
- response = await client.post(
261
- url=f"{self.base_url}/api/generate",
262
- json=payload,
263
- )
264
- response.raise_for_status()
265
- raw = response.json()
266
- text = raw.get("response")
267
- return CompletionResponse(
268
- text=text,
269
- raw=raw,
270
- additional_kwargs=get_additional_kwargs(raw, ("response",)),
271
  )
272
 
273
  @llm_completion_callback()
@@ -282,9 +204,6 @@ class Ollama(CustomLLM):
282
  **kwargs,
283
  }
284
 
285
- if self.json_mode:
286
- payload["format"] = "json"
287
-
288
  with httpx.Client(timeout=Timeout(self.request_timeout)) as client:
289
  with client.stream(
290
  method="POST",
@@ -302,46 +221,7 @@ class Ollama(CustomLLM):
302
  delta=delta,
303
  text=text,
304
  raw=chunk,
305
- additional_kwargs=get_additional_kwargs(
306
  chunk, ("response",)
307
  ),
308
- )
309
-
310
- @llm_completion_callback()
311
- async def astream_complete(
312
- self, prompt: str, formatted: bool = False, **kwargs: Any
313
- ) -> CompletionResponseAsyncGen:
314
- payload = {
315
- self.prompt_key: prompt,
316
- "model": self.model,
317
- "options": self._model_kwargs,
318
- "stream": True,
319
- **kwargs,
320
- }
321
-
322
- if self.json_mode:
323
- payload["format"] = "json"
324
-
325
- async def gen() -> CompletionResponseAsyncGen:
326
- async with httpx.AsyncClient(
327
- timeout=Timeout(self.request_timeout)
328
- ) as client:
329
- async with client.stream(
330
- method="POST",
331
- url=f"{self.base_url}/api/generate",
332
- json=payload,
333
- ) as response:
334
- async for line in response.aiter_lines():
335
- if line:
336
- chunk = json.loads(line)
337
- delta = chunk.get("response")
338
- yield CompletionResponse(
339
- delta=delta,
340
- text=delta,
341
- raw=chunk,
342
- additional_kwargs=get_additional_kwargs(
343
- chunk, ("response",)
344
- ),
345
- )
346
-
347
- return gen()
 
1
+ import json
2
+ from typing import Any, Dict, Sequence, Tuple
 
 
3
 
4
+ import httpx
5
+ from httpx import Timeout
6
 
7
+ from llama_index.legacy.bridge.pydantic import Field
8
+ from llama_index.legacy.constants import DEFAULT_CONTEXT_WINDOW, DEFAULT_NUM_OUTPUTS
9
+ from llama_index.legacy.core.llms.types import (
10
+ ChatMessage,
11
+ ChatResponse,
12
+ ChatResponseGen,
13
+ CompletionResponse,
14
+ CompletionResponseGen,
15
+ LLMMetadata,
16
+ MessageRole,
17
+ )
18
+ from llama_index.legacy.llms.base import llm_chat_callback, llm_completion_callback
19
+ from llama_index.legacy.llms.custom import CustomLLM
20
 
21
+ DEFAULT_REQUEST_TIMEOUT = 30.0
 
22
 
 
 
23
 
24
+ def get_addtional_kwargs(
25
+ response: Dict[str, Any], exclude: Tuple[str, ...]
26
+ ) -> Dict[str, Any]:
27
+ return {k: v for k, v in response.items() if k not in exclude}
28
 
 
 
 
 
29
 
30
+ class Ollama(CustomLLM):
31
  base_url: str = Field(
32
  default="http://localhost:11434",
33
  description="Base url the model is hosted under.",
 
51
  prompt_key: str = Field(
52
  default="prompt", description="The key to use for the prompt in API calls."
53
  )
 
 
 
 
54
  additional_kwargs: Dict[str, Any] = Field(
55
  default_factory=dict,
56
  description="Additional model parameters for the Ollama API.",
 
98
  **kwargs,
99
  }
100
 
 
 
 
101
  with httpx.Client(timeout=Timeout(self.request_timeout)) as client:
102
  response = client.post(
103
  url=f"{self.base_url}/api/chat",
 
110
  message=ChatMessage(
111
  content=message.get("content"),
112
  role=MessageRole(message.get("role")),
113
+ additional_kwargs=get_addtional_kwargs(
114
  message, ("content", "role")
115
  ),
116
  ),
117
  raw=raw,
118
+ additional_kwargs=get_addtional_kwargs(raw, ("message",)),
119
  )
120
 
121
  @llm_chat_callback()
 
137
  **kwargs,
138
  }
139
 
 
 
 
140
  with httpx.Client(timeout=Timeout(self.request_timeout)) as client:
141
  with client.stream(
142
  method="POST",
 
157
  message=ChatMessage(
158
  content=text,
159
  role=MessageRole(message.get("role")),
160
+ additional_kwargs=get_addtional_kwargs(
161
  message, ("content", "role")
162
  ),
163
  ),
164
  delta=delta,
165
  raw=chunk,
166
+ additional_kwargs=get_addtional_kwargs(chunk, ("message",)),
 
 
167
  )
168
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  @llm_completion_callback()
170
  def complete(
171
  self, prompt: str, formatted: bool = False, **kwargs: Any
 
178
  **kwargs,
179
  }
180
 
 
 
 
181
  with httpx.Client(timeout=Timeout(self.request_timeout)) as client:
182
  response = client.post(
183
  url=f"{self.base_url}/api/generate",
 
189
  return CompletionResponse(
190
  text=text,
191
  raw=raw,
192
+ additional_kwargs=get_addtional_kwargs(raw, ("response",)),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
  )
194
 
195
  @llm_completion_callback()
 
204
  **kwargs,
205
  }
206
 
 
 
 
207
  with httpx.Client(timeout=Timeout(self.request_timeout)) as client:
208
  with client.stream(
209
  method="POST",
 
221
  delta=delta,
222
  text=text,
223
  raw=chunk,
224
+ additional_kwargs=get_addtional_kwargs(
225
  chunk, ("response",)
226
  ),
227
+ )