File size: 3,571 Bytes
879455c ba986c0 879455c ba986c0 6463491 879455c 6463491 879455c 6463491 879455c 6463491 1e641f1 6463491 1e641f1 6463491 879455c 6463491 879455c ba986c0 879455c 5dd2af5 879455c 9bfb451 879455c 9052a89 879455c 9052a89 879455c 9052a89 879455c 9052a89 879455c 9bfb451 81eb27e 879455c 9052a89 879455c 9052a89 879455c 9052a89 879455c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
"use server"
import { HfInference, HfInferenceEndpoint } from "@huggingface/inference"
import { LLMEngine } from "@/types"
import { createZephyrPrompt } from "@/lib/createZephyrPrompt"
export async function predict({
systemPrompt,
userPrompt,
nbMaxNewTokens,
}: {
systemPrompt: string
userPrompt: string
nbMaxNewTokens: number
}): Promise<string> {
const hf = new HfInference(process.env.AUTH_HF_API_TOKEN)
const llmEngine = `${process.env.LLM_ENGINE || ""}` as LLMEngine
const inferenceEndpoint = `${process.env.LLM_HF_INFERENCE_ENDPOINT_URL || ""}`
const inferenceModel = `${process.env.LLM_HF_INFERENCE_API_MODEL || ""}`
let hfie: HfInferenceEndpoint = hf
switch (llmEngine) {
case "INFERENCE_ENDPOINT":
if (inferenceEndpoint) {
// console.log("Using a custom HF Inference Endpoint")
hfie = hf.endpoint(inferenceEndpoint)
} else {
const error = "No Inference Endpoint URL defined"
console.error(error)
throw new Error(error)
}
break;
case "INFERENCE_API":
if (inferenceModel) {
// console.log("Using an HF Inference API Model")
} else {
const error = "No Inference API model defined"
console.error(error)
throw new Error(error)
}
break;
default:
const error = "Please check your Hugging Face Inference API or Inference Endpoint settings"
console.error(error)
throw new Error(error)
}
const api = llmEngine === "INFERENCE_ENDPOINT" ? hfie : hf
let instructions = ""
try {
for await (const output of api.textGenerationStream({
model: llmEngine === "INFERENCE_ENDPOINT" ? undefined : (inferenceModel || undefined),
inputs: createZephyrPrompt([
{ role: "system", content: systemPrompt },
{ role: "user", content: userPrompt }
]) + "\n[{", // <-- important: we force its hand
parameters: {
do_sample: true,
max_new_tokens: nbMaxNewTokens,
return_full_text: false,
}
})) {
instructions += output.token.text
// process.stdout.write(output.token.text)
if (
instructions.includes("</s>") ||
instructions.includes("<s>") ||
instructions.includes("/s>") ||
instructions.includes("[INST]") ||
instructions.includes("[/INST]") ||
instructions.includes("<SYS>") ||
instructions.includes("<<SYS>>") ||
instructions.includes("</SYS>") ||
instructions.includes("<</SYS>>") ||
instructions.includes("<|user|>") ||
instructions.includes("<|end|>") ||
instructions.includes("<|system|>") ||
instructions.includes("<|assistant|>")
) {
break
}
}
} catch (err) {
// console.error(`error during generation: ${err}`)
// a common issue with Llama-2 might be that the model receives too many requests
if (`${err}` === "Error: Model is overloaded") {
instructions = ``
}
}
// need to do some cleanup of the garbage the LLM might have gave us
return (
instructions
.replaceAll("<|end|>", "")
.replaceAll("<s>", "")
.replaceAll("</s>", "")
.replaceAll("/s>", "")
.replaceAll("[INST]", "")
.replaceAll("[/INST]", "")
.replaceAll("<SYS>", "")
.replaceAll("<<SYS>>", "")
.replaceAll("</SYS>", "")
.replaceAll("<</SYS>>", "")
.replaceAll("<|system|>", "")
.replaceAll("<|user|>", "")
.replaceAll("<|all|>", "")
.replaceAll("<|assistant|>", "")
.replaceAll('""', '"')
)
}
|