"use server" import { HfInference, HfInferenceEndpoint } from "@huggingface/inference" import { LLMEngine } from "@/types" import { createZephyrPrompt } from "@/lib/createZephyrPrompt" export async function predict({ systemPrompt, userPrompt, nbMaxNewTokens, }: { systemPrompt: string userPrompt: string nbMaxNewTokens: number }): Promise { const hf = new HfInference(process.env.AUTH_HF_API_TOKEN) const llmEngine = `${process.env.LLM_ENGINE || ""}` as LLMEngine const inferenceEndpoint = `${process.env.LLM_HF_INFERENCE_ENDPOINT_URL || ""}` const inferenceModel = `${process.env.LLM_HF_INFERENCE_API_MODEL || ""}` let hfie: HfInferenceEndpoint = hf switch (llmEngine) { case "INFERENCE_ENDPOINT": if (inferenceEndpoint) { // console.log("Using a custom HF Inference Endpoint") hfie = hf.endpoint(inferenceEndpoint) } else { const error = "No Inference Endpoint URL defined" console.error(error) throw new Error(error) } break; case "INFERENCE_API": if (inferenceModel) { // console.log("Using an HF Inference API Model") } else { const error = "No Inference API model defined" console.error(error) throw new Error(error) } break; default: const error = "Please check your Hugging Face Inference API or Inference Endpoint settings" console.error(error) throw new Error(error) } const api = llmEngine === "INFERENCE_ENDPOINT" ? hfie : hf let instructions = "" try { for await (const output of api.textGenerationStream({ model: llmEngine === "INFERENCE_ENDPOINT" ? undefined : (inferenceModel || undefined), inputs: createZephyrPrompt([ { role: "system", content: systemPrompt }, { role: "user", content: userPrompt } ]) + "\n[{", // <-- important: we force its hand parameters: { do_sample: true, max_new_tokens: nbMaxNewTokens, return_full_text: false, } })) { instructions += output.token.text // process.stdout.write(output.token.text) if ( instructions.includes("") || instructions.includes("") || instructions.includes("/s>") || instructions.includes("[INST]") || instructions.includes("[/INST]") || instructions.includes("") || instructions.includes("<>") || instructions.includes("") || instructions.includes("<>") || instructions.includes("<|user|>") || instructions.includes("<|end|>") || instructions.includes("<|system|>") || instructions.includes("<|assistant|>") ) { break } } } catch (err) { // console.error(`error during generation: ${err}`) // a common issue with Llama-2 might be that the model receives too many requests if (`${err}` === "Error: Model is overloaded") { instructions = `` } } // need to do some cleanup of the garbage the LLM might have gave us return ( instructions .replaceAll("<|end|>", "") .replaceAll("", "") .replaceAll("", "") .replaceAll("/s>", "") .replaceAll("[INST]", "") .replaceAll("[/INST]", "") .replaceAll("", "") .replaceAll("<>", "") .replaceAll("", "") .replaceAll("<>", "") .replaceAll("<|system|>", "") .replaceAll("<|user|>", "") .replaceAll("<|all|>", "") .replaceAll("<|assistant|>", "") .replaceAll('""', '"') ) }