matt HOFFNER commited on
Commit
f373356
β€’
1 Parent(s): 1300e36
Files changed (3) hide show
  1. package.json +0 -2
  2. src/components/ChatWindow.jsx +16 -25
  3. src/embed/hf.ts +35 -55
package.json CHANGED
@@ -15,8 +15,6 @@
15
  "@types/react": "18.2.6",
16
  "@types/react-dom": "18.2.4",
17
  "@xenova/transformers": "^2.1.1",
18
- "chromadb": "^1.5.2",
19
- "cohere-ai": "^5.1.0",
20
  "dexie": "^3.2.4",
21
  "eslint": "8.40.0",
22
  "eslint-config-next": "13.4.2",
 
15
  "@types/react": "18.2.6",
16
  "@types/react-dom": "18.2.4",
17
  "@xenova/transformers": "^2.1.1",
 
 
18
  "dexie": "^3.2.4",
19
  "eslint": "8.40.0",
20
  "eslint-config-next": "13.4.2",
src/components/ChatWindow.jsx CHANGED
@@ -5,11 +5,8 @@ import MessageList from './MessageList';
5
  import {FileLoader} from './FileLoader';
6
  import Loader from "./Loader";
7
  import { RecursiveCharacterTextSplitter } from 'langchain/text_splitter';
8
- import { TransformersEmbeddingFunction } from '../embed/hf';
9
- import { ChromaClient } from "chromadb";
10
-
11
- const client = new ChromaClient();
12
- const embedder = new TransformersEmbeddingFunction({});
13
 
14
  function ChatWindow({
15
  stopStrings,
@@ -34,37 +31,31 @@ function ChatWindow({
34
  console.log('found file text splitting into chunks')
35
  const textSplitter = new RecursiveCharacterTextSplitter({ chunkSize: 1000 });
36
  const docs = await textSplitter.createDocuments([fileText]);
37
- console.log(`split docs: ${docs}`);
38
- const collection = await client.createCollection({name: "docs", embeddingFunction: embedder })
39
- console.log(`collection: ${collection}`);
40
  let queryResult;
 
 
41
  try {
42
- await collection.add({
43
- ids: [...docs.map((v, k) => k)],
44
- metadatas: [...docs.map(doc => doc.metadata)],
45
- documents: [...docs.map(doc => doc.pageContent)],
46
- });
47
- const queryResult = await collection.query({
48
- nResults: 2,
49
- queryTexts: [userPrompt]
50
- });
51
- console.log(queryResult);
52
- } catch (err) {
53
- console.log(err);
54
- }
55
-
56
-
57
- const qaPrompt =
58
  `You are an AI assistant providing helpful advice. You are given the following extracted parts of a long document and a question. Provide a conversational answer based on the context provided.
59
  You should only provide hyperlinks that reference the context below. Do NOT make up hyperlinks.
60
  If you can't find the answer in the context below, just say "Hmm, I'm not sure." Don't try to make up an answer.
61
  If the question is not related to the context, politely respond that you are tuned to only answer questions that are related to the context.
62
  Question: ${userInput}
63
  =========
64
- ${queryResult}
65
  =========
66
  Answer:
67
  `
 
 
 
68
  send(qaPrompt, maxTokens, stopStrings);
69
  } else {
70
  send(userInput, maxTokens, stopStrings);
 
5
  import {FileLoader} from './FileLoader';
6
  import Loader from "./Loader";
7
  import { RecursiveCharacterTextSplitter } from 'langchain/text_splitter';
8
+ import { XenovaTransformersEmbeddings } from '../embed/hf';
9
+ import { MemoryVectorStore } from "langchain/vectorstores/memory";
 
 
 
10
 
11
  function ChatWindow({
12
  stopStrings,
 
31
  console.log('found file text splitting into chunks')
32
  const textSplitter = new RecursiveCharacterTextSplitter({ chunkSize: 1000 });
33
  const docs = await textSplitter.createDocuments([fileText]);
 
 
 
34
  let queryResult;
35
+ let qaPrompt;
36
+ console.log(docs);
37
  try {
38
+ const vectorStore = await MemoryVectorStore.fromTexts(
39
+ [...docs.map(doc => doc.pageContent)],
40
+ [...docs.map((v, k) => k)],
41
+ new XenovaTransformersEmbeddings()
42
+ )
43
+ let queryResult = await vectorStore.similaritySearch(userInput, 1);
44
+ console.log("queryResult", queryResult);
45
+ qaPrompt =
 
 
 
 
 
 
 
 
46
  `You are an AI assistant providing helpful advice. You are given the following extracted parts of a long document and a question. Provide a conversational answer based on the context provided.
47
  You should only provide hyperlinks that reference the context below. Do NOT make up hyperlinks.
48
  If you can't find the answer in the context below, just say "Hmm, I'm not sure." Don't try to make up an answer.
49
  If the question is not related to the context, politely respond that you are tuned to only answer questions that are related to the context.
50
  Question: ${userInput}
51
  =========
52
+ ${queryResult[0].pageContent}
53
  =========
54
  Answer:
55
  `
56
+ } catch (err) {
57
+ console.log(err);
58
+ }
59
  send(qaPrompt, maxTokens, stopStrings);
60
  } else {
61
  send(userInput, maxTokens, stopStrings);
src/embed/hf.ts CHANGED
@@ -1,62 +1,42 @@
1
- import { IEmbeddingFunction } from "chromadb/src/embeddings/IEmbeddingFunction";
2
-
3
- // Dynamically import module
4
- let TransformersApi: Promise<any>;
5
-
6
- export class TransformersEmbeddingFunction implements IEmbeddingFunction {
7
- private pipelinePromise: Promise<any> | null;
8
-
9
- /**
10
- * TransformersEmbeddingFunction constructor.
11
- * @param options The configuration options.
12
- * @param options.model The model to use to calculate embeddings. Defaults to 'Xenova/all-MiniLM-L6-v2', which is an ONNX port of `sentence-transformers/all-MiniLM-L6-v2`.
13
- * @param options.revision The specific model version to use (can be a branch, tag name, or commit id). Defaults to 'main'.
14
- * @param options.quantized Whether to load the 8-bit quantized version of the model. Defaults to `false`.
15
- * @param options.progress_callback If specified, this function will be called during model construction, to provide the user with progress updates.
16
- */
17
- constructor({
18
- model = "Xenova/all-MiniLM-L6-v2",
19
- revision = "main",
20
- quantized = false,
21
- progress_callback = null,
22
- }: {
23
- model?: string;
24
- revision?: string;
25
- quantized?: boolean;
26
- progress_callback?: Function | null;
27
- } = {}) {
28
- try {
29
- // Since Transformers.js is an ESM package, we use the dynamic `import` syntax instead of `require`.
30
- // Also, since we use `"module": "commonjs"` in tsconfig.json, we use the following workaround to ensure
31
- // the dynamic import is not transpiled to a `require` statement.
32
- // For more information, see https://github.com/microsoft/TypeScript/issues/43329#issuecomment-1008361973
33
- TransformersApi = Function('return import("@xenova/transformers")')();
34
- } catch (e) {
35
- throw new Error(
36
- "Please install the @xenova/transformers package to use the TransformersEmbeddingFunction, `npm install -S @xenova/transformers`."
37
- );
38
  }
39
 
40
- // Store a promise that resolves to the pipeline
41
- this.pipelinePromise = new Promise(async (resolve, reject) => {
42
- try {
43
- const { pipeline } = await TransformersApi;
44
- resolve(
45
- await pipeline("feature-extraction", model, {
46
- quantized,
47
- revision,
48
- progress_callback,
49
- })
50
- );
51
- } catch (e) {
52
- reject(e);
53
- }
54
  });
55
  }
56
 
57
- public async generate(texts: string[]): Promise<number[][]> {
58
- let pipe = await this.pipelinePromise;
59
- let output = await pipe(texts, { pooling: "mean", normalize: true });
60
- return output.tolist();
 
 
61
  }
62
  }
 
1
+ import { pipeline } from "@xenova/transformers";
2
+ import { Embeddings, EmbeddingsParams } from "langchain/embeddings/base";
3
+
4
+ export interface XenovaTransformersEmbeddingsParams extends EmbeddingsParams {
5
+ model?: string;
6
+ }
7
+
8
+ export class XenovaTransformersEmbeddings
9
+ extends Embeddings
10
+ implements XenovaTransformersEmbeddingsParams
11
+ {
12
+ model: string;
13
+
14
+ client: any;
15
+
16
+ constructor(fields?: XenovaTransformersEmbeddingsParams) {
17
+ super(fields ?? {});
18
+ this.model = fields?.model ?? "Xenova/all-MiniLM-L6-v2";
19
+ }
20
+
21
+ async _embed(texts: string[]): Promise<number[][]> {
22
+ if (!this.client) {
23
+ this.client = await pipeline("embeddings", this.model);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  }
25
 
26
+ return this.caller.call(async () => {
27
+ return await Promise.all(
28
+ texts.map(async (t) => (await this.client(t, {
29
+ pooling: "mean", normalize: true
30
+ })).data)
31
+ );
 
 
 
 
 
 
 
 
32
  });
33
  }
34
 
35
+ embedQuery(document: string): Promise<number[]> {
36
+ return this._embed([document]).then((embeddings) => embeddings[0]);
37
+ }
38
+
39
+ embedDocuments(documents: string[]): Promise<number[][]> {
40
+ return this._embed(documents);
41
  }
42
  }