File size: 3,241 Bytes
879455c
 
 
 
 
d64b893
6463491
879455c
6463491
 
 
879455c
6463491
879455c
6463491
 
 
1e641f1
6463491
 
 
 
 
 
 
 
 
 
1e641f1
6463491
 
 
 
 
 
 
 
 
879455c
 
6463491
 
 
879455c
 
 
 
 
 
 
 
5dd2af5
879455c
 
 
 
9bfb451
879455c
 
 
9052a89
879455c
 
 
9052a89
879455c
9052a89
 
879455c
9052a89
879455c
 
 
 
 
 
9bfb451
81eb27e
 
 
 
 
879455c
 
 
 
 
 
 
 
9052a89
879455c
 
 
9052a89
879455c
9052a89
 
 
 
879455c
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
"use server"

import { HfInference, HfInferenceEndpoint } from "@huggingface/inference"
import { LLMEngine } from "@/types"

export async function predict(inputs: string, nbMaxNewTokens: number): Promise<string> {
  const hf = new HfInference(process.env.AUTH_HF_API_TOKEN)

  const llmEngine = `${process.env.LLM_ENGINE || ""}` as LLMEngine
  const inferenceEndpoint = `${process.env.LLM_HF_INFERENCE_ENDPOINT_URL || ""}`
  const inferenceModel = `${process.env.LLM_HF_INFERENCE_API_MODEL || ""}`

  let hfie: HfInferenceEndpoint = hf

  switch (llmEngine) {
    case "INFERENCE_ENDPOINT":
      if (inferenceEndpoint) {
       //  console.log("Using a custom HF Inference Endpoint")
        hfie = hf.endpoint(inferenceEndpoint)
      } else {
        const error = "No Inference Endpoint URL defined"
        console.error(error)
        throw new Error(error)
      }
      break;
    
    case "INFERENCE_API":
      if (inferenceModel) {
        // console.log("Using an HF Inference API Model")
      } else {
        const error = "No Inference API model defined"
        console.error(error)
        throw new Error(error)
      }
      break;

    default:
      const error = "Please check your Hugging Face Inference API or Inference Endpoint settings"
      console.error(error)
      throw new Error(error)
  }
    
  const api = llmEngine === "INFERENCE_ENDPOINT" ? hfie : hf

  let instructions = ""
  try {
    for await (const output of api.textGenerationStream({
      model: llmEngine === "INFERENCE_ENDPOINT" ? undefined : (inferenceModel || undefined),
      inputs,
      parameters: {
        do_sample: true,
        max_new_tokens: nbMaxNewTokens,
        return_full_text: false,
      }
    })) {
      instructions += output.token.text
      // process.stdout.write(output.token.text)
      if (
        instructions.includes("</s>") || 
        instructions.includes("<s>") ||
        instructions.includes("/s>") ||
        instructions.includes("[INST]") ||
        instructions.includes("[/INST]") ||
        instructions.includes("<SYS>") ||
        instructions.includes("<<SYS>>") ||
        instructions.includes("</SYS>") ||
        instructions.includes("<</SYS>>") ||
        instructions.includes("<|user|>") ||
        instructions.includes("<|end|>") ||
        instructions.includes("<|system|>") ||
        instructions.includes("<|assistant|>")
      ) {
        break
      }
    }
  } catch (err) {
    // console.error(`error during generation: ${err}`)

    // a common issue with Llama-2 might be that the model receives too many requests
    if (`${err}` === "Error: Model is overloaded") {
      instructions = ``
    }
  }

  // need to do some cleanup of the garbage the LLM might have gave us
  return (
    instructions
    .replaceAll("<|end|>", "")
    .replaceAll("<s>", "")
    .replaceAll("</s>", "")
    .replaceAll("/s>", "")
    .replaceAll("[INST]", "")
    .replaceAll("[/INST]", "") 
    .replaceAll("<SYS>", "")
    .replaceAll("<<SYS>>", "")
    .replaceAll("</SYS>", "")
    .replaceAll("<</SYS>>", "")
    .replaceAll("<|system|>", "")
    .replaceAll("<|user|>", "")
    .replaceAll("<|all|>", "")
    .replaceAll("<|assistant|>", "")
    .replaceAll('""', '"')
  )
}