matt HOFFNER commited on
Commit
c93058d
β€’
1 Parent(s): 28a4c08

port over until merged

Browse files
Files changed (1) hide show
  1. src/embed/hf.ts +40 -0
src/embed/hf.ts ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { pipeline } from "@xenova/transformers";
2
+ import { Embeddings, EmbeddingsParams } from "langchain/embeddings/base";
3
+
4
+ export interface XenovaTransformersEmbeddingsParams extends EmbeddingsParams {
5
+ model?: string;
6
+ }
7
+
8
+ export class XenovaTransformersEmbeddings
9
+ extends Embeddings
10
+ implements XenovaTransformersEmbeddingsParams
11
+ {
12
+ model: string;
13
+
14
+ client: any;
15
+
16
+ constructor(fields?: XenovaTransformersEmbeddingsParams) {
17
+ super(fields ?? {});
18
+ this.model = fields?.model ?? "Xenova/all-MiniLM-L6-v2";
19
+ }
20
+
21
+ async _embed(texts: string[]): Promise<number[][]> {
22
+ if (!this.client) {
23
+ this.client = await pipeline("embeddings", this.model);
24
+ }
25
+
26
+ return this.caller.call(async () => {
27
+ return await Promise.all(
28
+ texts.map(async (t) => (await this.client(t)).data)
29
+ );
30
+ });
31
+ }
32
+
33
+ embedQuery(document: string): Promise<number[]> {
34
+ return this._embed([document]).then((embeddings) => embeddings[0]);
35
+ }
36
+
37
+ embedDocuments(documents: string[]): Promise<number[][]> {
38
+ return this._embed(documents);
39
+ }
40
+ }