Spaces:
Runtime error
Runtime error
File size: 5,339 Bytes
f8d016c 29e20e0 f8d016c 29e20e0 fd2180b 945bdba 29e20e0 945bdba 29e20e0 317d5b5 29e20e0 317d5b5 29e20e0 fd2180b 945bdba fd2180b 945bdba fd2180b 945bdba fd2180b 945bdba fd2180b 29e20e0 945bdba 29e20e0 f8d016c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
import { chromium, firefox, webkit, Browser } from "playwright";
import { createServer } from "vite";
import { getArg } from "../core/args.js";
// CLI for running browser benchmarks headlessly via Playwright
const modelId = process.argv[2] || "Xenova/distilbert-base-uncased";
const task = process.argv[3] || "feature-extraction";
const mode = getArg("mode", "warm") as "warm" | "cold";
const repeats = Math.max(1, parseInt(getArg("repeats", "3") || "3", 10));
const device = getArg("device", "webgpu") as "webgpu" | "wasm";
const dtype = getArg("dtype"); // optional: fp32, fp16, q8, q4, etc.
const batchSize = Math.max(1, parseInt(getArg("batch-size", "1") || "1", 10));
const browserType = getArg("browser", "chromium") as "chromium" | "firefox" | "webkit";
const headed = getArg("headed") === "true";
async function main() {
console.log(`Model : ${modelId}`);
console.log(`Task : ${task}`);
console.log(`Mode : ${mode}`);
console.log(`Repeats : ${repeats}`);
console.log(`Device : ${device}`);
console.log(`DType : ${dtype || 'auto'}`);
console.log(`Batch Size : ${batchSize}`);
console.log(`Browser : ${browserType}`);
console.log(`Headed : ${headed}`);
// Start Vite dev server
const server = await createServer({
server: {
port: 5173,
strictPort: false,
},
logLevel: "error",
});
await server.listen();
const port = server.config.server.port || 5173;
const url = `http://localhost:${port}`;
console.log(`Vite server started at ${url}`);
let browser: Browser;
// Build args based on mode
const args = device === "wasm"
? ["--disable-gpu", "--disable-software-rasterizer"]
: [
// Official WebGPU flags from Chrome team
// https://developer.chrome.com/blog/supercharge-web-ai-testing#enable-webgpu
"--enable-unsafe-webgpu",
"--enable-features=Vulkan",
];
// Add headless-specific flags only in headless mode
if (!headed && device !== "wasm") {
args.push(
"--no-sandbox",
"--headless=new",
"--use-angle=vulkan",
"--disable-vulkan-surface"
);
}
const launchOptions = {
headless: !headed,
args,
};
switch (browserType) {
case "firefox":
browser = await firefox.launch(launchOptions);
break;
case "webkit":
browser = await webkit.launch(launchOptions);
break;
case "chromium":
default:
browser = await chromium.launch(launchOptions);
break;
}
try {
const context = await browser.newContext();
const page = await context.newPage();
// Expose console logs
page.on("console", (msg) => {
const type = msg.type();
if (type === "error" || type === "warning") {
console.log(`[browser ${type}]`, msg.text());
}
});
// Navigate to the app
await page.goto(url);
// Wait for the page to be ready
await page.waitForSelector("#run");
console.log("\nStarting benchmark...");
// Check WebGPU availability if using webgpu device
if (device === "webgpu") {
const gpuInfo = await page.evaluate(async () => {
if (!('gpu' in navigator)) {
return { available: false, adapter: null, features: null };
}
try {
const adapter = await (navigator as any).gpu.requestAdapter();
if (!adapter) {
return { available: false, adapter: null, features: null };
}
const features = Array.from(adapter.features || []);
const limits = adapter.limits ? {
maxTextureDimension2D: adapter.limits.maxTextureDimension2D,
maxComputeWorkgroupSizeX: adapter.limits.maxComputeWorkgroupSizeX,
} : null;
return {
available: true,
adapterInfo: adapter.info ? adapter.info.description : 'Unknown',
features,
limits
};
} catch (e) {
return { available: false, adapter: null, error: String(e) };
}
});
if (!gpuInfo.available) {
console.error("\n❌ WebGPU is not available in this browser!");
console.error("Make sure to use --enable-unsafe-webgpu flag for Chromium.");
if (gpuInfo.error) console.error("Error:", gpuInfo.error);
throw new Error("WebGPU not available");
}
console.log("✓ WebGPU is available");
console.log(` Adapter: ${gpuInfo.adapterInfo}`);
if (gpuInfo.features && gpuInfo.features.length > 0) {
console.log(` Features: ${gpuInfo.features.slice(0, 3).join(', ')}${gpuInfo.features.length > 3 ? '...' : ''}`);
}
}
// Use the exposed CLI function from main.ts
const result = await page.evaluate(({ modelId, task, mode, repeats, device, dtype, batchSize }) => {
return (window as any).runBenchmarkCLI({ modelId, task, mode, repeats, device, dtype, batchSize });
}, { modelId, task, mode, repeats, device, dtype, batchSize });
console.log("\n" + JSON.stringify(result, null, 2));
} finally {
await browser.close();
await server.close();
}
}
// Check if this module is being run directly (not imported)
const isMainModule = process.argv[1]?.includes('web/cli');
if (isMainModule) {
main().catch((e) => {
console.error(e);
process.exit(1);
});
}
export { main };
|