nsarrazin HF staff commited on
Commit
0819256
1 Parent(s): e027921

Use jinja template for chat formatting (#730) (#744)

Browse files

* Use jinja template for chat formatting

* Add support for transformers js chat template

* update to latest transformers version

* Make sure to `add_generation_prompt`

* unindent

.env.template CHANGED
@@ -64,7 +64,7 @@ MODELS=`[
64
  "description": "The latest and biggest model from Meta, fine-tuned for chat.",
65
  "logoUrl": "https://huggingface.co/datasets/huggingchat/models-logo/resolve/main/meta-logo.png",
66
  "websiteUrl": "https://ai.meta.com/llama/",
67
- "preprompt": " ",
68
  "chatPromptTemplate" : "<s>[INST] <<SYS>>\n{{preprompt}}\n<</SYS>>\n\n{{#each messages}}{{#ifUser}}{{content}} [/INST] {{/ifUser}}{{#ifAssistant}}{{content}} </s><s>[INST] {{/ifAssistant}}{{/each}}",
69
  "promptExamples": [
70
  {
 
64
  "description": "The latest and biggest model from Meta, fine-tuned for chat.",
65
  "logoUrl": "https://huggingface.co/datasets/huggingchat/models-logo/resolve/main/meta-logo.png",
66
  "websiteUrl": "https://ai.meta.com/llama/",
67
+ "preprompt": "",
68
  "chatPromptTemplate" : "<s>[INST] <<SYS>>\n{{preprompt}}\n<</SYS>>\n\n{{#each messages}}{{#ifUser}}{{content}} [/INST] {{/ifUser}}{{#ifAssistant}}{{content}} </s><s>[INST] {{/ifAssistant}}{{/each}}",
69
  "promptExamples": [
70
  {
package-lock.json CHANGED
@@ -12,7 +12,7 @@
12
  "@huggingface/inference": "^2.6.3",
13
  "@iconify-json/bi": "^1.1.21",
14
  "@resvg/resvg-js": "^2.6.0",
15
- "@xenova/transformers": "^2.6.0",
16
  "autoprefixer": "^10.4.14",
17
  "browser-image-resizer": "^2.4.1",
18
  "date-fns": "^2.29.3",
@@ -660,6 +660,14 @@
660
  "node": ">=18"
661
  }
662
  },
 
 
 
 
 
 
 
 
663
  "node_modules/@humanwhocodes/config-array": {
664
  "version": "0.11.8",
665
  "resolved": "https://registry.npmjs.org/@humanwhocodes/config-array/-/config-array-0.11.8.tgz",
@@ -2407,10 +2415,11 @@
2407
  }
2408
  },
2409
  "node_modules/@xenova/transformers": {
2410
- "version": "2.6.0",
2411
- "resolved": "https://registry.npmjs.org/@xenova/transformers/-/transformers-2.6.0.tgz",
2412
- "integrity": "sha512-k9bs+reiwhn+kx0d4FYnlBTWtl8D5Q4fIzoKYxKbTTSVyS33KXbQESRpdIxiU9gtlMKML2Sw0Oep4FYK9dQCsQ==",
2413
  "dependencies": {
 
2414
  "onnxruntime-web": "1.14.0",
2415
  "sharp": "^0.32.0"
2416
  },
 
12
  "@huggingface/inference": "^2.6.3",
13
  "@iconify-json/bi": "^1.1.21",
14
  "@resvg/resvg-js": "^2.6.0",
15
+ "@xenova/transformers": "^2.16.1",
16
  "autoprefixer": "^10.4.14",
17
  "browser-image-resizer": "^2.4.1",
18
  "date-fns": "^2.29.3",
 
660
  "node": ">=18"
661
  }
662
  },
663
+ "node_modules/@huggingface/jinja": {
664
+ "version": "0.2.2",
665
+ "resolved": "https://registry.npmjs.org/@huggingface/jinja/-/jinja-0.2.2.tgz",
666
+ "integrity": "sha512-/KPde26khDUIPkTGU82jdtTW9UAuvUTumCAbFs/7giR0SxsvZC4hru51PBvpijH6BVkHcROcvZM/lpy5h1jRRA==",
667
+ "engines": {
668
+ "node": ">=18"
669
+ }
670
+ },
671
  "node_modules/@humanwhocodes/config-array": {
672
  "version": "0.11.8",
673
  "resolved": "https://registry.npmjs.org/@humanwhocodes/config-array/-/config-array-0.11.8.tgz",
 
2415
  }
2416
  },
2417
  "node_modules/@xenova/transformers": {
2418
+ "version": "2.16.1",
2419
+ "resolved": "https://registry.npmjs.org/@xenova/transformers/-/transformers-2.16.1.tgz",
2420
+ "integrity": "sha512-p2ii7v7oC3Se0PC012dn4vt196GCroaN5ngOYJYkfg0/ce8A5frsrnnnktOBJuejG3bW5Hreb7JZ/KxtUaKd8w==",
2421
  "dependencies": {
2422
+ "@huggingface/jinja": "^0.2.2",
2423
  "onnxruntime-web": "1.14.0",
2424
  "sharp": "^0.32.0"
2425
  },
package.json CHANGED
@@ -54,7 +54,7 @@
54
  "@huggingface/inference": "^2.6.3",
55
  "@iconify-json/bi": "^1.1.21",
56
  "@resvg/resvg-js": "^2.6.0",
57
- "@xenova/transformers": "^2.6.0",
58
  "autoprefixer": "^10.4.14",
59
  "browser-image-resizer": "^2.4.1",
60
  "date-fns": "^2.29.3",
@@ -83,8 +83,8 @@
83
  },
84
  "optionalDependencies": {
85
  "@anthropic-ai/sdk": "^0.17.1",
 
86
  "aws4fetch": "^1.0.17",
87
- "openai": "^4.14.2",
88
- "@google-cloud/vertexai": "^0.5.0"
89
  }
90
  }
 
54
  "@huggingface/inference": "^2.6.3",
55
  "@iconify-json/bi": "^1.1.21",
56
  "@resvg/resvg-js": "^2.6.0",
57
+ "@xenova/transformers": "^2.16.1",
58
  "autoprefixer": "^10.4.14",
59
  "browser-image-resizer": "^2.4.1",
60
  "date-fns": "^2.29.3",
 
83
  },
84
  "optionalDependencies": {
85
  "@anthropic-ai/sdk": "^0.17.1",
86
+ "@google-cloud/vertexai": "^0.5.0",
87
  "aws4fetch": "^1.0.17",
88
+ "openai": "^4.14.2"
 
89
  }
90
  }
src/lib/components/TokensCounter.svelte CHANGED
@@ -1,6 +1,7 @@
1
  <script lang="ts">
2
  import type { Model } from "$lib/types/Model";
3
- import { AutoTokenizer, PreTrainedTokenizer } from "@xenova/transformers";
 
4
 
5
  export let classNames = "";
6
  export let prompt = "";
@@ -9,23 +10,6 @@
9
 
10
  let tokenizer: PreTrainedTokenizer | undefined = undefined;
11
 
12
- async function getTokenizer(_modelTokenizer: Exclude<Model["tokenizer"], undefined>) {
13
- if (typeof _modelTokenizer === "string") {
14
- // return auto tokenizer
15
- return await AutoTokenizer.from_pretrained(_modelTokenizer);
16
- }
17
- {
18
- // construct & return pretrained tokenizer
19
- const { tokenizerUrl, tokenizerConfigUrl } = _modelTokenizer satisfies {
20
- tokenizerUrl: string;
21
- tokenizerConfigUrl: string;
22
- };
23
- const tokenizerJSON = await (await fetch(tokenizerUrl)).json();
24
- const tokenizerConfig = await (await fetch(tokenizerConfigUrl)).json();
25
- return new PreTrainedTokenizer(tokenizerJSON, tokenizerConfig);
26
- }
27
- }
28
-
29
  async function tokenizeText(_prompt: string) {
30
  if (!tokenizer) {
31
  return;
 
1
  <script lang="ts">
2
  import type { Model } from "$lib/types/Model";
3
+ import { getTokenizer } from "$lib/utils/getTokenizer";
4
+ import type { PreTrainedTokenizer } from "@xenova/transformers";
5
 
6
  export let classNames = "";
7
  export let prompt = "";
 
10
 
11
  let tokenizer: PreTrainedTokenizer | undefined = undefined;
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  async function tokenizeText(_prompt: string) {
14
  if (!tokenizer) {
15
  return;
src/lib/server/embeddingEndpoints/transformersjs/embeddingEndpoints.ts CHANGED
@@ -1,6 +1,6 @@
1
  import { z } from "zod";
2
  import type { EmbeddingEndpoint } from "../embeddingEndpoints";
3
- import type { Tensor, Pipeline } from "@xenova/transformers";
4
  import { pipeline } from "@xenova/transformers";
5
 
6
  export const embeddingEndpointTransformersJSParametersSchema = z.object({
@@ -11,9 +11,9 @@ export const embeddingEndpointTransformersJSParametersSchema = z.object({
11
 
12
  // Use the Singleton pattern to enable lazy construction of the pipeline.
13
  class TransformersJSModelsSingleton {
14
- static instances: Array<[string, Promise<Pipeline>]> = [];
15
 
16
- static async getInstance(modelName: string): Promise<Pipeline> {
17
  const modelPipelineInstance = this.instances.find(([name]) => name === modelName);
18
 
19
  if (modelPipelineInstance) {
 
1
  import { z } from "zod";
2
  import type { EmbeddingEndpoint } from "../embeddingEndpoints";
3
+ import type { Tensor, FeatureExtractionPipeline } from "@xenova/transformers";
4
  import { pipeline } from "@xenova/transformers";
5
 
6
  export const embeddingEndpointTransformersJSParametersSchema = z.object({
 
11
 
12
  // Use the Singleton pattern to enable lazy construction of the pipeline.
13
  class TransformersJSModelsSingleton {
14
+ static instances: Array<[string, Promise<FeatureExtractionPipeline>]> = [];
15
 
16
+ static async getInstance(modelName: string): Promise<FeatureExtractionPipeline> {
17
  const modelPipelineInstance = this.instances.find(([name]) => name === modelName);
18
 
19
  if (modelPipelineInstance) {
src/lib/server/models.ts CHANGED
@@ -14,7 +14,10 @@ import endpointTgi from "./endpoints/tgi/endpointTgi";
14
  import { sum } from "$lib/utils/sum";
15
  import { embeddingModels, validateEmbeddingModelByName } from "./embeddingModels";
16
 
 
 
17
  import JSON5 from "json5";
 
18
 
19
  type Optional<T, K extends keyof T> = Pick<Partial<T>, K> & Omit<T, K>;
20
 
@@ -39,23 +42,9 @@ const modelConfig = z.object({
39
  .optional(),
40
  datasetName: z.string().min(1).optional(),
41
  datasetUrl: z.string().url().optional(),
42
- userMessageToken: z.string().default(""),
43
- userMessageEndToken: z.string().default(""),
44
- assistantMessageToken: z.string().default(""),
45
- assistantMessageEndToken: z.string().default(""),
46
- messageEndToken: z.string().default(""),
47
  preprompt: z.string().default(""),
48
  prepromptUrl: z.string().url().optional(),
49
- chatPromptTemplate: z
50
- .string()
51
- .default(
52
- "{{preprompt}}" +
53
- "{{#each messages}}" +
54
- "{{#ifUser}}{{@root.userMessageToken}}{{content}}{{@root.userMessageEndToken}}{{/ifUser}}" +
55
- "{{#ifAssistant}}{{@root.assistantMessageToken}}{{content}}{{@root.assistantMessageEndToken}}{{/ifAssistant}}" +
56
- "{{/each}}" +
57
- "{{assistantMessageToken}}"
58
- ),
59
  promptExamples: z
60
  .array(
61
  z.object({
@@ -84,11 +73,64 @@ const modelConfig = z.object({
84
 
85
  const modelsRaw = z.array(modelConfig).parse(JSON5.parse(MODELS));
86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  const processModel = async (m: z.infer<typeof modelConfig>) => ({
88
  ...m,
89
- userMessageEndToken: m?.userMessageEndToken || m?.messageEndToken,
90
- assistantMessageEndToken: m?.assistantMessageEndToken || m?.messageEndToken,
91
- chatPromptRender: compileTemplate<ChatTemplateInput>(m.chatPromptTemplate, m),
92
  id: m.id || m.name,
93
  displayName: m.displayName || m.name,
94
  preprompt: m.prepromptUrl ? await fetch(m.prepromptUrl).then((r) => r.text()) : m.preprompt,
 
14
  import { sum } from "$lib/utils/sum";
15
  import { embeddingModels, validateEmbeddingModelByName } from "./embeddingModels";
16
 
17
+ import type { PreTrainedTokenizer } from "@xenova/transformers";
18
+
19
  import JSON5 from "json5";
20
+ import { getTokenizer } from "$lib/utils/getTokenizer";
21
 
22
  type Optional<T, K extends keyof T> = Pick<Partial<T>, K> & Omit<T, K>;
23
 
 
42
  .optional(),
43
  datasetName: z.string().min(1).optional(),
44
  datasetUrl: z.string().url().optional(),
 
 
 
 
 
45
  preprompt: z.string().default(""),
46
  prepromptUrl: z.string().url().optional(),
47
+ chatPromptTemplate: z.string().optional(),
 
 
 
 
 
 
 
 
 
48
  promptExamples: z
49
  .array(
50
  z.object({
 
73
 
74
  const modelsRaw = z.array(modelConfig).parse(JSON5.parse(MODELS));
75
 
76
+ async function getChatPromptRender(
77
+ m: z.infer<typeof modelConfig>
78
+ ): Promise<ReturnType<typeof compileTemplate<ChatTemplateInput>>> {
79
+ if (m.chatPromptTemplate) {
80
+ return compileTemplate<ChatTemplateInput>(m.chatPromptTemplate, m);
81
+ }
82
+ let tokenizer: PreTrainedTokenizer;
83
+
84
+ if (!m.tokenizer) {
85
+ throw new Error(
86
+ "No tokenizer specified and no chat prompt template specified for model " + m.name
87
+ );
88
+ }
89
+
90
+ try {
91
+ tokenizer = await getTokenizer(m.tokenizer);
92
+ } catch (e) {
93
+ throw Error(
94
+ "Failed to load tokenizer for model " +
95
+ m.name +
96
+ " consider setting chatPromptTemplate manually or making sure the model is available on the hub."
97
+ );
98
+ }
99
+
100
+ const renderTemplate = ({ messages, preprompt }: ChatTemplateInput) => {
101
+ let formattedMessages: { role: string; content: string }[] = messages.map((message) => ({
102
+ content: message.content,
103
+ role: message.from,
104
+ }));
105
+
106
+ if (preprompt) {
107
+ formattedMessages = [
108
+ {
109
+ role: "system",
110
+ content: preprompt,
111
+ },
112
+ ...formattedMessages,
113
+ ];
114
+ }
115
+
116
+ const output = tokenizer.apply_chat_template(formattedMessages, {
117
+ tokenize: false,
118
+ add_generation_prompt: true,
119
+ });
120
+
121
+ if (typeof output !== "string") {
122
+ throw new Error("Failed to apply chat template, the output is not a string");
123
+ }
124
+
125
+ return output;
126
+ };
127
+
128
+ return renderTemplate;
129
+ }
130
+
131
  const processModel = async (m: z.infer<typeof modelConfig>) => ({
132
  ...m,
133
+ chatPromptRender: await getChatPromptRender(m),
 
 
134
  id: m.id || m.name,
135
  displayName: m.displayName || m.name,
136
  preprompt: m.prepromptUrl ? await fetch(m.prepromptUrl).then((r) => r.text()) : m.preprompt,
src/lib/types/Template.ts CHANGED
@@ -1,13 +1,5 @@
1
  import type { Message } from "./Message";
2
 
3
- export type LegacyParamatersTemplateInput = {
4
- preprompt?: string;
5
- userMessageToken: string;
6
- userMessageEndToken: string;
7
- assistantMessageToken: string;
8
- assistantMessageEndToken: string;
9
- };
10
-
11
  export type ChatTemplateInput = {
12
  messages: Pick<Message, "from" | "content">[];
13
  preprompt?: string;
 
1
  import type { Message } from "./Message";
2
 
 
 
 
 
 
 
 
 
3
  export type ChatTemplateInput = {
4
  messages: Pick<Message, "from" | "content">[];
5
  preprompt?: string;
src/lib/utils/getTokenizer.ts ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import type { Model } from "$lib/types/Model";
2
+ import { AutoTokenizer, PreTrainedTokenizer } from "@xenova/transformers";
3
+
4
+ export async function getTokenizer(_modelTokenizer: Exclude<Model["tokenizer"], undefined>) {
5
+ if (typeof _modelTokenizer === "string") {
6
+ // return auto tokenizer
7
+ return await AutoTokenizer.from_pretrained(_modelTokenizer);
8
+ } else {
9
+ // construct & return pretrained tokenizer
10
+ const { tokenizerUrl, tokenizerConfigUrl } = _modelTokenizer satisfies {
11
+ tokenizerUrl: string;
12
+ tokenizerConfigUrl: string;
13
+ };
14
+ const tokenizerJSON = await (await fetch(tokenizerUrl)).json();
15
+ const tokenizerConfig = await (await fetch(tokenizerConfigUrl)).json();
16
+ return new PreTrainedTokenizer(tokenizerJSON, tokenizerConfig);
17
+ }
18
+ }
src/lib/utils/template.ts CHANGED
@@ -1,5 +1,4 @@
1
  import type { Message } from "$lib/types/Message";
2
- import type { LegacyParamatersTemplateInput } from "$lib/types/Template";
3
  import Handlebars from "handlebars";
4
 
5
  Handlebars.registerHelper("ifUser", function (this: Pick<Message, "from" | "content">, options) {
@@ -13,8 +12,8 @@ Handlebars.registerHelper(
13
  }
14
  );
15
 
16
- export function compileTemplate<T>(input: string, model: LegacyParamatersTemplateInput) {
17
- const template = Handlebars.compile<T & LegacyParamatersTemplateInput>(input, {
18
  knownHelpers: { ifUser: true, ifAssistant: true },
19
  knownHelpersOnly: true,
20
  noEscape: true,
 
1
  import type { Message } from "$lib/types/Message";
 
2
  import Handlebars from "handlebars";
3
 
4
  Handlebars.registerHelper("ifUser", function (this: Pick<Message, "from" | "content">, options) {
 
12
  }
13
  );
14
 
15
+ export function compileTemplate<T>(input: string, model: { preprompt: string }) {
16
+ const template = Handlebars.compile<T>(input, {
17
  knownHelpers: { ifUser: true, ifAssistant: true },
18
  knownHelpersOnly: true,
19
  noEscape: true,