File size: 5,339 Bytes
f8d016c
29e20e0
f8d016c
29e20e0
 
 
 
 
 
 
 
 
fd2180b
945bdba
29e20e0
 
 
 
945bdba
 
 
 
 
 
 
 
 
29e20e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
317d5b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29e20e0
 
317d5b5
29e20e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fd2180b
 
945bdba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fd2180b
 
945bdba
fd2180b
 
945bdba
fd2180b
 
 
 
945bdba
 
 
 
fd2180b
 
29e20e0
945bdba
 
 
29e20e0
 
 
 
 
 
 
 
 
f8d016c
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import { chromium, firefox, webkit, Browser } from "playwright";
import { createServer } from "vite";
import { getArg } from "../core/args.js";

// CLI for running browser benchmarks headlessly via Playwright

const modelId = process.argv[2] || "Xenova/distilbert-base-uncased";
const task = process.argv[3] || "feature-extraction";

const mode = getArg("mode", "warm") as "warm" | "cold";
const repeats = Math.max(1, parseInt(getArg("repeats", "3") || "3", 10));
const device = getArg("device", "webgpu") as "webgpu" | "wasm";
const dtype = getArg("dtype"); // optional: fp32, fp16, q8, q4, etc.
const batchSize = Math.max(1, parseInt(getArg("batch-size", "1") || "1", 10));
const browserType = getArg("browser", "chromium") as "chromium" | "firefox" | "webkit";
const headed = getArg("headed") === "true";

async function main() {
  console.log(`Model      : ${modelId}`);
  console.log(`Task       : ${task}`);
  console.log(`Mode       : ${mode}`);
  console.log(`Repeats    : ${repeats}`);
  console.log(`Device     : ${device}`);
  console.log(`DType      : ${dtype || 'auto'}`);
  console.log(`Batch Size : ${batchSize}`);
  console.log(`Browser    : ${browserType}`);
  console.log(`Headed     : ${headed}`);

  // Start Vite dev server
  const server = await createServer({
    server: {
      port: 5173,
      strictPort: false,
    },
    logLevel: "error",
  });

  await server.listen();

  const port = server.config.server.port || 5173;
  const url = `http://localhost:${port}`;

  console.log(`Vite server started at ${url}`);

  let browser: Browser;

  // Build args based on mode
  const args = device === "wasm"
    ? ["--disable-gpu", "--disable-software-rasterizer"]
    : [
        // Official WebGPU flags from Chrome team
        // https://developer.chrome.com/blog/supercharge-web-ai-testing#enable-webgpu
        "--enable-unsafe-webgpu",
        "--enable-features=Vulkan",
      ];

  // Add headless-specific flags only in headless mode
  if (!headed && device !== "wasm") {
    args.push(
      "--no-sandbox",
      "--headless=new",
      "--use-angle=vulkan",
      "--disable-vulkan-surface"
    );
  }

  const launchOptions = {
    headless: !headed,
    args,
  };

  switch (browserType) {
    case "firefox":
      browser = await firefox.launch(launchOptions);
      break;
    case "webkit":
      browser = await webkit.launch(launchOptions);
      break;
    case "chromium":
    default:
      browser = await chromium.launch(launchOptions);
      break;
  }

  try {
    const context = await browser.newContext();
    const page = await context.newPage();

    // Expose console logs
    page.on("console", (msg) => {
      const type = msg.type();
      if (type === "error" || type === "warning") {
        console.log(`[browser ${type}]`, msg.text());
      }
    });

    // Navigate to the app
    await page.goto(url);

    // Wait for the page to be ready
    await page.waitForSelector("#run");

    console.log("\nStarting benchmark...");

    // Check WebGPU availability if using webgpu device
    if (device === "webgpu") {
      const gpuInfo = await page.evaluate(async () => {
        if (!('gpu' in navigator)) {
          return { available: false, adapter: null, features: null };
        }
        try {
          const adapter = await (navigator as any).gpu.requestAdapter();
          if (!adapter) {
            return { available: false, adapter: null, features: null };
          }
          const features = Array.from(adapter.features || []);
          const limits = adapter.limits ? {
            maxTextureDimension2D: adapter.limits.maxTextureDimension2D,
            maxComputeWorkgroupSizeX: adapter.limits.maxComputeWorkgroupSizeX,
          } : null;
          return {
            available: true,
            adapterInfo: adapter.info ? adapter.info.description : 'Unknown',
            features,
            limits
          };
        } catch (e) {
          return { available: false, adapter: null, error: String(e) };
        }
      });

      if (!gpuInfo.available) {
        console.error("\n❌ WebGPU is not available in this browser!");
        console.error("Make sure to use --enable-unsafe-webgpu flag for Chromium.");
        if (gpuInfo.error) console.error("Error:", gpuInfo.error);
        throw new Error("WebGPU not available");
      }

      console.log("✓ WebGPU is available");
      console.log(`  Adapter: ${gpuInfo.adapterInfo}`);
      if (gpuInfo.features && gpuInfo.features.length > 0) {
        console.log(`  Features: ${gpuInfo.features.slice(0, 3).join(', ')}${gpuInfo.features.length > 3 ? '...' : ''}`);
      }
    }

    // Use the exposed CLI function from main.ts
    const result = await page.evaluate(({ modelId, task, mode, repeats, device, dtype, batchSize }) => {
      return (window as any).runBenchmarkCLI({ modelId, task, mode, repeats, device, dtype, batchSize });
    }, { modelId, task, mode, repeats, device, dtype, batchSize });

    console.log("\n" + JSON.stringify(result, null, 2));

  } finally {
    await browser.close();
    await server.close();
  }
}

// Check if this module is being run directly (not imported)
const isMainModule = process.argv[1]?.includes('web/cli');

if (isMainModule) {
  main().catch((e) => {
    console.error(e);
    process.exit(1);
  });
}

export { main };