nsarrazin HF staff commited on
Commit
14f0244
1 Parent(s): cd5cd0c

Add ollama endpoint support (#569)

Browse files

* Add ollama endpoint support

* replace if by switch

* Add Ollama example in docs

README.md CHANGED
@@ -313,6 +313,41 @@ MODELS=[
313
 
314
  Start chat-ui with `npm run dev` and you should be able to chat with Zephyr locally.
315
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
316
  #### Amazon
317
 
318
  You can also specify your Amazon SageMaker instance as an endpoint for chat-ui. The config goes like this:
 
313
 
314
  Start chat-ui with `npm run dev` and you should be able to chat with Zephyr locally.
315
 
316
+ #### Ollama
317
+
318
+ We also support the Ollama inference server. Spin up a model with
319
+
320
+ ```cli
321
+ ollama run mistral
322
+ ```
323
+
324
+ Then specify the endpoints like so:
325
+
326
+ ```env
327
+ MODELS=[
328
+ {
329
+ "name": "Ollama Mistral",
330
+ "chatPromptTemplate": "<s>{{#each messages}}{{#ifUser}}[INST] {{#if @first}}{{#if @root.preprompt}}{{@root.preprompt}}\n{{/if}}{{/if}} {{content}} [/INST]{{/ifUser}}{{#ifAssistant}}{{content}}</s> {{/ifAssistant}}{{/each}}",
331
+ "parameters": {
332
+ "temperature": 0.1,
333
+ "top_p": 0.95,
334
+ "repetition_penalty": 1.2,
335
+ "top_k": 50,
336
+ "truncate": 3072,
337
+ "max_new_tokens": 1024,
338
+ "stop": ["</s>"]
339
+ },
340
+ "endpoints": [
341
+ {
342
+ "type": "ollama",
343
+ "url" : "http://127.0.0.1:11434",
344
+ "ollamaName" : "mistral"
345
+ }
346
+ ]
347
+ }
348
+ ]
349
+ ```
350
+
351
  #### Amazon
352
 
353
  You can also specify your Amazon SageMaker instance as an endpoint for chat-ui. The config goes like this:
src/lib/server/endpoints/endpoints.ts CHANGED
@@ -5,6 +5,7 @@ import { z } from "zod";
5
  import endpointAws, { endpointAwsParametersSchema } from "./aws/endpointAws";
6
  import { endpointOAIParametersSchema, endpointOai } from "./openai/endpointOai";
7
  import endpointLlamacpp, { endpointLlamacppParametersSchema } from "./llamacpp/endpointLlamacpp";
 
8
 
9
  // parameters passed when generating text
10
  interface EndpointParameters {
@@ -32,6 +33,7 @@ export const endpoints = {
32
  aws: endpointAws,
33
  openai: endpointOai,
34
  llamacpp: endpointLlamacpp,
 
35
  };
36
 
37
  export const endpointSchema = z.discriminatedUnion("type", [
@@ -39,5 +41,6 @@ export const endpointSchema = z.discriminatedUnion("type", [
39
  endpointOAIParametersSchema,
40
  endpointTgiParametersSchema,
41
  endpointLlamacppParametersSchema,
 
42
  ]);
43
  export default endpoints;
 
5
  import endpointAws, { endpointAwsParametersSchema } from "./aws/endpointAws";
6
  import { endpointOAIParametersSchema, endpointOai } from "./openai/endpointOai";
7
  import endpointLlamacpp, { endpointLlamacppParametersSchema } from "./llamacpp/endpointLlamacpp";
8
+ import endpointOllama, { endpointOllamaParametersSchema } from "./ollama/endpointOllama";
9
 
10
  // parameters passed when generating text
11
  interface EndpointParameters {
 
33
  aws: endpointAws,
34
  openai: endpointOai,
35
  llamacpp: endpointLlamacpp,
36
+ ollama: endpointOllama,
37
  };
38
 
39
  export const endpointSchema = z.discriminatedUnion("type", [
 
41
  endpointOAIParametersSchema,
42
  endpointTgiParametersSchema,
43
  endpointLlamacppParametersSchema,
44
+ endpointOllamaParametersSchema,
45
  ]);
46
  export default endpoints;
src/lib/server/endpoints/llamacpp/endpointLlamacpp.ts CHANGED
@@ -8,7 +8,7 @@ export const endpointLlamacppParametersSchema = z.object({
8
  weight: z.number().int().positive().default(1),
9
  model: z.any(),
10
  type: z.literal("llamacpp"),
11
- url: z.string().url(),
12
  accessToken: z.string().min(1).default(HF_ACCESS_TOKEN),
13
  });
14
 
 
8
  weight: z.number().int().positive().default(1),
9
  model: z.any(),
10
  type: z.literal("llamacpp"),
11
+ url: z.string().url().default("http://127.0.0.1:8080"),
12
  accessToken: z.string().min(1).default(HF_ACCESS_TOKEN),
13
  });
14
 
src/lib/server/endpoints/ollama/endpointOllama.ts ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { buildPrompt } from "$lib/buildPrompt";
2
+ import type { TextGenerationStreamOutput } from "@huggingface/inference";
3
+ import type { Endpoint } from "../endpoints";
4
+ import { z } from "zod";
5
+
6
+ export const endpointOllamaParametersSchema = z.object({
7
+ weight: z.number().int().positive().default(1),
8
+ model: z.any(),
9
+ type: z.literal("ollama"),
10
+ url: z.string().url().default("http://127.0.0.1:11434"),
11
+ ollamaName: z.string().min(1).optional(),
12
+ });
13
+
14
+ export function endpointOllama({
15
+ url,
16
+ model,
17
+ ollamaName,
18
+ }: z.infer<typeof endpointOllamaParametersSchema>): Endpoint {
19
+ return async ({ conversation }) => {
20
+ const prompt = await buildPrompt({
21
+ messages: conversation.messages,
22
+ webSearch: conversation.messages[conversation.messages.length - 1].webSearch,
23
+ preprompt: conversation.preprompt,
24
+ model,
25
+ });
26
+
27
+ const r = await fetch(`${url}/api/generate`, {
28
+ method: "POST",
29
+ headers: {
30
+ "Content-Type": "application/json",
31
+ },
32
+ body: JSON.stringify({
33
+ prompt,
34
+ model: ollamaName ?? model.name,
35
+ raw: true,
36
+ options: {
37
+ top_p: model.parameters.top_p,
38
+ top_k: model.parameters.top_k,
39
+ temperature: model.parameters.temperature,
40
+ repeat_penalty: model.parameters.repetition_penalty,
41
+ stop: model.parameters.stop,
42
+ num_predict: model.parameters.max_new_tokens,
43
+ },
44
+ }),
45
+ });
46
+
47
+ if (!r.ok) {
48
+ throw new Error(`Failed to generate text: ${await r.text()}`);
49
+ }
50
+
51
+ const encoder = new TextDecoderStream();
52
+ const reader = r.body?.pipeThrough(encoder).getReader();
53
+
54
+ return (async function* () {
55
+ let generatedText = "";
56
+ let tokenId = 0;
57
+ let stop = false;
58
+ while (!stop) {
59
+ // read the stream and log the outputs to console
60
+ const out = (await reader?.read()) ?? { done: false, value: undefined };
61
+ // we read, if it's done we cancel
62
+ if (out.done) {
63
+ reader?.cancel();
64
+ return;
65
+ }
66
+
67
+ if (!out.value) {
68
+ return;
69
+ }
70
+
71
+ let data = null;
72
+ try {
73
+ data = JSON.parse(out.value);
74
+ } catch (e) {
75
+ return;
76
+ }
77
+ if (!data.done) {
78
+ generatedText += data.response;
79
+
80
+ yield {
81
+ token: {
82
+ id: tokenId++,
83
+ text: data.response ?? "",
84
+ logprob: 0,
85
+ special: false,
86
+ },
87
+ generated_text: null,
88
+ details: null,
89
+ } satisfies TextGenerationStreamOutput;
90
+ } else {
91
+ stop = true;
92
+ yield {
93
+ token: {
94
+ id: tokenId++,
95
+ text: data.response ?? "",
96
+ logprob: 0,
97
+ special: true,
98
+ },
99
+ generated_text: generatedText,
100
+ details: null,
101
+ } satisfies TextGenerationStreamOutput;
102
+ }
103
+ }
104
+ })();
105
+ };
106
+ }
107
+
108
+ export default endpointOllama;
src/lib/server/models.ts CHANGED
@@ -48,7 +48,7 @@ const modelConfig = z.object({
48
  parameters: z
49
  .object({
50
  temperature: z.number().min(0).max(1),
51
- truncate: z.number().int().positive(),
52
  max_new_tokens: z.number().int().positive(),
53
  stop: z.array(z.string()).optional(),
54
  top_p: z.number().positive().optional(),
@@ -92,17 +92,21 @@ const addEndpoint = (m: Awaited<ReturnType<typeof processModel>>) => ({
92
  for (const endpoint of m.endpoints) {
93
  if (random < endpoint.weight) {
94
  const args = { ...endpoint, model: m };
95
- if (args.type === "tgi") {
96
- return endpoints.tgi(args);
97
- } else if (args.type === "aws") {
98
- return await endpoints.aws(args);
99
- } else if (args.type === "openai") {
100
- return await endpoints.openai(args);
101
- } else if (args.type === "llamacpp") {
102
- return await endpoints.llamacpp(args);
103
- } else {
104
- // for legacy reason
105
- return await endpoints.tgi(args);
 
 
 
 
106
  }
107
  }
108
  random -= endpoint.weight;
 
48
  parameters: z
49
  .object({
50
  temperature: z.number().min(0).max(1),
51
+ truncate: z.number().int().positive().optional(),
52
  max_new_tokens: z.number().int().positive(),
53
  stop: z.array(z.string()).optional(),
54
  top_p: z.number().positive().optional(),
 
92
  for (const endpoint of m.endpoints) {
93
  if (random < endpoint.weight) {
94
  const args = { ...endpoint, model: m };
95
+
96
+ switch (args.type) {
97
+ case "tgi":
98
+ return endpoints.tgi(args);
99
+ case "aws":
100
+ return await endpoints.aws(args);
101
+ case "openai":
102
+ return await endpoints.openai(args);
103
+ case "llamacpp":
104
+ return endpoints.llamacpp(args);
105
+ case "ollama":
106
+ return endpoints.ollama(args);
107
+ default:
108
+ // for legacy reason
109
+ return endpoints.tgi(args);
110
  }
111
  }
112
  random -= endpoint.weight;