Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import type { SerializedRenderResult } from "quicktype-core"; | |
import { quicktype, InputData, JSONSchemaInput, FetchingJSONSchemaStore } from "quicktype-core"; | |
import * as fs from "node:fs/promises"; | |
import { existsSync as pathExists } from "node:fs"; | |
import * as path from "node:path/posix"; | |
import ts from "typescript"; | |
const TYPESCRIPT_HEADER_FILE = ` | |
/** | |
* Inference code generated from the JSON schema spec in ./spec | |
* | |
* Using src/scripts/inference-codegen | |
*/ | |
`; | |
const PYTHON_HEADER_FILE = ` | |
# Inference code generated from the JSON schema spec in @huggingface/tasks. | |
# | |
# See: | |
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts | |
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. | |
`; | |
const PYTHON_DIR = "./.python_generated"; | |
const rootDirFinder = function (): string { | |
let currentPath = path.normalize(import.meta.url); | |
while (currentPath !== "/") { | |
if (pathExists(path.join(currentPath, "package.json"))) { | |
return currentPath; | |
} | |
currentPath = path.normalize(path.join(currentPath, "..")); | |
} | |
return "/"; | |
}; | |
/** | |
* | |
* @param taskId The ID of the task for which we are generating code | |
* @param taskSpecDir The path to the directory where the input.json & output.json files are | |
* @param allSpecFiles An array of paths to all the tasks specs. Allows resolving cross-file references ($ref). | |
*/ | |
async function buildInputData(taskId: string, taskSpecDir: string, allSpecFiles: string[]): Promise<InputData> { | |
const schema = new JSONSchemaInput(new FetchingJSONSchemaStore(), [], allSpecFiles); | |
await schema.addSource({ | |
name: `${taskId}-input`, | |
schema: await fs.readFile(`${taskSpecDir}/input.json`, { encoding: "utf-8" }), | |
}); | |
await schema.addSource({ | |
name: `${taskId}-output`, | |
schema: await fs.readFile(`${taskSpecDir}/output.json`, { encoding: "utf-8" }), | |
}); | |
if (taskId === "text-generation" || taskId === "chat-completion") { | |
await schema.addSource({ | |
name: `${taskId}-stream-output`, | |
schema: await fs.readFile(`${taskSpecDir}/stream_output.json`, { encoding: "utf-8" }), | |
}); | |
} | |
const inputData = new InputData(); | |
inputData.addInput(schema); | |
return inputData; | |
} | |
async function generateTypescript(inputData: InputData): Promise<SerializedRenderResult> { | |
return await quicktype({ | |
inputData, | |
lang: "typescript", | |
alphabetizeProperties: true, | |
indentation: "\t", | |
rendererOptions: { | |
"just-types": true, | |
"nice-property-names": false, | |
"prefer-unions": true, | |
"prefer-const-values": true, | |
"prefer-unknown": true, | |
"explicit-unions": true, | |
"runtime-typecheck": false, | |
}, | |
}); | |
} | |
async function generatePython(inputData: InputData): Promise<SerializedRenderResult> { | |
return await quicktype({ | |
inputData, | |
lang: "python", | |
alphabetizeProperties: true, | |
rendererOptions: { | |
"just-types": true, | |
"nice-property-names": true, | |
"python-version": "3.7", | |
}, | |
}); | |
} | |
/** | |
* quicktype is unable to generate "top-level array types" that are defined in the output spec: https://github.com/glideapps/quicktype/issues/2481 | |
* We have to use the TypeScript API to generate those types when required. | |
* This hacky function: | |
* - looks for the generated interface for output types | |
* - renames it with a `Element` suffix | |
* - generates type alias in the form `export type <OutputType> = <OutputType>Element[]; | |
* | |
* And writes that to the `inference.ts` file | |
* | |
*/ | |
async function postProcessOutput(path2generated: string, outputSpec: Record<string, unknown>): Promise<void> { | |
const source = ts.createSourceFile( | |
path.basename(path2generated), | |
await fs.readFile(path2generated, { encoding: "utf-8" }), | |
ts.ScriptTarget.ES2022 | |
); | |
const exportedName = outputSpec.title; | |
if (outputSpec.type !== "array" || typeof exportedName !== "string") { | |
console.log(" Nothing to do"); | |
return; | |
} | |
const topLevelNodes = source.getChildAt(0).getChildren(); | |
const hasTypeAlias = topLevelNodes.some( | |
(node) => | |
node.kind === ts.SyntaxKind.TypeAliasDeclaration && | |
(node as ts.TypeAliasDeclaration).name.escapedText === exportedName | |
); | |
if (hasTypeAlias) { | |
return; | |
} | |
const interfaceDeclaration = topLevelNodes.find((node): node is ts.InterfaceDeclaration => { | |
if (node.kind === ts.SyntaxKind.InterfaceDeclaration) { | |
return (node as ts.InterfaceDeclaration).name.getText(source) === exportedName; | |
} | |
return false; | |
}); | |
if (!interfaceDeclaration) { | |
console.log(" Nothing to do"); | |
return; | |
} | |
console.log(" Inserting top-level array type alias..."); | |
const updatedInterface = ts.factory.updateInterfaceDeclaration( | |
interfaceDeclaration, | |
interfaceDeclaration.modifiers, | |
ts.factory.createIdentifier(interfaceDeclaration.name.getText(source) + "Element"), | |
interfaceDeclaration.typeParameters, | |
interfaceDeclaration.heritageClauses, | |
interfaceDeclaration.members | |
); | |
const arrayDeclaration = ts.factory.createTypeAliasDeclaration( | |
[ts.factory.createModifier(ts.SyntaxKind.ExportKeyword)], | |
exportedName, | |
undefined, | |
ts.factory.createArrayTypeNode(ts.factory.createTypeReferenceNode(updatedInterface.name)) | |
); | |
const printer = ts.createPrinter(); | |
const newNodes = ts.factory.createNodeArray([ | |
...topLevelNodes.filter((node) => node !== interfaceDeclaration), | |
arrayDeclaration, | |
updatedInterface, | |
]); | |
await fs.writeFile(path2generated, printer.printList(ts.ListFormat.MultiLine, newNodes, source), { | |
flag: "w+", | |
encoding: "utf-8", | |
}); | |
return; | |
} | |
const rootDir = rootDirFinder(); | |
const tasksDir = path.join(rootDir, "src", "tasks"); | |
const allTasks = await Promise.all( | |
(await fs.readdir(tasksDir, { withFileTypes: true })) | |
.filter((entry) => entry.isDirectory()) | |
.filter((entry) => entry.name !== "placeholder") | |
.map(async (entry) => ({ task: entry.name, dirPath: path.join(entry.path, entry.name) })) | |
); | |
const allSpecFiles = [ | |
path.join(tasksDir, "common-definitions.json"), | |
...allTasks | |
.flatMap(({ dirPath }) => [path.join(dirPath, "spec", "input.json"), path.join(dirPath, "spec", "output.json")]) | |
.filter((filepath) => pathExists(filepath)), | |
]; | |
for (const { task, dirPath } of allTasks) { | |
const taskSpecDir = path.join(dirPath, "spec"); | |
if (!(pathExists(path.join(taskSpecDir, "input.json")) && pathExists(path.join(taskSpecDir, "output.json")))) { | |
console.debug(`No spec found for task ${task} - skipping`); | |
continue; | |
} | |
console.debug(`β¨ Generating types for task`, task); | |
console.debug(" π¦ Building input data"); | |
const inputData = await buildInputData(task, taskSpecDir, allSpecFiles); | |
console.debug(" π Generating typescript code"); | |
{ | |
const { lines } = await generateTypescript(inputData); | |
await fs.writeFile(`${dirPath}/inference.ts`, [TYPESCRIPT_HEADER_FILE, ...lines].join(`\n`), { | |
flag: "w+", | |
encoding: "utf-8", | |
}); | |
} | |
const outputSpec = JSON.parse(await fs.readFile(`${taskSpecDir}/output.json`, { encoding: "utf-8" })); | |
console.log(" π©Ή Post-processing the generated code"); | |
await postProcessOutput(`${dirPath}/inference.ts`, outputSpec); | |
console.debug(" π Generating Python code"); | |
{ | |
const { lines } = await generatePython(inputData); | |
const pythonFilename = `${task}.py`.replace(/-/g, "_"); | |
const pythonPath = `${PYTHON_DIR}/${pythonFilename}`; | |
await fs.mkdir(PYTHON_DIR, { recursive: true }); | |
await fs.writeFile(pythonPath, [PYTHON_HEADER_FILE, ...lines].join(`\n`), { | |
flag: "w+", | |
encoding: "utf-8", | |
}); | |
} | |
} | |
console.debug("β All done!"); | |