matt HOFFNER commited on
Commit
8c64b1d
Β·
1 Parent(s): c63af75

less langchain more chroma

Browse files
src/components/ChatWindow.jsx CHANGED
@@ -3,9 +3,9 @@ import Image from "next/image";
3
  import { useCallback, useEffect, useState } from "react";
4
  import MessageList from './MessageList';
5
  import {FileLoader} from './FileLoader';
6
- import { db } from '@/utils/db-client';
7
  import Loader from "./Loader";
8
- import { Chroma } from "langchain/vectorstores/chroma";
 
9
 
10
  function ChatWindow({
11
  stopStrings,
@@ -27,10 +27,14 @@ function ChatWindow({
27
  }
28
 
29
  if (fileId) {
30
- const similarityMatches = 2;
31
- const fileContents = await db.docs.get(fileId);
32
- const vectorStore = await Chroma.fromDocuments(fileContents, new XenovaTransformersEmbeddings());
33
- const result = await vectorStore.similaritySearch(userInput, similarityMatches);
 
 
 
 
34
  const qaPrompt =
35
  `You are an AI assistant providing helpful advice. You are given the following extracted parts of a long document and a question. Provide a conversational answer based on the context provided.
36
  You should only provide hyperlinks that reference the context below. Do NOT make up hyperlinks.
 
3
  import { useCallback, useEffect, useState } from "react";
4
  import MessageList from './MessageList';
5
  import {FileLoader} from './FileLoader';
 
6
  import Loader from "./Loader";
7
+ const {ChromaClient} = require('chromadb');
8
+ const client = new ChromaClient();
9
 
10
  function ChatWindow({
11
  stopStrings,
 
27
  }
28
 
29
  if (fileId) {
30
+ // const fileContents = await db.docs.get(fileId);
31
+ const collection = await client.getCollection("docs")
32
+ const result = await collection.query({
33
+ nResults: 2,
34
+ queryTexts: [userPrompt]
35
+ });
36
+ console.log(result);
37
+
38
  const qaPrompt =
39
  `You are an AI assistant providing helpful advice. You are given the following extracted parts of a long document and a question. Provide a conversational answer based on the context provided.
40
  You should only provide hyperlinks that reference the context below. Do NOT make up hyperlinks.
src/embed/hf.ts CHANGED
@@ -1,41 +1,62 @@
1
- import { pipeline } from "@xenova/transformers";
2
- import { Embeddings, EmbeddingsParams } from "langchain/embeddings/base";
3
-
4
- export interface XenovaTransformersEmbeddingsParams extends EmbeddingsParams {
5
- model?: string;
6
- }
7
-
8
- export default class XenovaTransformersEmbeddings
9
- extends Embeddings
10
- implements XenovaTransformersEmbeddingsParams
11
- {
12
- model: string;
13
-
14
- client: any;
15
-
16
- constructor(fields?: XenovaTransformersEmbeddingsParams) {
17
- super(fields ?? {});
18
- this.model = fields?.model ?? "Xenova/all-MiniLM-L6-v2";
19
- }
20
-
21
- async _embed(texts: string[]): Promise<number[][]> {
22
- if (!this.client) {
23
- this.client = await pipeline("embeddings", this.model);
24
- }
25
- console.log(this.client, texts);
26
-
27
- return this.caller.call(async () => {
28
- return await Promise.all(
29
- texts.map(async (t) => (await this.client(t, { pooling: 'mean', normalize: true })).data)
30
- );
31
- });
32
- }
33
-
34
- embedQuery(document: string): Promise<number[]> {
35
- return this._embed([document]).then((embeddings) => embeddings[0]);
36
- }
37
-
38
- embedDocuments(documents: string[]): Promise<number[][]> {
39
- return this._embed(documents);
40
- }
41
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { IEmbeddingFunction } from "chromadb/src/embeddings/IEmbeddingFunction";
2
+
3
+ // Dynamically import module
4
+ let TransformersApi: Promise<any>;
5
+
6
+ export class TransformersEmbeddingFunction implements IEmbeddingFunction {
7
+ private pipelinePromise: Promise<any> | null;
8
+
9
+ /**
10
+ * TransformersEmbeddingFunction constructor.
11
+ * @param options The configuration options.
12
+ * @param options.model The model to use to calculate embeddings. Defaults to 'Xenova/all-MiniLM-L6-v2', which is an ONNX port of `sentence-transformers/all-MiniLM-L6-v2`.
13
+ * @param options.revision The specific model version to use (can be a branch, tag name, or commit id). Defaults to 'main'.
14
+ * @param options.quantized Whether to load the 8-bit quantized version of the model. Defaults to `false`.
15
+ * @param options.progress_callback If specified, this function will be called during model construction, to provide the user with progress updates.
16
+ */
17
+ constructor({
18
+ model = "Xenova/all-MiniLM-L6-v2",
19
+ revision = "main",
20
+ quantized = false,
21
+ progress_callback = null,
22
+ }: {
23
+ model?: string;
24
+ revision?: string;
25
+ quantized?: boolean;
26
+ progress_callback?: Function | null;
27
+ } = {}) {
28
+ try {
29
+ // Since Transformers.js is an ESM package, we use the dynamic `import` syntax instead of `require`.
30
+ // Also, since we use `"module": "commonjs"` in tsconfig.json, we use the following workaround to ensure
31
+ // the dynamic import is not transpiled to a `require` statement.
32
+ // For more information, see https://github.com/microsoft/TypeScript/issues/43329#issuecomment-1008361973
33
+ TransformersApi = Function('return import("@xenova/transformers")')();
34
+ } catch (e) {
35
+ throw new Error(
36
+ "Please install the @xenova/transformers package to use the TransformersEmbeddingFunction, `npm install -S @xenova/transformers`."
37
+ );
38
+ }
39
+
40
+ // Store a promise that resolves to the pipeline
41
+ this.pipelinePromise = new Promise(async (resolve, reject) => {
42
+ try {
43
+ const { pipeline } = await TransformersApi;
44
+ resolve(
45
+ await pipeline("feature-extraction", model, {
46
+ quantized,
47
+ revision,
48
+ progress_callback,
49
+ })
50
+ );
51
+ } catch (e) {
52
+ reject(e);
53
+ }
54
+ });
55
+ }
56
+
57
+ public async generate(texts: string[]): Promise<number[][]> {
58
+ let pipe = await this.pipelinePromise;
59
+ let output = await pipe(texts, { pooling: "mean", normalize: true });
60
+ return output.tolist();
61
+ }
62
+ }
src/pages/api/docHandle.ts CHANGED
@@ -1,15 +1,20 @@
1
  import type { NextApiRequest, NextApiResponse } from 'next';
2
  import { RecursiveCharacterTextSplitter } from 'langchain/text_splitter';
3
- import { Chroma } from "langchain/vectorstores/chroma";
4
- import XenovaTransformersEmbeddings from '../../embed/hf'
 
5
 
6
  async function handleDocs(text: string) {
7
  const textSplitter = new RecursiveCharacterTextSplitter({ chunkSize: 1000 });
8
  const docs = await textSplitter.createDocuments([text]);
9
- const vectorStore = await Chroma.fromDocuments(docs, new XenovaTransformersEmbeddings(), {
10
- collectionName: 'docs'
 
 
 
11
  });
12
- return vectorStore;
 
13
  }
14
 
15
  export default async function handler(
 
1
  import type { NextApiRequest, NextApiResponse } from 'next';
2
  import { RecursiveCharacterTextSplitter } from 'langchain/text_splitter';
3
+ import { TransformersEmbeddingFunction } from '../../embed/hf';
4
+ const {ChromaClient} = require('chromadb');
5
+ const client = new ChromaClient();
6
 
7
  async function handleDocs(text: string) {
8
  const textSplitter = new RecursiveCharacterTextSplitter({ chunkSize: 1000 });
9
  const docs = await textSplitter.createDocuments([text]);
10
+ const collection = await client.createCollection({name: "docs", embeddingFunction: TransformersEmbeddingFunction})
11
+ await collection.add({
12
+ ids: [...docs.map((v, k) => k)],
13
+ metadatas: [...docs.map(doc => doc.metadata)],
14
+ documents: [...docs.map(doc => doc.pageContent)],
15
  });
16
+
17
+ return collection;
18
  }
19
 
20
  export default async function handler(