File size: 1,114 Bytes
f373356
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8c64b1d
 
f373356
 
 
 
 
 
8c64b1d
 
 
f373356
 
 
 
 
 
8c64b1d
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import { pipeline } from "@xenova/transformers";
import { Embeddings, EmbeddingsParams } from "langchain/embeddings/base";

export interface XenovaTransformersEmbeddingsParams extends EmbeddingsParams {
  model?: string;
}

export class XenovaTransformersEmbeddings
  extends Embeddings
  implements XenovaTransformersEmbeddingsParams
{
  model: string;

  client: any;

  constructor(fields?: XenovaTransformersEmbeddingsParams) {
    super(fields ?? {});
    this.model = fields?.model ?? "Xenova/all-MiniLM-L6-v2";
  }

  async _embed(texts: string[]): Promise<number[][]> {
    if (!this.client) {
      this.client = await pipeline("embeddings", this.model);
    }

    return this.caller.call(async () => {
      return await Promise.all(
        texts.map(async (t) => (await this.client(t, {
          pooling: "mean", normalize: true 
        })).data)
      );
    });
  }

  embedQuery(document: string): Promise<number[]> {
    return this._embed([document]).then((embeddings) => embeddings[0]);
  }

  embedDocuments(documents: string[]): Promise<number[][]> {
    return this._embed(documents);
  }
}