Spaces:
Sleeping
Sleeping
File size: 7,494 Bytes
afa4e5a e201dc3 afa4e5a ab9170f afa4e5a e201dc3 afa4e5a e201dc3 afa4e5a e201dc3 afa4e5a 0f711ba ac33c34 0f711ba afa4e5a e201dc3 afa4e5a a6b2d88 afa4e5a ab9170f afa4e5a ab9170f afa4e5a e201dc3 afa4e5a e201dc3 afa4e5a e201dc3 ab9170f afa4e5a e201dc3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 |
import type { SerializedRenderResult } from "quicktype-core";
import { quicktype, InputData, JSONSchemaInput, FetchingJSONSchemaStore } from "quicktype-core";
import * as fs from "node:fs/promises";
import { existsSync as pathExists } from "node:fs";
import * as path from "node:path/posix";
import ts from "typescript";
const TYPESCRIPT_HEADER_FILE = `
/**
* Inference code generated from the JSON schema spec in ./spec
*
* Using src/scripts/inference-codegen
*/
`;
const PYTHON_HEADER_FILE = `
# Inference code generated from the JSON schema spec in @huggingface/tasks.
#
# See:
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
`;
const PYTHON_DIR = "./.python_generated";
const rootDirFinder = function (): string {
let currentPath = path.normalize(import.meta.url);
while (currentPath !== "/") {
if (pathExists(path.join(currentPath, "package.json"))) {
return currentPath;
}
currentPath = path.normalize(path.join(currentPath, ".."));
}
return "/";
};
/**
*
* @param taskId The ID of the task for which we are generating code
* @param taskSpecDir The path to the directory where the input.json & output.json files are
* @param allSpecFiles An array of paths to all the tasks specs. Allows resolving cross-file references ($ref).
*/
async function buildInputData(taskId: string, taskSpecDir: string, allSpecFiles: string[]): Promise<InputData> {
const schema = new JSONSchemaInput(new FetchingJSONSchemaStore(), [], allSpecFiles);
await schema.addSource({
name: `${taskId}-input`,
schema: await fs.readFile(`${taskSpecDir}/input.json`, { encoding: "utf-8" }),
});
await schema.addSource({
name: `${taskId}-output`,
schema: await fs.readFile(`${taskSpecDir}/output.json`, { encoding: "utf-8" }),
});
if (taskId === "text-generation" || taskId === "chat-completion") {
await schema.addSource({
name: `${taskId}-stream-output`,
schema: await fs.readFile(`${taskSpecDir}/stream_output.json`, { encoding: "utf-8" }),
});
}
const inputData = new InputData();
inputData.addInput(schema);
return inputData;
}
async function generateTypescript(inputData: InputData): Promise<SerializedRenderResult> {
return await quicktype({
inputData,
lang: "typescript",
alphabetizeProperties: true,
indentation: "\t",
rendererOptions: {
"just-types": true,
"nice-property-names": false,
"prefer-unions": true,
"prefer-const-values": true,
"prefer-unknown": true,
"explicit-unions": true,
"runtime-typecheck": false,
},
});
}
async function generatePython(inputData: InputData): Promise<SerializedRenderResult> {
return await quicktype({
inputData,
lang: "python",
alphabetizeProperties: true,
rendererOptions: {
"just-types": true,
"nice-property-names": true,
"python-version": "3.7",
},
});
}
/**
* quicktype is unable to generate "top-level array types" that are defined in the output spec: https://github.com/glideapps/quicktype/issues/2481
* We have to use the TypeScript API to generate those types when required.
* This hacky function:
* - looks for the generated interface for output types
* - renames it with a `Element` suffix
* - generates type alias in the form `export type <OutputType> = <OutputType>Element[];
*
* And writes that to the `inference.ts` file
*
*/
async function postProcessOutput(path2generated: string, outputSpec: Record<string, unknown>): Promise<void> {
const source = ts.createSourceFile(
path.basename(path2generated),
await fs.readFile(path2generated, { encoding: "utf-8" }),
ts.ScriptTarget.ES2022
);
const exportedName = outputSpec.title;
if (outputSpec.type !== "array" || typeof exportedName !== "string") {
console.log(" Nothing to do");
return;
}
const topLevelNodes = source.getChildAt(0).getChildren();
const hasTypeAlias = topLevelNodes.some(
(node) =>
node.kind === ts.SyntaxKind.TypeAliasDeclaration &&
(node as ts.TypeAliasDeclaration).name.escapedText === exportedName
);
if (hasTypeAlias) {
return;
}
const interfaceDeclaration = topLevelNodes.find((node): node is ts.InterfaceDeclaration => {
if (node.kind === ts.SyntaxKind.InterfaceDeclaration) {
return (node as ts.InterfaceDeclaration).name.getText(source) === exportedName;
}
return false;
});
if (!interfaceDeclaration) {
console.log(" Nothing to do");
return;
}
console.log(" Inserting top-level array type alias...");
const updatedInterface = ts.factory.updateInterfaceDeclaration(
interfaceDeclaration,
interfaceDeclaration.modifiers,
ts.factory.createIdentifier(interfaceDeclaration.name.getText(source) + "Element"),
interfaceDeclaration.typeParameters,
interfaceDeclaration.heritageClauses,
interfaceDeclaration.members
);
const arrayDeclaration = ts.factory.createTypeAliasDeclaration(
[ts.factory.createModifier(ts.SyntaxKind.ExportKeyword)],
exportedName,
undefined,
ts.factory.createArrayTypeNode(ts.factory.createTypeReferenceNode(updatedInterface.name))
);
const printer = ts.createPrinter();
const newNodes = ts.factory.createNodeArray([
...topLevelNodes.filter((node) => node !== interfaceDeclaration),
arrayDeclaration,
updatedInterface,
]);
await fs.writeFile(path2generated, printer.printList(ts.ListFormat.MultiLine, newNodes, source), {
flag: "w+",
encoding: "utf-8",
});
return;
}
const rootDir = rootDirFinder();
const tasksDir = path.join(rootDir, "src", "tasks");
const allTasks = await Promise.all(
(await fs.readdir(tasksDir, { withFileTypes: true }))
.filter((entry) => entry.isDirectory())
.filter((entry) => entry.name !== "placeholder")
.map(async (entry) => ({ task: entry.name, dirPath: path.join(entry.path, entry.name) }))
);
const allSpecFiles = [
path.join(tasksDir, "common-definitions.json"),
...allTasks
.flatMap(({ dirPath }) => [path.join(dirPath, "spec", "input.json"), path.join(dirPath, "spec", "output.json")])
.filter((filepath) => pathExists(filepath)),
];
for (const { task, dirPath } of allTasks) {
const taskSpecDir = path.join(dirPath, "spec");
if (!(pathExists(path.join(taskSpecDir, "input.json")) && pathExists(path.join(taskSpecDir, "output.json")))) {
console.debug(`No spec found for task ${task} - skipping`);
continue;
}
console.debug(`✨ Generating types for task`, task);
console.debug(" 📦 Building input data");
const inputData = await buildInputData(task, taskSpecDir, allSpecFiles);
console.debug(" 🏭 Generating typescript code");
{
const { lines } = await generateTypescript(inputData);
await fs.writeFile(`${dirPath}/inference.ts`, [TYPESCRIPT_HEADER_FILE, ...lines].join(`\n`), {
flag: "w+",
encoding: "utf-8",
});
}
const outputSpec = JSON.parse(await fs.readFile(`${taskSpecDir}/output.json`, { encoding: "utf-8" }));
console.log(" 🩹 Post-processing the generated code");
await postProcessOutput(`${dirPath}/inference.ts`, outputSpec);
console.debug(" 🏭 Generating Python code");
{
const { lines } = await generatePython(inputData);
const pythonFilename = `${task}.py`.replace(/-/g, "_");
const pythonPath = `${PYTHON_DIR}/${pythonFilename}`;
await fs.mkdir(PYTHON_DIR, { recursive: true });
await fs.writeFile(pythonPath, [PYTHON_HEADER_FILE, ...lines].join(`\n`), {
flag: "w+",
encoding: "utf-8",
});
}
}
console.debug("✅ All done!");
|