Entz commited on
Commit
f69fed3
1 Parent(s): 3d53a45

Upload myollama.py

Browse files
Files changed (1) hide show
  1. myollama.py +347 -0
myollama.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class Ollama(CustomLLM):
2
+ """Ollama LLM.
3
+
4
+ Visit https://ollama.com/ to download and install Ollama.
5
+
6
+ Run `ollama serve` to start a server.
7
+
8
+ Run `ollama pull <name>` to download a model to run.
9
+
10
+ Examples:
11
+ `pip install llama-index-llms-ollama`
12
+
13
+ ```python
14
+ from llama_index.llms.ollama import Ollama
15
+
16
+ llm = Ollama(model="llama2", request_timeout=60.0)
17
+
18
+ response = llm.complete("What is the capital of France?")
19
+ print(response)
20
+ ```
21
+ """
22
+
23
+ base_url: str = Field(
24
+ default="http://localhost:11434",
25
+ description="Base url the model is hosted under.",
26
+ )
27
+ model: str = Field(description="The Ollama model to use.")
28
+ temperature: float = Field(
29
+ default=0.75,
30
+ description="The temperature to use for sampling.",
31
+ gte=0.0,
32
+ lte=1.0,
33
+ )
34
+ context_window: int = Field(
35
+ default=DEFAULT_CONTEXT_WINDOW,
36
+ description="The maximum number of context tokens for the model.",
37
+ gt=0,
38
+ )
39
+ request_timeout: float = Field(
40
+ default=DEFAULT_REQUEST_TIMEOUT,
41
+ description="The timeout for making http request to Ollama API server",
42
+ )
43
+ prompt_key: str = Field(
44
+ default="prompt", description="The key to use for the prompt in API calls."
45
+ )
46
+ json_mode: bool = Field(
47
+ default=False,
48
+ description="Whether to use JSON mode for the Ollama API.",
49
+ )
50
+ additional_kwargs: Dict[str, Any] = Field(
51
+ default_factory=dict,
52
+ description="Additional model parameters for the Ollama API.",
53
+ )
54
+
55
+ @classmethod
56
+ def class_name(cls) -> str:
57
+ return "Ollama_llm"
58
+
59
+ @property
60
+ def metadata(self) -> LLMMetadata:
61
+ """LLM metadata."""
62
+ return LLMMetadata(
63
+ context_window=self.context_window,
64
+ num_output=DEFAULT_NUM_OUTPUTS,
65
+ model_name=self.model,
66
+ is_chat_model=True, # Ollama supports chat API for all models
67
+ )
68
+
69
+ @property
70
+ def _model_kwargs(self) -> Dict[str, Any]:
71
+ base_kwargs = {
72
+ "temperature": self.temperature,
73
+ "num_ctx": self.context_window,
74
+ }
75
+ return {
76
+ **base_kwargs,
77
+ **self.additional_kwargs,
78
+ }
79
+
80
+ @llm_chat_callback()
81
+ def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
82
+ payload = {
83
+ "model": self.model,
84
+ "messages": [
85
+ {
86
+ "role": message.role.value,
87
+ "content": message.content,
88
+ **message.additional_kwargs,
89
+ }
90
+ for message in messages
91
+ ],
92
+ "options": self._model_kwargs,
93
+ "stream": False,
94
+ **kwargs,
95
+ }
96
+
97
+ if self.json_mode:
98
+ payload["format"] = "json"
99
+
100
+ with httpx.Client(timeout=Timeout(self.request_timeout)) as client:
101
+ response = client.post(
102
+ url=f"{self.base_url}/api/chat",
103
+ json=payload,
104
+ )
105
+ response.raise_for_status()
106
+ raw = response.json()
107
+ message = raw["message"]
108
+ return ChatResponse(
109
+ message=ChatMessage(
110
+ content=message.get("content"),
111
+ role=MessageRole(message.get("role")),
112
+ additional_kwargs=get_additional_kwargs(
113
+ message, ("content", "role")
114
+ ),
115
+ ),
116
+ raw=raw,
117
+ additional_kwargs=get_additional_kwargs(raw, ("message",)),
118
+ )
119
+
120
+ @llm_chat_callback()
121
+ def stream_chat(
122
+ self, messages: Sequence[ChatMessage], **kwargs: Any
123
+ ) -> ChatResponseGen:
124
+ payload = {
125
+ "model": self.model,
126
+ "messages": [
127
+ {
128
+ "role": message.role.value,
129
+ "content": message.content,
130
+ **message.additional_kwargs,
131
+ }
132
+ for message in messages
133
+ ],
134
+ "options": self._model_kwargs,
135
+ "stream": True,
136
+ **kwargs,
137
+ }
138
+
139
+ if self.json_mode:
140
+ payload["format"] = "json"
141
+
142
+ with httpx.Client(timeout=Timeout(self.request_timeout)) as client:
143
+ with client.stream(
144
+ method="POST",
145
+ url=f"{self.base_url}/api/chat",
146
+ json=payload,
147
+ ) as response:
148
+ response.raise_for_status()
149
+ text = ""
150
+ for line in response.iter_lines():
151
+ if line:
152
+ chunk = json.loads(line)
153
+ if "done" in chunk and chunk["done"]:
154
+ break
155
+ message = chunk["message"]
156
+ delta = message.get("content")
157
+ text += delta
158
+ yield ChatResponse(
159
+ message=ChatMessage(
160
+ content=text,
161
+ role=MessageRole(message.get("role")),
162
+ additional_kwargs=get_additional_kwargs(
163
+ message, ("content", "role")
164
+ ),
165
+ ),
166
+ delta=delta,
167
+ raw=chunk,
168
+ additional_kwargs=get_additional_kwargs(
169
+ chunk, ("message",)
170
+ ),
171
+ )
172
+
173
+ @llm_chat_callback()
174
+ async def achat(
175
+ self, messages: Sequence[ChatMessage], **kwargs: Any
176
+ ) -> ChatResponseAsyncGen:
177
+ payload = {
178
+ "model": self.model,
179
+ "messages": [
180
+ {
181
+ "role": message.role.value,
182
+ "content": message.content,
183
+ **message.additional_kwargs,
184
+ }
185
+ for message in messages
186
+ ],
187
+ "options": self._model_kwargs,
188
+ "stream": False,
189
+ **kwargs,
190
+ }
191
+
192
+ if self.json_mode:
193
+ payload["format"] = "json"
194
+
195
+ async with httpx.AsyncClient(timeout=Timeout(self.request_timeout)) as client:
196
+ response = await client.post(
197
+ url=f"{self.base_url}/api/chat",
198
+ json=payload,
199
+ )
200
+ response.raise_for_status()
201
+ raw = response.json()
202
+ message = raw["message"]
203
+ return ChatResponse(
204
+ message=ChatMessage(
205
+ content=message.get("content"),
206
+ role=MessageRole(message.get("role")),
207
+ additional_kwargs=get_additional_kwargs(
208
+ message, ("content", "role")
209
+ ),
210
+ ),
211
+ raw=raw,
212
+ additional_kwargs=get_additional_kwargs(raw, ("message",)),
213
+ )
214
+
215
+ @llm_completion_callback()
216
+ def complete(
217
+ self, prompt: str, formatted: bool = False, **kwargs: Any
218
+ ) -> CompletionResponse:
219
+ payload = {
220
+ self.prompt_key: prompt,
221
+ "model": self.model,
222
+ "options": self._model_kwargs,
223
+ "stream": False,
224
+ **kwargs,
225
+ }
226
+
227
+ if self.json_mode:
228
+ payload["format"] = "json"
229
+
230
+ with httpx.Client(timeout=Timeout(self.request_timeout)) as client:
231
+ response = client.post(
232
+ url=f"{self.base_url}/api/generate",
233
+ json=payload,
234
+ )
235
+ response.raise_for_status()
236
+ raw = response.json()
237
+ text = raw.get("response")
238
+ return CompletionResponse(
239
+ text=text,
240
+ raw=raw,
241
+ additional_kwargs=get_additional_kwargs(raw, ("response",)),
242
+ )
243
+
244
+ @llm_completion_callback()
245
+ async def acomplete(
246
+ self, prompt: str, formatted: bool = False, **kwargs: Any
247
+ ) -> CompletionResponse:
248
+ payload = {
249
+ self.prompt_key: prompt,
250
+ "model": self.model,
251
+ "options": self._model_kwargs,
252
+ "stream": False,
253
+ **kwargs,
254
+ }
255
+
256
+ if self.json_mode:
257
+ payload["format"] = "json"
258
+
259
+ async with httpx.AsyncClient(timeout=Timeout(self.request_timeout)) as client:
260
+ response = await client.post(
261
+ url=f"{self.base_url}/api/generate",
262
+ json=payload,
263
+ )
264
+ response.raise_for_status()
265
+ raw = response.json()
266
+ text = raw.get("response")
267
+ return CompletionResponse(
268
+ text=text,
269
+ raw=raw,
270
+ additional_kwargs=get_additional_kwargs(raw, ("response",)),
271
+ )
272
+
273
+ @llm_completion_callback()
274
+ def stream_complete(
275
+ self, prompt: str, formatted: bool = False, **kwargs: Any
276
+ ) -> CompletionResponseGen:
277
+ payload = {
278
+ self.prompt_key: prompt,
279
+ "model": self.model,
280
+ "options": self._model_kwargs,
281
+ "stream": True,
282
+ **kwargs,
283
+ }
284
+
285
+ if self.json_mode:
286
+ payload["format"] = "json"
287
+
288
+ with httpx.Client(timeout=Timeout(self.request_timeout)) as client:
289
+ with client.stream(
290
+ method="POST",
291
+ url=f"{self.base_url}/api/generate",
292
+ json=payload,
293
+ ) as response:
294
+ response.raise_for_status()
295
+ text = ""
296
+ for line in response.iter_lines():
297
+ if line:
298
+ chunk = json.loads(line)
299
+ delta = chunk.get("response")
300
+ text += delta
301
+ yield CompletionResponse(
302
+ delta=delta,
303
+ text=text,
304
+ raw=chunk,
305
+ additional_kwargs=get_additional_kwargs(
306
+ chunk, ("response",)
307
+ ),
308
+ )
309
+
310
+ @llm_completion_callback()
311
+ async def astream_complete(
312
+ self, prompt: str, formatted: bool = False, **kwargs: Any
313
+ ) -> CompletionResponseAsyncGen:
314
+ payload = {
315
+ self.prompt_key: prompt,
316
+ "model": self.model,
317
+ "options": self._model_kwargs,
318
+ "stream": True,
319
+ **kwargs,
320
+ }
321
+
322
+ if self.json_mode:
323
+ payload["format"] = "json"
324
+
325
+ async def gen() -> CompletionResponseAsyncGen:
326
+ async with httpx.AsyncClient(
327
+ timeout=Timeout(self.request_timeout)
328
+ ) as client:
329
+ async with client.stream(
330
+ method="POST",
331
+ url=f"{self.base_url}/api/generate",
332
+ json=payload,
333
+ ) as response:
334
+ async for line in response.aiter_lines():
335
+ if line:
336
+ chunk = json.loads(line)
337
+ delta = chunk.get("response")
338
+ yield CompletionResponse(
339
+ delta=delta,
340
+ text=delta,
341
+ raw=chunk,
342
+ additional_kwargs=get_additional_kwargs(
343
+ chunk, ("response",)
344
+ ),
345
+ )
346
+
347
+ return gen()