Spaces:
Running
Running
File size: 2,591 Bytes
0c4cf03 2e6d1bb 0c4cf03 0134fe1 0c4cf03 2e6d1bb ebac87f 0c4cf03 2e6d1bb ebac87f 2e6d1bb 0134fe1 2e6d1bb 0134fe1 2e6d1bb 1e19fc8 0c4cf03 2e6d1bb d2a650e 0c4cf03 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
import { defaultModel } from "$lib/server/models";
import { modelEndpoint } from "./modelEndpoint";
import { trimSuffix } from "$lib/utils/trimSuffix";
import { trimPrefix } from "$lib/utils/trimPrefix";
import { PUBLIC_SEP_TOKEN } from "$lib/constants/publicSepToken";
import { AwsClient } from "aws4fetch";
interface Parameters {
temperature: number;
truncate: number;
max_new_tokens: number;
stop: string[];
}
export async function generateFromDefaultEndpoint(
prompt: string,
parameters?: Partial<Parameters>
): Promise<string> {
const newParameters = {
...defaultModel.parameters,
...parameters,
return_full_text: false,
};
const randomEndpoint = modelEndpoint(defaultModel);
const abortController = new AbortController();
let resp: Response;
if (randomEndpoint.host === "sagemaker") {
const requestParams = JSON.stringify({
parameters: newParameters,
inputs: prompt,
});
const aws = new AwsClient({
accessKeyId: randomEndpoint.accessKey,
secretAccessKey: randomEndpoint.secretKey,
sessionToken: randomEndpoint.sessionToken,
service: "sagemaker",
});
resp = await aws.fetch(randomEndpoint.url, {
method: "POST",
body: requestParams,
signal: abortController.signal,
headers: {
"Content-Type": "application/json",
},
});
} else {
resp = await fetch(randomEndpoint.url, {
headers: {
"Content-Type": "application/json",
Authorization: randomEndpoint.authorization,
},
method: "POST",
body: JSON.stringify({
parameters: newParameters,
inputs: prompt,
}),
signal: abortController.signal,
});
}
if (!resp.ok) {
throw new Error(await resp.text());
}
if (!resp.body) {
throw new Error("Body is empty");
}
const decoder = new TextDecoder();
const reader = resp.body.getReader();
let isDone = false;
let result = "";
while (!isDone) {
const { done, value } = await reader.read();
isDone = done;
result += decoder.decode(value, { stream: true }); // Convert current chunk to text
}
// Close the reader when done
reader.releaseLock();
let results;
if (result.startsWith("data:")) {
results = [JSON.parse(result.split("data:")?.pop() ?? "")];
} else {
results = JSON.parse(result);
}
let generated_text = trimSuffix(
trimPrefix(trimPrefix(results[0].generated_text, "<|startoftext|>"), prompt),
PUBLIC_SEP_TOKEN
).trimEnd();
for (const stop of [...(newParameters?.stop ?? []), "<|endoftext|>"]) {
if (generated_text.endsWith(stop)) {
generated_text = generated_text.slice(0, -stop.length).trimEnd();
}
}
return generated_text;
}
|