multimodalart HF Staff commited on
Commit
f555806
·
verified ·
1 Parent(s): 9e57aab

Upload 121 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. ui/README.md +55 -0
  2. ui/cron/worker.ts +31 -0
  3. ui/next-env.d.ts +5 -0
  4. ui/next.config.ts +15 -0
  5. ui/package-lock.json +0 -0
  6. ui/package.json +51 -0
  7. ui/postcss.config.mjs +8 -0
  8. ui/prisma/schema.prisma +38 -0
  9. ui/public/file.svg +1 -0
  10. ui/public/globe.svg +1 -0
  11. ui/public/next.svg +1 -0
  12. ui/public/ostris_logo.png +0 -0
  13. ui/public/vercel.svg +1 -0
  14. ui/public/web-app-manifest-192x192.png +0 -0
  15. ui/public/web-app-manifest-512x512.png +0 -0
  16. ui/public/window.svg +1 -0
  17. ui/src/.DS_Store +0 -0
  18. ui/src/app/.DS_Store +0 -0
  19. ui/src/app/api/.DS_Store +0 -0
  20. ui/src/app/api/auth/hf/callback/route.ts +112 -0
  21. ui/src/app/api/auth/hf/login/route.ts +36 -0
  22. ui/src/app/api/auth/hf/validate/route.ts +22 -0
  23. ui/src/app/api/auth/route.ts +6 -0
  24. ui/src/app/api/caption/get/route.ts +46 -0
  25. ui/src/app/api/datasets/create/route.tsx +25 -0
  26. ui/src/app/api/datasets/delete/route.tsx +24 -0
  27. ui/src/app/api/datasets/list/route.ts +25 -0
  28. ui/src/app/api/datasets/listImages/route.ts +61 -0
  29. ui/src/app/api/datasets/upload/route.ts +57 -0
  30. ui/src/app/api/files/[...filePath]/route.ts +116 -0
  31. ui/src/app/api/gpu/route.ts +121 -0
  32. ui/src/app/api/hf-hub/route.ts +165 -0
  33. ui/src/app/api/hf-jobs/route.ts +761 -0
  34. ui/src/app/api/img/[...imagePath]/route.ts +78 -0
  35. ui/src/app/api/img/caption/route.ts +29 -0
  36. ui/src/app/api/img/delete/route.ts +34 -0
  37. ui/src/app/api/img/upload/route.ts +58 -0
  38. ui/src/app/api/jobs/[jobID]/delete/route.ts +32 -0
  39. ui/src/app/api/jobs/[jobID]/files/route.ts +48 -0
  40. ui/src/app/api/jobs/[jobID]/log/route.ts +35 -0
  41. ui/src/app/api/jobs/[jobID]/samples/route.ts +40 -0
  42. ui/src/app/api/jobs/[jobID]/start/route.ts +215 -0
  43. ui/src/app/api/jobs/[jobID]/stop/route.ts +23 -0
  44. ui/src/app/api/jobs/route.ts +67 -0
  45. ui/src/app/api/settings/route.ts +59 -0
  46. ui/src/app/api/zip/route.ts +78 -0
  47. ui/src/app/apple-icon.png +0 -0
  48. ui/src/app/dashboard/page.tsx +54 -0
  49. ui/src/app/datasets/[datasetName]/page.tsx +190 -0
  50. ui/src/app/datasets/page.tsx +217 -0
ui/README.md ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ This is a [Next.js](https://nextjs.org) project bootstrapped with [`create-next-app`](https://nextjs.org/docs/app/api-reference/cli/create-next-app).
2
+
3
+ ## Getting Started
4
+
5
+ First, run the development server:
6
+
7
+ ```bash
8
+ npm run dev
9
+ # or
10
+ yarn dev
11
+ # or
12
+ pnpm dev
13
+ # or
14
+ bun dev
15
+ ```
16
+
17
+ Open [http://localhost:3000](http://localhost:3000) with your browser to see the result.
18
+
19
+ You can start editing the page by modifying `app/page.tsx`. The page auto-updates as you edit the file.
20
+
21
+ This project uses [`next/font`](https://nextjs.org/docs/app/building-your-application/optimizing/fonts) to automatically optimize and load [Geist](https://vercel.com/font), a new font family for Vercel.
22
+
23
+ ## Database Modes
24
+
25
+ Use the `NEXT_PUBLIC_DB_MODE` environment variable to control how UI data is persisted:
26
+
27
+ - `server` (default): interacts with the shared SQLite database through Prisma. Supports local job orchestration.
28
+ - `browser`: stores jobs and settings in the user's browser (localStorage). This mode only supports Hugging Face Jobs workflows; local GPU training controls are disabled and the GPU monitor shows a cloud-mode status banner.
29
+
30
+ When running in browser mode every visitor sees only their own jobs, settings, and dataset catalog (all stored in their browser), making the UI safe to host for multiple users without sharing the SQLite file.
31
+
32
+ ## Hugging Face Authentication
33
+
34
+ Users can authenticate either by pasting a personal access token or via the Hugging Face OAuth flow. To enable OAuth set the following environment variables for the UI:
35
+
36
+ - `HF_OAUTH_CLIENT_ID` – the application client ID
37
+ - `HF_OAUTH_CLIENT_SECRET` – the application secret (server-side only)
38
+ - `NEXT_PUBLIC_HF_OAUTH_CLIENT_ID` – the client ID exposed to the browser (usually the same as `HF_OAUTH_CLIENT_ID`)
39
+
40
+ If these values are not provided the UI falls back to manual token entry. In multi-user/browser mode the authenticated token and namespace are stored per browser session.
41
+
42
+ ## Learn More
43
+
44
+ To learn more about Next.js, take a look at the following resources:
45
+
46
+ - [Next.js Documentation](https://nextjs.org/docs) - learn about Next.js features and API.
47
+ - [Learn Next.js](https://nextjs.org/learn) - an interactive Next.js tutorial.
48
+
49
+ You can check out [the Next.js GitHub repository](https://github.com/vercel/next.js) - your feedback and contributions are welcome!
50
+
51
+ ## Deploy on Vercel
52
+
53
+ The easiest way to deploy your Next.js app is to use the [Vercel Platform](https://vercel.com/new?utm_medium=default-template&filter=next.js&utm_source=create-next-app&utm_campaign=create-next-app-readme) from the creators of Next.js.
54
+
55
+ Check out our [Next.js deployment documentation](https://nextjs.org/docs/app/building-your-application/deploying) for more details.
ui/cron/worker.ts ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class CronWorker {
2
+ interval: number;
3
+ is_running: boolean;
4
+ intervalId: NodeJS.Timeout;
5
+ constructor() {
6
+ this.interval = 1000; // Default interval of 1 second
7
+ this.is_running = false;
8
+ this.intervalId = setInterval(() => {
9
+ this.run();
10
+ }, this.interval);
11
+ }
12
+ async run() {
13
+ if (this.is_running) {
14
+ return;
15
+ }
16
+ this.is_running = true;
17
+ try {
18
+ // Loop logic here
19
+ await this.loop();
20
+ } catch (error) {
21
+ console.error('Error in cron worker loop:', error);
22
+ }
23
+ this.is_running = false;
24
+ }
25
+
26
+ async loop() {}
27
+ }
28
+
29
+ // it automatically starts the loop
30
+ const cronWorker = new CronWorker();
31
+ console.log('Cron worker started with interval:', cronWorker.interval, 'ms');
ui/next-env.d.ts ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ /// <reference types="next" />
2
+ /// <reference types="next/image-types/global" />
3
+
4
+ // NOTE: This file should not be edited
5
+ // see https://nextjs.org/docs/app/api-reference/config/typescript for more information.
ui/next.config.ts ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import type { NextConfig } from 'next';
2
+
3
+ const nextConfig: NextConfig = {
4
+ typescript: {
5
+ // Remove this. Build fails because of route types
6
+ ignoreBuildErrors: true,
7
+ },
8
+ experimental: {
9
+ serverActions: {
10
+ bodySizeLimit: '100mb',
11
+ },
12
+ },
13
+ };
14
+
15
+ export default nextConfig;
ui/package-lock.json ADDED
The diff for this file is too large to render. See raw diff
 
ui/package.json ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "ai-toolkit-ui",
3
+ "version": "0.1.0",
4
+ "private": true,
5
+ "scripts": {
6
+ "dev": "concurrently -k -n WORKER,UI \"ts-node-dev --respawn --watch cron --transpile-only cron/worker.ts\" \"next dev --turbopack\"",
7
+ "build": "tsc -p tsconfig.worker.json && next build",
8
+ "start": "concurrently --restart-tries -1 --restart-after 1000 -n WORKER,UI \"node dist/worker.js\" \"next start --port 8675\"",
9
+ "build_and_start": "npm install && npm run update_db && npm run build && npm run start",
10
+ "lint": "next lint",
11
+ "update_db": "npx prisma generate && npx prisma db push",
12
+ "format": "prettier --write \"**/*.{js,jsx,ts,tsx,css,scss}\""
13
+ },
14
+ "dependencies": {
15
+ "@headlessui/react": "^2.2.0",
16
+ "@huggingface/hub": "^2.5.2",
17
+ "@monaco-editor/react": "^4.7.0",
18
+ "@prisma/client": "^6.3.1",
19
+ "archiver": "^7.0.1",
20
+ "axios": "^1.7.9",
21
+ "classnames": "^2.5.1",
22
+ "form-data": "^4.0.4",
23
+ "lucide-react": "^0.475.0",
24
+ "next": "15.1.7",
25
+ "node-cache": "^5.1.2",
26
+ "prisma": "^6.3.1",
27
+ "react": "^19.0.0",
28
+ "react-dom": "^19.0.0",
29
+ "react-dropzone": "^14.3.5",
30
+ "react-global-hooks": "^1.3.5",
31
+ "react-icons": "^5.5.0",
32
+ "react-select": "^5.10.1",
33
+ "sqlite3": "^5.1.7",
34
+ "uuid": "^11.1.0",
35
+ "yaml": "^2.7.0"
36
+ },
37
+ "devDependencies": {
38
+ "@types/archiver": "^6.0.3",
39
+ "@types/node": "^20",
40
+ "@types/react": "^19",
41
+ "@types/react-dom": "^19",
42
+ "concurrently": "^9.1.2",
43
+ "postcss": "^8",
44
+ "prettier": "^3.5.1",
45
+ "prettier-basic": "^1.0.0",
46
+ "tailwindcss": "^3.4.1",
47
+ "ts-node-dev": "^2.0.0",
48
+ "typescript": "^5"
49
+ },
50
+ "prettier": "prettier-basic"
51
+ }
ui/postcss.config.mjs ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ /** @type {import('postcss-load-config').Config} */
2
+ const config = {
3
+ plugins: {
4
+ tailwindcss: {},
5
+ },
6
+ };
7
+
8
+ export default config;
ui/prisma/schema.prisma ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ generator client {
2
+ provider = "prisma-client-js"
3
+ }
4
+
5
+ datasource db {
6
+ provider = "sqlite"
7
+ url = "file:../../aitk_db.db"
8
+ }
9
+
10
+ model Settings {
11
+ id Int @id @default(autoincrement())
12
+ key String @unique
13
+ value String
14
+ }
15
+
16
+ model Job {
17
+ id String @id @default(uuid())
18
+ name String @unique
19
+ gpu_ids String
20
+ job_config String // JSON string
21
+ created_at DateTime @default(now())
22
+ updated_at DateTime @updatedAt
23
+ status String @default("stopped")
24
+ stop Boolean @default(false)
25
+ step Int @default(0)
26
+ info String @default("")
27
+ speed_string String @default("")
28
+ }
29
+
30
+ model Queue {
31
+ id String @id @default(uuid())
32
+ channel String
33
+ job_id String
34
+ created_at DateTime @default(now())
35
+ updated_at DateTime @updatedAt
36
+ status String @default("waiting")
37
+ @@index([job_id, channel])
38
+ }
ui/public/file.svg ADDED
ui/public/globe.svg ADDED
ui/public/next.svg ADDED
ui/public/ostris_logo.png ADDED
ui/public/vercel.svg ADDED
ui/public/web-app-manifest-192x192.png ADDED
ui/public/web-app-manifest-512x512.png ADDED
ui/public/window.svg ADDED
ui/src/.DS_Store ADDED
Binary file (6.15 kB). View file
 
ui/src/app/.DS_Store ADDED
Binary file (6.15 kB). View file
 
ui/src/app/api/.DS_Store ADDED
Binary file (6.15 kB). View file
 
ui/src/app/api/auth/hf/callback/route.ts ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { NextRequest, NextResponse } from 'next/server';
2
+ import { cookies } from 'next/headers';
3
+
4
+ const TOKEN_ENDPOINT = 'https://huggingface.co/oauth/token';
5
+ const USERINFO_ENDPOINT = 'https://huggingface.co/oauth/userinfo';
6
+ const STATE_COOKIE = 'hf_oauth_state';
7
+
8
+ function htmlResponse(script: string) {
9
+ return new NextResponse(
10
+ `<!DOCTYPE html><html><body><script>${script}</script></body></html>`,
11
+ {
12
+ headers: { 'Content-Type': 'text/html; charset=utf-8' },
13
+ },
14
+ );
15
+ }
16
+
17
+ export async function GET(request: NextRequest) {
18
+ const clientId = process.env.HF_OAUTH_CLIENT_ID || process.env.NEXT_PUBLIC_HF_OAUTH_CLIENT_ID;
19
+ const clientSecret = process.env.HF_OAUTH_CLIENT_SECRET;
20
+
21
+ if (!clientId || !clientSecret) {
22
+ return NextResponse.json({ error: 'OAuth application is not configured' }, { status: 500 });
23
+ }
24
+
25
+ const { searchParams } = new URL(request.url);
26
+ const code = searchParams.get('code');
27
+ const incomingState = searchParams.get('state');
28
+
29
+ const cookieStore = cookies();
30
+ const storedState = cookieStore.get(STATE_COOKIE)?.value;
31
+
32
+ cookieStore.delete(STATE_COOKIE);
33
+
34
+ const origin = request.nextUrl.origin;
35
+
36
+ if (!code || !incomingState || !storedState || incomingState !== storedState) {
37
+ const script = `
38
+ window.opener && window.opener.postMessage({
39
+ type: 'HF_OAUTH_ERROR',
40
+ payload: { message: 'Invalid or expired OAuth state.' }
41
+ }, '${origin}');
42
+ window.close();
43
+ `;
44
+ return htmlResponse(script.trim());
45
+ }
46
+
47
+ const redirectUri = `${origin}/api/auth/hf/callback`;
48
+
49
+ try {
50
+ const tokenResponse = await fetch(TOKEN_ENDPOINT, {
51
+ method: 'POST',
52
+ headers: {
53
+ 'Content-Type': 'application/x-www-form-urlencoded',
54
+ },
55
+ body: new URLSearchParams({
56
+ grant_type: 'authorization_code',
57
+ code,
58
+ redirect_uri: redirectUri,
59
+ client_id: clientId,
60
+ client_secret: clientSecret,
61
+ }),
62
+ });
63
+
64
+ if (!tokenResponse.ok) {
65
+ const errorPayload = await tokenResponse.json().catch(() => ({}));
66
+ throw new Error(errorPayload?.error_description || 'Failed to exchange code for token');
67
+ }
68
+
69
+ const tokenData = await tokenResponse.json();
70
+ const accessToken = tokenData?.access_token;
71
+ if (!accessToken) {
72
+ throw new Error('Access token missing in response');
73
+ }
74
+
75
+ const userResponse = await fetch(USERINFO_ENDPOINT, {
76
+ headers: {
77
+ Authorization: `Bearer ${accessToken}`,
78
+ },
79
+ });
80
+
81
+ if (!userResponse.ok) {
82
+ throw new Error('Failed to fetch user info');
83
+ }
84
+
85
+ const profile = await userResponse.json();
86
+ const namespace = profile?.preferred_username || profile?.name || 'user';
87
+
88
+ const script = `
89
+ window.opener && window.opener.postMessage({
90
+ type: 'HF_OAUTH_SUCCESS',
91
+ payload: {
92
+ token: ${JSON.stringify(accessToken)},
93
+ namespace: ${JSON.stringify(namespace)},
94
+ }
95
+ }, '${origin}');
96
+ window.close();
97
+ `;
98
+
99
+ return htmlResponse(script.trim());
100
+ } catch (error: any) {
101
+ const message = error?.message || 'OAuth flow failed';
102
+ const script = `
103
+ window.opener && window.opener.postMessage({
104
+ type: 'HF_OAUTH_ERROR',
105
+ payload: { message: ${JSON.stringify(message)} }
106
+ }, '${origin}');
107
+ window.close();
108
+ `;
109
+
110
+ return htmlResponse(script.trim());
111
+ }
112
+ }
ui/src/app/api/auth/hf/login/route.ts ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { randomUUID } from 'crypto';
2
+ import { NextRequest, NextResponse } from 'next/server';
3
+
4
+ const HF_AUTHORIZE_URL = 'https://huggingface.co/oauth/authorize';
5
+ const STATE_COOKIE = 'hf_oauth_state';
6
+
7
+ export async function GET(request: NextRequest) {
8
+ const clientId = process.env.HF_OAUTH_CLIENT_ID || process.env.NEXT_PUBLIC_HF_OAUTH_CLIENT_ID;
9
+ if (!clientId) {
10
+ return NextResponse.json({ error: 'OAuth client ID not configured' }, { status: 500 });
11
+ }
12
+
13
+ const state = randomUUID();
14
+ const origin = request.nextUrl.origin;
15
+ const redirectUri = `${origin}/api/auth/hf/callback`;
16
+
17
+ const authorizeUrl = new URL(HF_AUTHORIZE_URL);
18
+ authorizeUrl.searchParams.set('response_type', 'code');
19
+ authorizeUrl.searchParams.set('client_id', clientId);
20
+ authorizeUrl.searchParams.set('redirect_uri', redirectUri);
21
+ authorizeUrl.searchParams.set('scope', 'openid profile read-repos');
22
+ authorizeUrl.searchParams.set('state', state);
23
+
24
+ const response = NextResponse.redirect(authorizeUrl.toString(), { status: 302 });
25
+ response.cookies.set({
26
+ name: STATE_COOKIE,
27
+ value: state,
28
+ httpOnly: true,
29
+ sameSite: 'lax',
30
+ secure: process.env.NODE_ENV === 'production',
31
+ maxAge: 60 * 5,
32
+ path: '/',
33
+ });
34
+
35
+ return response;
36
+ }
ui/src/app/api/auth/hf/validate/route.ts ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { NextRequest, NextResponse } from 'next/server';
2
+ import { whoAmI } from '@huggingface/hub';
3
+
4
+ export async function POST(request: NextRequest) {
5
+ try {
6
+ const body = await request.json().catch(() => ({}));
7
+ const token = (body?.token || '').trim();
8
+
9
+ if (!token) {
10
+ return NextResponse.json({ error: 'Token is required' }, { status: 400 });
11
+ }
12
+
13
+ const info = await whoAmI({ accessToken: token });
14
+ return NextResponse.json({
15
+ name: info?.name || info?.username || 'user',
16
+ email: info?.email || null,
17
+ orgs: info?.orgs || [],
18
+ });
19
+ } catch (error: any) {
20
+ return NextResponse.json({ error: error?.message || 'Invalid token' }, { status: 401 });
21
+ }
22
+ }
ui/src/app/api/auth/route.ts ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import { NextResponse } from 'next/server';
2
+
3
+ export async function GET() {
4
+ // if this gets hit, auth has already been verified
5
+ return NextResponse.json({ isAuthenticated: true });
6
+ }
ui/src/app/api/caption/get/route.ts ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* eslint-disable */
2
+ import { NextRequest, NextResponse } from 'next/server';
3
+ import fs from 'fs';
4
+ import path from 'path';
5
+ import { getDatasetsRoot } from '@/server/settings';
6
+
7
+ export async function POST(request: NextRequest) {
8
+
9
+ const body = await request.json();
10
+ const { imgPath } = body;
11
+ console.log('Received POST request for caption:', imgPath);
12
+ try {
13
+ // Decode the path
14
+ const filepath = imgPath;
15
+ console.log('Decoded image path:', filepath);
16
+
17
+ // caption name is the filepath without extension but with .txt
18
+ const captionPath = filepath.replace(/\.[^/.]+$/, '') + '.txt';
19
+
20
+ // Get allowed directories
21
+ const allowedDir = await getDatasetsRoot();
22
+
23
+ // Security check: Ensure path is in allowed directory
24
+ const isAllowed = filepath.startsWith(allowedDir) && !filepath.includes('..');
25
+
26
+ if (!isAllowed) {
27
+ console.warn(`Access denied: ${filepath} not in ${allowedDir}`);
28
+ return new NextResponse('Access denied', { status: 403 });
29
+ }
30
+
31
+ // Check if file exists
32
+ if (!fs.existsSync(captionPath)) {
33
+ // send back blank string if caption file does not exist
34
+ return new NextResponse('');
35
+ }
36
+
37
+ // Read caption file
38
+ const caption = fs.readFileSync(captionPath, 'utf-8');
39
+
40
+ // Return caption
41
+ return new NextResponse(caption);
42
+ } catch (error) {
43
+ console.error('Error getting caption:', error);
44
+ return new NextResponse('Error getting caption', { status: 500 });
45
+ }
46
+ }
ui/src/app/api/datasets/create/route.tsx ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { NextResponse } from 'next/server';
2
+ import fs from 'fs';
3
+ import path from 'path';
4
+ import { getDatasetsRoot } from '@/server/settings';
5
+
6
+ export async function POST(request: Request) {
7
+ try {
8
+ const body = await request.json();
9
+ let { name } = body;
10
+ // clean name by making lower case, removing special characters, and replacing spaces with underscores
11
+ name = name.toLowerCase().replace(/[^a-z0-9]+/g, '_');
12
+
13
+ let datasetsPath = await getDatasetsRoot();
14
+ let datasetPath = path.join(datasetsPath, name);
15
+
16
+ // if folder doesnt exist, create it
17
+ if (!fs.existsSync(datasetPath)) {
18
+ fs.mkdirSync(datasetPath);
19
+ }
20
+
21
+ return NextResponse.json({ success: true, name: name, path: datasetPath });
22
+ } catch (error) {
23
+ return NextResponse.json({ error: 'Failed to create dataset' }, { status: 500 });
24
+ }
25
+ }
ui/src/app/api/datasets/delete/route.tsx ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { NextResponse } from 'next/server';
2
+ import fs from 'fs';
3
+ import path from 'path';
4
+ import { getDatasetsRoot } from '@/server/settings';
5
+
6
+ export async function POST(request: Request) {
7
+ try {
8
+ const body = await request.json();
9
+ const { name } = body;
10
+ let datasetsPath = await getDatasetsRoot();
11
+ let datasetPath = path.join(datasetsPath, name);
12
+
13
+ // if folder doesnt exist, ignore
14
+ if (!fs.existsSync(datasetPath)) {
15
+ return NextResponse.json({ success: true });
16
+ }
17
+
18
+ // delete it and return success
19
+ fs.rmdirSync(datasetPath, { recursive: true });
20
+ return NextResponse.json({ success: true });
21
+ } catch (error) {
22
+ return NextResponse.json({ error: 'Failed to create dataset' }, { status: 500 });
23
+ }
24
+ }
ui/src/app/api/datasets/list/route.ts ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { NextResponse } from 'next/server';
2
+ import fs from 'fs';
3
+ import { getDatasetsRoot } from '@/server/settings';
4
+
5
+ export async function GET() {
6
+ try {
7
+ let datasetsPath = await getDatasetsRoot();
8
+
9
+ // if folder doesnt exist, create it
10
+ if (!fs.existsSync(datasetsPath)) {
11
+ fs.mkdirSync(datasetsPath);
12
+ }
13
+
14
+ // find all the folders in the datasets folder
15
+ let folders = fs
16
+ .readdirSync(datasetsPath, { withFileTypes: true })
17
+ .filter(dirent => dirent.isDirectory())
18
+ .filter(dirent => !dirent.name.startsWith('.'))
19
+ .map(dirent => dirent.name);
20
+
21
+ return NextResponse.json(folders);
22
+ } catch (error) {
23
+ return NextResponse.json({ error: 'Failed to fetch datasets' }, { status: 500 });
24
+ }
25
+ }
ui/src/app/api/datasets/listImages/route.ts ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { NextResponse } from 'next/server';
2
+ import fs from 'fs';
3
+ import path from 'path';
4
+ import { getDatasetsRoot } from '@/server/settings';
5
+
6
+ export async function POST(request: Request) {
7
+ const datasetsPath = await getDatasetsRoot();
8
+ const body = await request.json();
9
+ const { datasetName } = body;
10
+ const datasetFolder = path.join(datasetsPath, datasetName);
11
+
12
+ try {
13
+ // Check if folder exists
14
+ if (!fs.existsSync(datasetFolder)) {
15
+ return NextResponse.json({ error: `Folder '${datasetName}' not found` }, { status: 404 });
16
+ }
17
+
18
+ // Find all images recursively
19
+ const imageFiles = findImagesRecursively(datasetFolder);
20
+
21
+ // Format response
22
+ const result = imageFiles.map(imgPath => ({
23
+ img_path: imgPath,
24
+ }));
25
+
26
+ return NextResponse.json({ images: result });
27
+ } catch (error) {
28
+ console.error('Error finding images:', error);
29
+ return NextResponse.json({ error: 'Failed to process request' }, { status: 500 });
30
+ }
31
+ }
32
+
33
+ /**
34
+ * Recursively finds all image files in a directory and its subdirectories
35
+ * @param dir Directory to search
36
+ * @returns Array of absolute paths to image files
37
+ */
38
+ function findImagesRecursively(dir: string): string[] {
39
+ const imageExtensions = ['.png', '.jpg', '.jpeg', '.webp', '.mp4', '.avi', '.mov', '.mkv', '.wmv', '.m4v', '.flv'];
40
+ let results: string[] = [];
41
+
42
+ const items = fs.readdirSync(dir);
43
+
44
+ for (const item of items) {
45
+ const itemPath = path.join(dir, item);
46
+ const stat = fs.statSync(itemPath);
47
+
48
+ if (stat.isDirectory() && item !== '_controls' && !item.startsWith('.')) {
49
+ // If it's a directory, recursively search it
50
+ results = results.concat(findImagesRecursively(itemPath));
51
+ } else {
52
+ // If it's a file, check if it's an image
53
+ const ext = path.extname(itemPath).toLowerCase();
54
+ if (imageExtensions.includes(ext)) {
55
+ results.push(itemPath);
56
+ }
57
+ }
58
+ }
59
+
60
+ return results;
61
+ }
ui/src/app/api/datasets/upload/route.ts ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // src/app/api/datasets/upload/route.ts
2
+ import { NextRequest, NextResponse } from 'next/server';
3
+ import { writeFile, mkdir } from 'fs/promises';
4
+ import { join } from 'path';
5
+ import { getDatasetsRoot } from '@/server/settings';
6
+
7
+ export async function POST(request: NextRequest) {
8
+ try {
9
+ const datasetsPath = await getDatasetsRoot();
10
+ if (!datasetsPath) {
11
+ return NextResponse.json({ error: 'Datasets path not found' }, { status: 500 });
12
+ }
13
+ const formData = await request.formData();
14
+ const files = formData.getAll('files');
15
+ const datasetName = formData.get('datasetName') as string;
16
+
17
+ if (!files || files.length === 0) {
18
+ return NextResponse.json({ error: 'No files provided' }, { status: 400 });
19
+ }
20
+
21
+ // Create upload directory if it doesn't exist
22
+ const uploadDir = join(datasetsPath, datasetName);
23
+ await mkdir(uploadDir, { recursive: true });
24
+
25
+ const savedFiles: string[] = [];
26
+
27
+ // Process files sequentially to avoid overwhelming the system
28
+ for (let i = 0; i < files.length; i++) {
29
+ const file = files[i] as any;
30
+ const bytes = await file.arrayBuffer();
31
+ const buffer = Buffer.from(bytes);
32
+
33
+ // Clean filename and ensure it's unique
34
+ const fileName = file.name.replace(/[^a-zA-Z0-9.-]/g, '_');
35
+ const filePath = join(uploadDir, fileName);
36
+
37
+ await writeFile(filePath, buffer);
38
+ savedFiles.push(fileName);
39
+ }
40
+
41
+ return NextResponse.json({
42
+ message: 'Files uploaded successfully',
43
+ files: savedFiles,
44
+ });
45
+ } catch (error) {
46
+ console.error('Upload error:', error);
47
+ return NextResponse.json({ error: 'Error uploading files' }, { status: 500 });
48
+ }
49
+ }
50
+
51
+ // Increase payload size limit (default is 4mb)
52
+ export const config = {
53
+ api: {
54
+ bodyParser: false,
55
+ responseLimit: '50mb',
56
+ },
57
+ };
ui/src/app/api/files/[...filePath]/route.ts ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* eslint-disable */
2
+ import { NextRequest, NextResponse } from 'next/server';
3
+ import fs from 'fs';
4
+ import path from 'path';
5
+ import { getDatasetsRoot, getTrainingFolder } from '@/server/settings';
6
+
7
+ export async function GET(request: NextRequest, { params }: { params: { filePath: string } }) {
8
+ const { filePath } = await params;
9
+ try {
10
+ // Decode the path
11
+ const decodedFilePath = decodeURIComponent(filePath);
12
+
13
+ // Get allowed directories
14
+ const datasetRoot = await getDatasetsRoot();
15
+ const trainingRoot = await getTrainingFolder();
16
+ const allowedDirs = [datasetRoot, trainingRoot];
17
+
18
+ // Security check: Ensure path is in allowed directory
19
+ const isAllowed =
20
+ allowedDirs.some(allowedDir => decodedFilePath.startsWith(allowedDir)) && !decodedFilePath.includes('..');
21
+
22
+ if (!isAllowed) {
23
+ console.warn(`Access denied: ${decodedFilePath} not in ${allowedDirs.join(', ')}`);
24
+ return new NextResponse('Access denied', { status: 403 });
25
+ }
26
+
27
+ // Check if file exists
28
+ if (!fs.existsSync(decodedFilePath)) {
29
+ console.warn(`File not found: ${decodedFilePath}`);
30
+ return new NextResponse('File not found', { status: 404 });
31
+ }
32
+
33
+ // Get file info
34
+ const stat = fs.statSync(decodedFilePath);
35
+ if (!stat.isFile()) {
36
+ return new NextResponse('Not a file', { status: 400 });
37
+ }
38
+
39
+ // Get filename for Content-Disposition
40
+ const filename = path.basename(decodedFilePath);
41
+
42
+ // Determine content type
43
+ const ext = path.extname(decodedFilePath).toLowerCase();
44
+ const contentTypeMap: { [key: string]: string } = {
45
+ '.jpg': 'image/jpeg',
46
+ '.jpeg': 'image/jpeg',
47
+ '.png': 'image/png',
48
+ '.gif': 'image/gif',
49
+ '.webp': 'image/webp',
50
+ '.svg': 'image/svg+xml',
51
+ '.bmp': 'image/bmp',
52
+ '.safetensors': 'application/octet-stream',
53
+ '.zip': 'application/zip',
54
+ // Videos
55
+ '.mp4': 'video/mp4',
56
+ '.avi': 'video/x-msvideo',
57
+ '.mov': 'video/quicktime',
58
+ '.mkv': 'video/x-matroska',
59
+ '.wmv': 'video/x-ms-wmv',
60
+ '.m4v': 'video/x-m4v',
61
+ '.flv': 'video/x-flv'
62
+ };
63
+
64
+ const contentType = contentTypeMap[ext] || 'application/octet-stream';
65
+
66
+ // Get range header for partial content support
67
+ const range = request.headers.get('range');
68
+
69
+ // Common headers for better download handling
70
+ const commonHeaders = {
71
+ 'Content-Type': contentType,
72
+ 'Accept-Ranges': 'bytes',
73
+ 'Cache-Control': 'public, max-age=86400',
74
+ 'Content-Disposition': `attachment; filename="${encodeURIComponent(filename)}"`,
75
+ 'X-Content-Type-Options': 'nosniff',
76
+ };
77
+
78
+ if (range) {
79
+ // Parse range header
80
+ const parts = range.replace(/bytes=/, '').split('-');
81
+ const start = parseInt(parts[0], 10);
82
+ const end = parts[1] ? parseInt(parts[1], 10) : Math.min(start + 10 * 1024 * 1024, stat.size - 1); // 10MB chunks
83
+ const chunkSize = end - start + 1;
84
+
85
+ const fileStream = fs.createReadStream(decodedFilePath, {
86
+ start,
87
+ end,
88
+ highWaterMark: 64 * 1024, // 64KB buffer
89
+ });
90
+
91
+ return new NextResponse(fileStream as any, {
92
+ status: 206,
93
+ headers: {
94
+ ...commonHeaders,
95
+ 'Content-Range': `bytes ${start}-${end}/${stat.size}`,
96
+ 'Content-Length': String(chunkSize),
97
+ },
98
+ });
99
+ } else {
100
+ // For full file download, read directly without streaming wrapper
101
+ const fileStream = fs.createReadStream(decodedFilePath, {
102
+ highWaterMark: 64 * 1024, // 64KB buffer
103
+ });
104
+
105
+ return new NextResponse(fileStream as any, {
106
+ headers: {
107
+ ...commonHeaders,
108
+ 'Content-Length': String(stat.size),
109
+ },
110
+ });
111
+ }
112
+ } catch (error) {
113
+ console.error('Error serving file:', error);
114
+ return new NextResponse('Internal Server Error', { status: 500 });
115
+ }
116
+ }
ui/src/app/api/gpu/route.ts ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { NextResponse } from 'next/server';
2
+ import { exec } from 'child_process';
3
+ import { promisify } from 'util';
4
+ import os from 'os';
5
+
6
+ const execAsync = promisify(exec);
7
+
8
+ export async function GET() {
9
+ try {
10
+ // Get platform
11
+ const platform = os.platform();
12
+ const isWindows = platform === 'win32';
13
+
14
+ // Check if nvidia-smi is available
15
+ const hasNvidiaSmi = await checkNvidiaSmi(isWindows);
16
+
17
+ if (!hasNvidiaSmi) {
18
+ return NextResponse.json({
19
+ hasNvidiaSmi: false,
20
+ gpus: [],
21
+ error: 'nvidia-smi not found or not accessible',
22
+ });
23
+ }
24
+
25
+ // Get GPU stats
26
+ const gpuStats = await getGpuStats(isWindows);
27
+
28
+ return NextResponse.json({
29
+ hasNvidiaSmi: true,
30
+ gpus: gpuStats,
31
+ });
32
+ } catch (error) {
33
+ console.error('Error fetching NVIDIA GPU stats:', error);
34
+ return NextResponse.json(
35
+ {
36
+ hasNvidiaSmi: false,
37
+ gpus: [],
38
+ error: `Failed to fetch GPU stats: ${error instanceof Error ? error.message : String(error)}`,
39
+ },
40
+ { status: 500 },
41
+ );
42
+ }
43
+ }
44
+
45
+ async function checkNvidiaSmi(isWindows: boolean): Promise<boolean> {
46
+ try {
47
+ if (isWindows) {
48
+ // Check if nvidia-smi is available on Windows
49
+ // It's typically located in C:\Program Files\NVIDIA Corporation\NVSMI\nvidia-smi.exe
50
+ // but we'll just try to run it directly as it may be in PATH
51
+ await execAsync('nvidia-smi -L');
52
+ } else {
53
+ // Linux/macOS check
54
+ await execAsync('which nvidia-smi');
55
+ }
56
+ return true;
57
+ } catch (error) {
58
+ return false;
59
+ }
60
+ }
61
+
62
+ async function getGpuStats(isWindows: boolean) {
63
+ // Command is the same for both platforms, but the path might be different
64
+ const command =
65
+ 'nvidia-smi --query-gpu=index,name,driver_version,temperature.gpu,utilization.gpu,utilization.memory,memory.total,memory.free,memory.used,power.draw,power.limit,clocks.current.graphics,clocks.current.memory,fan.speed --format=csv,noheader,nounits';
66
+
67
+ // Execute command
68
+ const { stdout } = await execAsync(command);
69
+
70
+ // Parse CSV output
71
+ const gpus = stdout
72
+ .trim()
73
+ .split('\n')
74
+ .map(line => {
75
+ const [
76
+ index,
77
+ name,
78
+ driverVersion,
79
+ temperature,
80
+ gpuUtil,
81
+ memoryUtil,
82
+ memoryTotal,
83
+ memoryFree,
84
+ memoryUsed,
85
+ powerDraw,
86
+ powerLimit,
87
+ clockGraphics,
88
+ clockMemory,
89
+ fanSpeed,
90
+ ] = line.split(', ').map(item => item.trim());
91
+
92
+ return {
93
+ index: parseInt(index),
94
+ name,
95
+ driverVersion,
96
+ temperature: parseInt(temperature),
97
+ utilization: {
98
+ gpu: parseInt(gpuUtil),
99
+ memory: parseInt(memoryUtil),
100
+ },
101
+ memory: {
102
+ total: parseInt(memoryTotal),
103
+ free: parseInt(memoryFree),
104
+ used: parseInt(memoryUsed),
105
+ },
106
+ power: {
107
+ draw: parseFloat(powerDraw),
108
+ limit: parseFloat(powerLimit),
109
+ },
110
+ clocks: {
111
+ graphics: parseInt(clockGraphics),
112
+ memory: parseInt(clockMemory),
113
+ },
114
+ fan: {
115
+ speed: parseInt(fanSpeed) || 0, // Some GPUs might not report fan speed, default to 0
116
+ },
117
+ };
118
+ });
119
+
120
+ return gpus;
121
+ }
ui/src/app/api/hf-hub/route.ts ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { NextRequest, NextResponse } from 'next/server';
2
+ import { whoAmI, createRepo, uploadFiles, datasetInfo } from '@huggingface/hub';
3
+ import { readdir, stat } from 'fs/promises';
4
+ import path from 'path';
5
+
6
+ export async function POST(request: NextRequest) {
7
+ try {
8
+ const body = await request.json();
9
+ const { action, token, namespace, datasetName, datasetPath, datasetId } = body;
10
+
11
+ if (!token) {
12
+ return NextResponse.json({ error: 'HF token is required' }, { status: 400 });
13
+ }
14
+
15
+ switch (action) {
16
+ case 'whoami':
17
+ try {
18
+ const user = await whoAmI({ accessToken: token });
19
+ return NextResponse.json({ user });
20
+ } catch (error) {
21
+ return NextResponse.json({ error: 'Invalid token or network error' }, { status: 401 });
22
+ }
23
+
24
+ case 'createDataset':
25
+ try {
26
+ if (!namespace || !datasetName) {
27
+ return NextResponse.json({ error: 'Namespace and dataset name required' }, { status: 400 });
28
+ }
29
+
30
+ const repoId = `datasets/${namespace}/${datasetName}`;
31
+
32
+ // Create repository
33
+ await createRepo({
34
+ repo: repoId,
35
+ accessToken: token,
36
+ private: false,
37
+ });
38
+
39
+ return NextResponse.json({ success: true, repoId });
40
+ } catch (error: any) {
41
+ if (error.message?.includes('already exists')) {
42
+ return NextResponse.json({ success: true, repoId: `${namespace}/${datasetName}`, exists: true });
43
+ }
44
+ return NextResponse.json({ error: error.message || 'Failed to create dataset' }, { status: 500 });
45
+ }
46
+
47
+ case 'uploadDataset':
48
+ try {
49
+ if (!namespace || !datasetName || !datasetPath) {
50
+ return NextResponse.json({ error: 'Missing required parameters' }, { status: 400 });
51
+ }
52
+
53
+ const repoId = `datasets/${namespace}/${datasetName}`;
54
+
55
+ // Check if directory exists
56
+ try {
57
+ await stat(datasetPath);
58
+ } catch {
59
+ return NextResponse.json({ error: 'Dataset path does not exist' }, { status: 400 });
60
+ }
61
+
62
+ // Read files from directory and upload them
63
+ const files = await readdir(datasetPath);
64
+ const filesToUpload = [];
65
+
66
+ for (const fileName of files) {
67
+ const filePath = path.join(datasetPath, fileName);
68
+ const fileStats = await stat(filePath);
69
+
70
+ if (fileStats.isFile()) {
71
+ filesToUpload.push({
72
+ path: fileName,
73
+ content: new URL(`file://${filePath}`)
74
+ });
75
+ }
76
+ }
77
+
78
+ if (filesToUpload.length > 0) {
79
+ await uploadFiles({
80
+ repo: repoId,
81
+ accessToken: token,
82
+ files: filesToUpload,
83
+ });
84
+ }
85
+
86
+ return NextResponse.json({ success: true, repoId });
87
+ } catch (error: any) {
88
+ console.error('Upload error:', error);
89
+ return NextResponse.json({ error: error.message || 'Failed to upload dataset' }, { status: 500 });
90
+ }
91
+
92
+ case 'listFiles':
93
+ try {
94
+ if (!datasetPath) {
95
+ return NextResponse.json({ error: 'Dataset path required' }, { status: 400 });
96
+ }
97
+
98
+ const files = await readdir(datasetPath, { withFileTypes: true });
99
+ const imageExtensions = ['.jpg', '.jpeg', '.png', '.webp', '.bmp'];
100
+
101
+ const imageFiles = files
102
+ .filter(file => file.isFile())
103
+ .filter(file => imageExtensions.some(ext => file.name.toLowerCase().endsWith(ext)))
104
+ .map(file => ({
105
+ name: file.name,
106
+ path: path.join(datasetPath, file.name),
107
+ }));
108
+
109
+ const captionFiles = files
110
+ .filter(file => file.isFile())
111
+ .filter(file => file.name.endsWith('.txt'))
112
+ .map(file => ({
113
+ name: file.name,
114
+ path: path.join(datasetPath, file.name),
115
+ }));
116
+
117
+ return NextResponse.json({
118
+ images: imageFiles,
119
+ captions: captionFiles,
120
+ total: imageFiles.length
121
+ });
122
+ } catch (error: any) {
123
+ return NextResponse.json({ error: error.message || 'Failed to list files' }, { status: 500 });
124
+ }
125
+
126
+ case 'validateDataset':
127
+ try {
128
+ if (!datasetId) {
129
+ return NextResponse.json({ error: 'Dataset ID required' }, { status: 400 });
130
+ }
131
+
132
+ // Try to get dataset info to validate it exists and is accessible
133
+ const dataset = await datasetInfo({
134
+ name: datasetId,
135
+ accessToken: token,
136
+ });
137
+
138
+ return NextResponse.json({
139
+ exists: true,
140
+ dataset: {
141
+ id: dataset.id,
142
+ author: dataset.author,
143
+ downloads: dataset.downloads,
144
+ likes: dataset.likes,
145
+ private: dataset.private,
146
+ }
147
+ });
148
+ } catch (error: any) {
149
+ if (error.message?.includes('404') || error.message?.includes('not found')) {
150
+ return NextResponse.json({ exists: false }, { status: 200 });
151
+ }
152
+ if (error.message?.includes('401') || error.message?.includes('403')) {
153
+ return NextResponse.json({ error: 'Dataset not accessible with current token' }, { status: 403 });
154
+ }
155
+ return NextResponse.json({ error: error.message || 'Failed to validate dataset' }, { status: 500 });
156
+ }
157
+
158
+ default:
159
+ return NextResponse.json({ error: 'Invalid action' }, { status: 400 });
160
+ }
161
+ } catch (error: any) {
162
+ console.error('HF Hub API error:', error);
163
+ return NextResponse.json({ error: error.message || 'Internal server error' }, { status: 500 });
164
+ }
165
+ }
ui/src/app/api/hf-jobs/route.ts ADDED
@@ -0,0 +1,761 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { NextRequest, NextResponse } from 'next/server';
2
+ import { spawn } from 'child_process';
3
+ import { writeFile } from 'fs/promises';
4
+ import path from 'path';
5
+ import { tmpdir } from 'os';
6
+
7
+ export async function POST(request: NextRequest) {
8
+ try {
9
+ const body = await request.json();
10
+ const { action, token, hardware, namespace, jobConfig, datasetRepo } = body;
11
+
12
+ switch (action) {
13
+ case 'checkStatus':
14
+ try {
15
+ if (!token || !jobConfig?.hf_job_id) {
16
+ return NextResponse.json({ error: 'Token and job ID required' }, { status: 400 });
17
+ }
18
+
19
+ const jobStatus = await checkHFJobStatus(token, jobConfig.hf_job_id);
20
+ return NextResponse.json({ status: jobStatus });
21
+ } catch (error: any) {
22
+ console.error('Job status check error:', error);
23
+ return NextResponse.json({ error: error.message }, { status: 500 });
24
+ }
25
+
26
+ case 'generateScript':
27
+ try {
28
+ const uvScript = generateUVScript({
29
+ jobConfig,
30
+ datasetRepo,
31
+ namespace,
32
+ token: token || 'YOUR_HF_TOKEN',
33
+ });
34
+
35
+ return NextResponse.json({
36
+ script: uvScript,
37
+ filename: `train_${jobConfig.config.name.replace(/[^a-zA-Z0-9]/g, '_')}.py`
38
+ });
39
+ } catch (error: any) {
40
+ return NextResponse.json({ error: error.message }, { status: 500 });
41
+ }
42
+
43
+ case 'submitJob':
44
+ try {
45
+ if (!token || !hardware) {
46
+ return NextResponse.json({ error: 'Token and hardware required' }, { status: 400 });
47
+ }
48
+
49
+ // Generate UV script
50
+ const uvScript = generateUVScript({
51
+ jobConfig,
52
+ datasetRepo,
53
+ namespace,
54
+ token,
55
+ });
56
+
57
+ // Write script to temporary file
58
+ const scriptPath = path.join(tmpdir(), `train_${Date.now()}.py`);
59
+ await writeFile(scriptPath, uvScript);
60
+
61
+ // Submit HF job using uv run
62
+ const jobId = await submitHFJobUV(token, hardware, scriptPath);
63
+
64
+ return NextResponse.json({
65
+ success: true,
66
+ jobId,
67
+ message: `Job submitted successfully with ID: ${jobId}`
68
+ });
69
+ } catch (error: any) {
70
+ console.error('Job submission error:', error);
71
+ return NextResponse.json({ error: error.message }, { status: 500 });
72
+ }
73
+
74
+ default:
75
+ return NextResponse.json({ error: 'Invalid action' }, { status: 400 });
76
+ }
77
+ } catch (error: any) {
78
+ console.error('HF Jobs API error:', error);
79
+ return NextResponse.json({ error: error.message }, { status: 500 });
80
+ }
81
+ }
82
+
83
+ function generateUVScript({ jobConfig, datasetRepo, namespace, token }: {
84
+ jobConfig: any;
85
+ datasetRepo: string;
86
+ namespace: string;
87
+ token: string;
88
+ }) {
89
+ const config = jobConfig.config;
90
+ const process = config.process[0];
91
+
92
+ return `# /// script
93
+ # dependencies = [
94
+ # "torch>=2.0.0",
95
+ # "torchvision",
96
+ # "torchao==0.10.0",
97
+ # "safetensors",
98
+ # "diffusers @ git+https://github.com/huggingface/diffusers@7a2b78bf0f788d311cc96b61e660a8e13e3b1e63",
99
+ # "transformers==4.52.4",
100
+ # "lycoris-lora==1.8.3",
101
+ # "flatten_json",
102
+ # "pyyaml",
103
+ # "oyaml",
104
+ # "tensorboard",
105
+ # "kornia",
106
+ # "invisible-watermark",
107
+ # "einops",
108
+ # "accelerate",
109
+ # "toml",
110
+ # "albumentations==1.4.15",
111
+ # "albucore==0.0.16",
112
+ # "pydantic",
113
+ # "omegaconf",
114
+ # "k-diffusion",
115
+ # "open_clip_torch",
116
+ # "timm",
117
+ # "prodigyopt",
118
+ # "controlnet_aux==0.0.10",
119
+ # "python-dotenv",
120
+ # "bitsandbytes",
121
+ # "hf_transfer",
122
+ # "lpips",
123
+ # "pytorch_fid",
124
+ # "optimum-quanto==0.2.4",
125
+ # "sentencepiece",
126
+ # "huggingface_hub",
127
+ # "peft",
128
+ # "python-slugify",
129
+ # "opencv-python-headless",
130
+ # "pytorch-wavelets==1.3.0",
131
+ # "matplotlib==3.10.1",
132
+ # "setuptools==69.5.1",
133
+ # "datasets==4.0.0",
134
+ # "pyarrow==20.0.0",
135
+ # "pillow",
136
+ # "ftfy",
137
+ # ]
138
+ # ///
139
+
140
+ import os
141
+ import sys
142
+ import subprocess
143
+ import argparse
144
+ import oyaml as yaml
145
+ from datasets import load_dataset
146
+ from huggingface_hub import HfApi, create_repo, upload_folder, snapshot_download
147
+ import tempfile
148
+ import shutil
149
+ import glob
150
+ from PIL import Image
151
+
152
+ def setup_ai_toolkit():
153
+ """Clone and setup ai-toolkit repository"""
154
+ repo_dir = "ai-toolkit"
155
+ if not os.path.exists(repo_dir):
156
+ print("Cloning ai-toolkit repository...")
157
+ subprocess.run(
158
+ ["git", "clone", "https://github.com/ostris/ai-toolkit.git", repo_dir],
159
+ check=True
160
+ )
161
+ sys.path.insert(0, os.path.abspath(repo_dir))
162
+ return repo_dir
163
+
164
+ def download_dataset(dataset_repo: str, local_path: str):
165
+ """Download dataset from HF Hub as files"""
166
+ print(f"Downloading dataset from {dataset_repo}...")
167
+
168
+ # Create local dataset directory
169
+ os.makedirs(local_path, exist_ok=True)
170
+
171
+ # Use snapshot_download to get the dataset files directly
172
+ from huggingface_hub import snapshot_download
173
+
174
+ try:
175
+ # First try to download as a structured dataset
176
+ dataset = load_dataset(dataset_repo, split="train")
177
+
178
+ # Download images and captions from structured dataset
179
+ for i, item in enumerate(dataset):
180
+ # Save image
181
+ if "image" in item:
182
+ image_path = os.path.join(local_path, f"image_{i:06d}.jpg")
183
+ image = item["image"]
184
+
185
+ # Convert RGBA to RGB if necessary (for JPEG compatibility)
186
+ if image.mode == 'RGBA':
187
+ # Create a white background and paste the RGBA image on it
188
+ background = Image.new('RGB', image.size, (255, 255, 255))
189
+ background.paste(image, mask=image.split()[-1]) # Use alpha channel as mask
190
+ image = background
191
+ elif image.mode not in ['RGB', 'L']:
192
+ # Convert any other mode to RGB
193
+ image = image.convert('RGB')
194
+
195
+ image.save(image_path, 'JPEG')
196
+
197
+ # Save caption
198
+ if "text" in item:
199
+ caption_path = os.path.join(local_path, f"image_{i:06d}.txt")
200
+ with open(caption_path, "w", encoding="utf-8") as f:
201
+ f.write(item["text"])
202
+
203
+ print(f"Downloaded {len(dataset)} items to {local_path}")
204
+
205
+ except Exception as e:
206
+ print(f"Failed to load as structured dataset: {e}")
207
+ print("Attempting to download raw files...")
208
+
209
+ # Download the dataset repository as files
210
+ temp_repo_path = snapshot_download(repo_id=dataset_repo, repo_type="dataset")
211
+
212
+ # Copy all image and text files to the local path
213
+ import glob
214
+ import shutil
215
+
216
+ print(f"Downloaded repo to: {temp_repo_path}")
217
+ print(f"Contents: {os.listdir(temp_repo_path)}")
218
+
219
+ # Find all image files
220
+ image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.webp', '*.bmp', '*.JPG', '*.JPEG', '*.PNG']
221
+ image_files = []
222
+ for ext in image_extensions:
223
+ pattern = os.path.join(temp_repo_path, "**", ext)
224
+ found_files = glob.glob(pattern, recursive=True)
225
+ image_files.extend(found_files)
226
+ print(f"Pattern {pattern} found {len(found_files)} files")
227
+
228
+ # Find all text files
229
+ text_files = glob.glob(os.path.join(temp_repo_path, "**", "*.txt"), recursive=True)
230
+
231
+ print(f"Found {len(image_files)} image files and {len(text_files)} text files")
232
+
233
+ # Copy image files
234
+ for i, img_file in enumerate(image_files):
235
+ dest_path = os.path.join(local_path, f"image_{i:06d}.jpg")
236
+
237
+ # Load and convert image if needed
238
+ try:
239
+ with Image.open(img_file) as image:
240
+ if image.mode == 'RGBA':
241
+ background = Image.new('RGB', image.size, (255, 255, 255))
242
+ background.paste(image, mask=image.split()[-1])
243
+ image = background
244
+ elif image.mode not in ['RGB', 'L']:
245
+ image = image.convert('RGB')
246
+
247
+ image.save(dest_path, 'JPEG')
248
+ except Exception as img_error:
249
+ print(f"Error processing image {img_file}: {img_error}")
250
+ continue
251
+
252
+ # Copy text files (captions)
253
+ for i, txt_file in enumerate(text_files[:len(image_files)]): # Match number of images
254
+ dest_path = os.path.join(local_path, f"image_{i:06d}.txt")
255
+ try:
256
+ shutil.copy2(txt_file, dest_path)
257
+ except Exception as txt_error:
258
+ print(f"Error copying text file {txt_file}: {txt_error}")
259
+ continue
260
+
261
+ print(f"Downloaded {len(image_files)} images and {len(text_files)} captions to {local_path}")
262
+
263
+ def create_config(dataset_path: str, output_path: str):
264
+ """Create training configuration"""
265
+ import json
266
+
267
+ # Load config from JSON string and fix boolean/null values for Python
268
+ config_str = """${JSON.stringify(jobConfig, null, 2)}"""
269
+ config_str = config_str.replace('true', 'True').replace('false', 'False').replace('null', 'None')
270
+ config = eval(config_str)
271
+
272
+ # Update paths for cloud environment
273
+ config["config"]["process"][0]["datasets"][0]["folder_path"] = dataset_path
274
+ config["config"]["process"][0]["training_folder"] = output_path
275
+
276
+ # Remove sqlite_db_path as it's not needed for cloud training
277
+ if "sqlite_db_path" in config["config"]["process"][0]:
278
+ del config["config"]["process"][0]["sqlite_db_path"]
279
+
280
+ # Also change trainer type from ui_trainer to standard trainer to avoid UI dependencies
281
+ if config["config"]["process"][0]["type"] == "ui_trainer":
282
+ config["config"]["process"][0]["type"] = "sd_trainer"
283
+
284
+ return config
285
+
286
+ def upload_results(output_path: str, model_name: str, namespace: str, token: str, config: dict):
287
+ """Upload trained model to HF Hub with README generation and proper file organization"""
288
+ import tempfile
289
+ import shutil
290
+ import glob
291
+ import re
292
+ import yaml
293
+ from datetime import datetime
294
+ from huggingface_hub import create_repo, upload_file, HfApi
295
+
296
+ try:
297
+ repo_id = f"{namespace}/{model_name}"
298
+
299
+ # Create repository
300
+ create_repo(repo_id=repo_id, token=token, exist_ok=True)
301
+
302
+ print(f"Uploading model to {repo_id}...")
303
+
304
+ # Create temporary directory for organized upload
305
+ with tempfile.TemporaryDirectory() as temp_upload_dir:
306
+ api = HfApi()
307
+
308
+ # 1. Find and upload model files to root directory
309
+ safetensors_files = glob.glob(os.path.join(output_path, "**", "*.safetensors"), recursive=True)
310
+ json_files = glob.glob(os.path.join(output_path, "**", "*.json"), recursive=True)
311
+ txt_files = glob.glob(os.path.join(output_path, "**", "*.txt"), recursive=True)
312
+
313
+ uploaded_files = []
314
+
315
+ # Upload .safetensors files to root
316
+ for file_path in safetensors_files:
317
+ filename = os.path.basename(file_path)
318
+ print(f"Uploading {filename} to repository root...")
319
+ api.upload_file(
320
+ path_or_fileobj=file_path,
321
+ path_in_repo=filename,
322
+ repo_id=repo_id,
323
+ token=token
324
+ )
325
+ uploaded_files.append(filename)
326
+
327
+ # Upload relevant JSON config files to root (skip metadata.json and other internal files)
328
+ config_files_uploaded = []
329
+ for file_path in json_files:
330
+ filename = os.path.basename(file_path)
331
+ # Only upload important config files, skip internal metadata
332
+ if any(keyword in filename.lower() for keyword in ['config', 'adapter', 'lora', 'model']):
333
+ print(f"Uploading {filename} to repository root...")
334
+ api.upload_file(
335
+ path_or_fileobj=file_path,
336
+ path_in_repo=filename,
337
+ repo_id=repo_id,
338
+ token=token
339
+ )
340
+ uploaded_files.append(filename)
341
+ config_files_uploaded.append(filename)
342
+
343
+ # 2. Handle sample images
344
+ samples_uploaded = []
345
+ samples_dir = os.path.join(output_path, "samples")
346
+ if os.path.isdir(samples_dir):
347
+ print("Uploading sample images...")
348
+ # Create samples directory in repo
349
+ for filename in os.listdir(samples_dir):
350
+ if filename.lower().endswith(('.jpg', '.jpeg', '.png', '.webp')):
351
+ file_path = os.path.join(samples_dir, filename)
352
+ repo_path = f"samples/{filename}"
353
+ api.upload_file(
354
+ path_or_fileobj=file_path,
355
+ path_in_repo=repo_path,
356
+ repo_id=repo_id,
357
+ token=token
358
+ )
359
+ samples_uploaded.append(repo_path)
360
+
361
+ # 3. Generate and upload README.md
362
+ readme_content = generate_model_card_readme(
363
+ repo_id=repo_id,
364
+ config=config,
365
+ model_name=model_name,
366
+ samples_dir=samples_dir if os.path.isdir(samples_dir) else None,
367
+ uploaded_files=uploaded_files
368
+ )
369
+
370
+ # Create README.md file and upload to root
371
+ readme_path = os.path.join(temp_upload_dir, "README.md")
372
+ with open(readme_path, "w", encoding="utf-8") as f:
373
+ f.write(readme_content)
374
+
375
+ print("Uploading README.md to repository root...")
376
+ api.upload_file(
377
+ path_or_fileobj=readme_path,
378
+ path_in_repo="README.md",
379
+ repo_id=repo_id,
380
+ token=token
381
+ )
382
+
383
+ print(f"Model uploaded successfully to https://huggingface.co/{repo_id}")
384
+ print(f"Files uploaded: {len(uploaded_files)} model files, {len(samples_uploaded)} samples, README.md")
385
+
386
+ except Exception as e:
387
+ print(f"Failed to upload model: {e}")
388
+ raise e
389
+
390
+ def generate_model_card_readme(repo_id: str, config: dict, model_name: str, samples_dir: str = None, uploaded_files: list = None) -> str:
391
+ """Generate README.md content for the model card based on AI Toolkit's implementation"""
392
+ import re
393
+ import yaml
394
+ import os
395
+
396
+ try:
397
+ # Extract configuration details
398
+ process_config = config.get("config", {}).get("process", [{}])[0]
399
+ model_config = process_config.get("model", {})
400
+ train_config = process_config.get("train", {})
401
+ sample_config = process_config.get("sample", {})
402
+
403
+ # Gather model info
404
+ base_model = model_config.get("name_or_path", "unknown")
405
+ trigger_word = process_config.get("trigger_word")
406
+ arch = model_config.get("arch", "")
407
+
408
+ # Determine license based on base model
409
+ if "FLUX.1-schnell" in base_model:
410
+ license_info = {"license": "apache-2.0"}
411
+ elif "FLUX.1-dev" in base_model:
412
+ license_info = {
413
+ "license": "other",
414
+ "license_name": "flux-1-dev-non-commercial-license",
415
+ "license_link": "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md"
416
+ }
417
+ else:
418
+ license_info = {"license": "creativeml-openrail-m"}
419
+
420
+ # Generate tags based on model architecture
421
+ tags = ["text-to-image"]
422
+
423
+ if "xl" in arch.lower():
424
+ tags.append("stable-diffusion-xl")
425
+ if "flux" in arch.lower():
426
+ tags.append("flux")
427
+ if "lumina" in arch.lower():
428
+ tags.append("lumina2")
429
+ if "sd3" in arch.lower() or "v3" in arch.lower():
430
+ tags.append("sd3")
431
+
432
+ # Add LoRA-specific tags
433
+ tags.extend(["lora", "diffusers", "template:sd-lora", "ai-toolkit"])
434
+
435
+ # Generate widgets from sample images and prompts
436
+ widgets = []
437
+ if samples_dir and os.path.isdir(samples_dir):
438
+ sample_prompts = sample_config.get("samples", [])
439
+ if not sample_prompts:
440
+ # Fallback to old format
441
+ sample_prompts = [{"prompt": p} for p in sample_config.get("prompts", [])]
442
+
443
+ # Get sample image files
444
+ sample_files = []
445
+ if os.path.isdir(samples_dir):
446
+ for filename in os.listdir(samples_dir):
447
+ if filename.lower().endswith(('.jpg', '.jpeg', '.png', '.webp')):
448
+ # Parse filename pattern: timestamp__steps_index.jpg
449
+ match = re.search(r"__(\d+)_(\d+)\.jpg$", filename)
450
+ if match:
451
+ steps, index = int(match.group(1)), int(match.group(2))
452
+ # Only use samples from final training step
453
+ final_steps = train_config.get("steps", 1000)
454
+ if steps == final_steps:
455
+ sample_files.append((index, f"samples/{filename}"))
456
+
457
+ # Sort by index and create widgets
458
+ sample_files.sort(key=lambda x: x[0])
459
+
460
+ for i, prompt_obj in enumerate(sample_prompts):
461
+ prompt = prompt_obj.get("prompt", "") if isinstance(prompt_obj, dict) else str(prompt_obj)
462
+ if i < len(sample_files):
463
+ _, image_path = sample_files[i]
464
+ widgets.append({
465
+ "text": prompt,
466
+ "output": {"url": image_path}
467
+ })
468
+
469
+ # Determine torch dtype based on model
470
+ dtype = "torch.bfloat16" if "flux" in arch.lower() else "torch.float16"
471
+
472
+ # Find the main safetensors file for usage example
473
+ main_safetensors = f"{model_name}.safetensors"
474
+ if uploaded_files:
475
+ safetensors_files = [f for f in uploaded_files if f.endswith('.safetensors')]
476
+ if safetensors_files:
477
+ main_safetensors = safetensors_files[0]
478
+
479
+ # Construct YAML frontmatter
480
+ frontmatter = {
481
+ "tags": tags,
482
+ "base_model": base_model,
483
+ **license_info
484
+ }
485
+
486
+ if widgets:
487
+ frontmatter["widget"] = widgets
488
+
489
+ if trigger_word:
490
+ frontmatter["instance_prompt"] = trigger_word
491
+
492
+ # Get first prompt for usage example
493
+ usage_prompt = trigger_word or "a beautiful landscape"
494
+ if widgets:
495
+ usage_prompt = widgets[0]["text"]
496
+ elif trigger_word:
497
+ usage_prompt = trigger_word
498
+
499
+ # Construct README content
500
+ trigger_section = f"You should use \`{trigger_word}\` to trigger the image generation." if trigger_word else "No trigger words defined."
501
+
502
+ # Build YAML frontmatter string
503
+ frontmatter_yaml = yaml.dump(frontmatter, default_flow_style=False, allow_unicode=True, sort_keys=False).strip()
504
+
505
+ readme_content = f"""---
506
+ {frontmatter_yaml}
507
+ ---
508
+
509
+ # {model_name}
510
+
511
+ Model trained with [AI Toolkit by Ostris](https://github.com/ostris/ai-toolkit)
512
+
513
+ <Gallery />
514
+
515
+ ## Trigger words
516
+
517
+ {trigger_section}
518
+
519
+ ## Download model and use it with ComfyUI, AUTOMATIC1111, SD.Next, Invoke AI, etc.
520
+
521
+ Weights for this model are available in Safetensors format.
522
+
523
+ [Download]({repo_id}/tree/main) them in the Files & versions tab.
524
+
525
+ ## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers)
526
+
527
+ \`\`\`py
528
+ from diffusers import AutoPipelineForText2Image
529
+ import torch
530
+
531
+ pipeline = AutoPipelineForText2Image.from_pretrained('{base_model}', torch_dtype={dtype}).to('cuda')
532
+ pipeline.load_lora_weights('{repo_id}', weight_name='{main_safetensors}')
533
+ image = pipeline('{usage_prompt}').images[0]
534
+ image.save("my_image.png")
535
+ \`\`\`
536
+
537
+ For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters)
538
+
539
+ """
540
+ return readme_content
541
+
542
+ except Exception as e:
543
+ print(f"Error generating README: {e}")
544
+ # Fallback simple README
545
+ return f"""# {model_name}
546
+
547
+ Model trained with [AI Toolkit by Ostris](https://github.com/ostris/ai-toolkit)
548
+
549
+ ## Download model
550
+
551
+ Weights for this model are available in Safetensors format.
552
+
553
+ [Download]({repo_id}/tree/main) them in the Files & versions tab.
554
+ """
555
+
556
+ def main():
557
+ # Setup environment - token comes from HF Jobs secrets
558
+ if "HF_TOKEN" not in os.environ:
559
+ raise ValueError("HF_TOKEN environment variable not set")
560
+
561
+ # Install system dependencies for headless operation
562
+ print("Installing system dependencies...")
563
+ try:
564
+ subprocess.run(["apt-get", "update"], check=True, capture_output=True)
565
+ subprocess.run([
566
+ "apt-get", "install", "-y",
567
+ "libgl1-mesa-glx",
568
+ "libglib2.0-0",
569
+ "libsm6",
570
+ "libxext6",
571
+ "libxrender-dev",
572
+ "libgomp1",
573
+ "ffmpeg"
574
+ ], check=True, capture_output=True)
575
+ print("System dependencies installed successfully")
576
+ except subprocess.CalledProcessError as e:
577
+ print(f"Failed to install system dependencies: {e}")
578
+ print("Continuing without system dependencies...")
579
+
580
+ # Setup ai-toolkit
581
+ toolkit_dir = setup_ai_toolkit()
582
+
583
+ # Create temporary directories
584
+ with tempfile.TemporaryDirectory() as temp_dir:
585
+ dataset_path = os.path.join(temp_dir, "dataset")
586
+ output_path = os.path.join(temp_dir, "output")
587
+
588
+ # Download dataset
589
+ download_dataset("${datasetRepo}", dataset_path)
590
+
591
+ # Create config
592
+ config = create_config(dataset_path, output_path)
593
+ config_path = os.path.join(temp_dir, "config.yaml")
594
+
595
+ with open(config_path, "w") as f:
596
+ yaml.dump(config, f, default_flow_style=False)
597
+
598
+ # Run training
599
+ print("Starting training...")
600
+ os.chdir(toolkit_dir)
601
+
602
+ subprocess.run([
603
+ sys.executable, "run.py",
604
+ config_path
605
+ ], check=True)
606
+
607
+ print("Training completed!")
608
+
609
+ # Upload results
610
+ model_name = f"${jobConfig.config.name}-lora"
611
+ upload_results(output_path, model_name, "${namespace}", os.environ["HF_TOKEN"], config)
612
+
613
+ if __name__ == "__main__":
614
+ main()
615
+ `;
616
+ }
617
+
618
+ async function submitHFJobUV(token: string, hardware: string, scriptPath: string): Promise<string> {
619
+ return new Promise((resolve, reject) => {
620
+ // Ensure token is available
621
+ if (!token) {
622
+ reject(new Error('HF_TOKEN is required'));
623
+ return;
624
+ }
625
+
626
+ console.log('Setting up environment with HF_TOKEN for job submission');
627
+ console.log(`Command: hf jobs uv run --flavor ${hardware} --timeout 5h --secrets HF_TOKEN --detach ${scriptPath}`);
628
+
629
+ // Use hf jobs uv run command with timeout and detach to get job ID
630
+ const childProcess = spawn('hf', [
631
+ 'jobs', 'uv', 'run',
632
+ '--flavor', hardware,
633
+ '--timeout', '5h',
634
+ '--secrets', 'HF_TOKEN',
635
+ '--detach',
636
+ scriptPath
637
+ ], {
638
+ env: {
639
+ ...process.env,
640
+ HF_TOKEN: token
641
+ }
642
+ });
643
+
644
+ let output = '';
645
+ let error = '';
646
+
647
+ childProcess.stdout.on('data', (data) => {
648
+ const text = data.toString();
649
+ output += text;
650
+ console.log('HF Jobs stdout:', text);
651
+ });
652
+
653
+ childProcess.stderr.on('data', (data) => {
654
+ const text = data.toString();
655
+ error += text;
656
+ console.log('HF Jobs stderr:', text);
657
+ });
658
+
659
+ childProcess.on('close', (code) => {
660
+ console.log('HF Jobs process closed with code:', code);
661
+ console.log('Full output:', output);
662
+ console.log('Full error:', error);
663
+
664
+ if (code === 0) {
665
+ // With --detach flag, the output should be just the job ID
666
+ const fullText = (output + ' ' + error).trim();
667
+
668
+ // Updated patterns to handle variable-length hex job IDs (16-24+ characters)
669
+ const jobIdPatterns = [
670
+ /Job started with ID:\s*([a-f0-9]{16,})/i, // "Job started with ID: 68b26b73767540db9fc726ac"
671
+ /job\s+([a-f0-9]{16,})/i, // "job 68b26b73767540db9fc726ac"
672
+ /Job ID:\s*([a-f0-9]{16,})/i, // "Job ID: 68b26b73767540db9fc726ac"
673
+ /created\s+job\s+([a-f0-9]{16,})/i, // "created job 68b26b73767540db9fc726ac"
674
+ /submitted.*?job\s+([a-f0-9]{16,})/i, // "submitted ... job 68b26b73767540db9fc726ac"
675
+ /https:\/\/huggingface\.co\/jobs\/[^\/]+\/([a-f0-9]{16,})/i, // URL pattern
676
+ /([a-f0-9]{20,})/i, // Fallback: any 20+ char hex string
677
+ ];
678
+
679
+ let jobId = 'unknown';
680
+
681
+ for (const pattern of jobIdPatterns) {
682
+ const match = fullText.match(pattern);
683
+ if (match && match[1] && match[1] !== 'started') {
684
+ jobId = match[1];
685
+ console.log(`Extracted job ID using pattern: ${pattern.toString()} -> ${jobId}`);
686
+ break;
687
+ }
688
+ }
689
+
690
+ resolve(jobId);
691
+ } else {
692
+ reject(new Error(error || output || 'Failed to submit job'));
693
+ }
694
+ });
695
+
696
+ childProcess.on('error', (err) => {
697
+ console.error('HF Jobs process error:', err);
698
+ reject(new Error(`Process error: ${err.message}`));
699
+ });
700
+ });
701
+ }
702
+
703
+ async function checkHFJobStatus(token: string, jobId: string): Promise<any> {
704
+ return new Promise((resolve, reject) => {
705
+ console.log(`Checking HF Job status for: ${jobId}`);
706
+
707
+ const childProcess = spawn('hf', [
708
+ 'jobs', 'inspect', jobId
709
+ ], {
710
+ env: {
711
+ ...process.env,
712
+ HF_TOKEN: token
713
+ }
714
+ });
715
+
716
+ let output = '';
717
+ let error = '';
718
+
719
+ childProcess.stdout.on('data', (data) => {
720
+ const text = data.toString();
721
+ output += text;
722
+ });
723
+
724
+ childProcess.stderr.on('data', (data) => {
725
+ const text = data.toString();
726
+ error += text;
727
+ });
728
+
729
+ childProcess.on('close', (code) => {
730
+ if (code === 0) {
731
+ try {
732
+ // Parse the JSON output from hf jobs inspect
733
+ const jobInfo = JSON.parse(output);
734
+ if (Array.isArray(jobInfo) && jobInfo.length > 0) {
735
+ const job = jobInfo[0];
736
+ resolve({
737
+ id: job.id,
738
+ status: job.status?.stage || 'UNKNOWN',
739
+ message: job.status?.message,
740
+ created_at: job.created_at,
741
+ flavor: job.flavor,
742
+ url: job.url,
743
+ });
744
+ } else {
745
+ reject(new Error('Invalid job info response'));
746
+ }
747
+ } catch (parseError: any) {
748
+ console.error('Failed to parse job status:', parseError, output);
749
+ reject(new Error('Failed to parse job status'));
750
+ }
751
+ } else {
752
+ reject(new Error(error || output || 'Failed to check job status'));
753
+ }
754
+ });
755
+
756
+ childProcess.on('error', (err) => {
757
+ console.error('HF Jobs inspect process error:', err);
758
+ reject(new Error(`Process error: ${err.message}`));
759
+ });
760
+ });
761
+ }
ui/src/app/api/img/[...imagePath]/route.ts ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* eslint-disable */
2
+ import { NextRequest, NextResponse } from 'next/server';
3
+ import fs from 'fs';
4
+ import path from 'path';
5
+ import { getDatasetsRoot, getTrainingFolder, getDataRoot } from '@/server/settings';
6
+
7
+ export async function GET(request: NextRequest, { params }: { params: { imagePath: string } }) {
8
+ const { imagePath } = await params;
9
+ try {
10
+ // Decode the path
11
+ const filepath = decodeURIComponent(imagePath);
12
+
13
+ // Get allowed directories
14
+ const datasetRoot = await getDatasetsRoot();
15
+ const trainingRoot = await getTrainingFolder();
16
+ const dataRoot = await getDataRoot();
17
+
18
+ const allowedDirs = [datasetRoot, trainingRoot, dataRoot];
19
+
20
+ // Security check: Ensure path is in allowed directory
21
+ const isAllowed = allowedDirs.some(allowedDir => filepath.startsWith(allowedDir)) && !filepath.includes('..');
22
+
23
+ if (!isAllowed) {
24
+ console.warn(`Access denied: ${filepath} not in ${allowedDirs.join(', ')}`);
25
+ return new NextResponse('Access denied', { status: 403 });
26
+ }
27
+
28
+ // Check if file exists
29
+ if (!fs.existsSync(filepath)) {
30
+ console.warn(`File not found: ${filepath}`);
31
+ return new NextResponse('File not found', { status: 404 });
32
+ }
33
+
34
+ // Get file info
35
+ const stat = fs.statSync(filepath);
36
+ if (!stat.isFile()) {
37
+ return new NextResponse('Not a file', { status: 400 });
38
+ }
39
+
40
+ // Determine content type
41
+ const ext = path.extname(filepath).toLowerCase();
42
+ const contentTypeMap: { [key: string]: string } = {
43
+ // Images
44
+ '.jpg': 'image/jpeg',
45
+ '.jpeg': 'image/jpeg',
46
+ '.png': 'image/png',
47
+ '.gif': 'image/gif',
48
+ '.webp': 'image/webp',
49
+ '.svg': 'image/svg+xml',
50
+ '.bmp': 'image/bmp',
51
+ // Videos
52
+ '.mp4': 'video/mp4',
53
+ '.avi': 'video/x-msvideo',
54
+ '.mov': 'video/quicktime',
55
+ '.mkv': 'video/x-matroska',
56
+ '.wmv': 'video/x-ms-wmv',
57
+ '.m4v': 'video/x-m4v',
58
+ '.flv': 'video/x-flv'
59
+ };
60
+
61
+ const contentType = contentTypeMap[ext] || 'application/octet-stream';
62
+
63
+ // Read file as buffer
64
+ const fileBuffer = fs.readFileSync(filepath);
65
+
66
+ // Return file with appropriate headers
67
+ return new NextResponse(fileBuffer, {
68
+ headers: {
69
+ 'Content-Type': contentType,
70
+ 'Content-Length': String(stat.size),
71
+ 'Cache-Control': 'public, max-age=86400',
72
+ },
73
+ });
74
+ } catch (error) {
75
+ console.error('Error serving image:', error);
76
+ return new NextResponse('Internal Server Error', { status: 500 });
77
+ }
78
+ }
ui/src/app/api/img/caption/route.ts ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { NextResponse } from 'next/server';
2
+ import fs from 'fs';
3
+ import { getDatasetsRoot } from '@/server/settings';
4
+
5
+ export async function POST(request: Request) {
6
+ try {
7
+ const body = await request.json();
8
+ const { imgPath, caption } = body;
9
+ let datasetsPath = await getDatasetsRoot();
10
+ // make sure the dataset path is in the image path
11
+ if (!imgPath.startsWith(datasetsPath)) {
12
+ return NextResponse.json({ error: 'Invalid image path' }, { status: 400 });
13
+ }
14
+
15
+ // if img doesnt exist, ignore
16
+ if (!fs.existsSync(imgPath)) {
17
+ return NextResponse.json({ error: 'Image does not exist' }, { status: 404 });
18
+ }
19
+
20
+ // check for caption
21
+ const captionPath = imgPath.replace(/\.[^/.]+$/, '') + '.txt';
22
+ // save caption to file
23
+ fs.writeFileSync(captionPath, caption);
24
+
25
+ return NextResponse.json({ success: true });
26
+ } catch (error) {
27
+ return NextResponse.json({ error: 'Failed to create dataset' }, { status: 500 });
28
+ }
29
+ }
ui/src/app/api/img/delete/route.ts ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { NextResponse } from 'next/server';
2
+ import fs from 'fs';
3
+ import { getDatasetsRoot } from '@/server/settings';
4
+
5
+ export async function POST(request: Request) {
6
+ try {
7
+ const body = await request.json();
8
+ const { imgPath } = body;
9
+ let datasetsPath = await getDatasetsRoot();
10
+ // make sure the dataset path is in the image path
11
+ if (!imgPath.startsWith(datasetsPath)) {
12
+ return NextResponse.json({ error: 'Invalid image path' }, { status: 400 });
13
+ }
14
+
15
+ // if img doesnt exist, ignore
16
+ if (!fs.existsSync(imgPath)) {
17
+ return NextResponse.json({ success: true });
18
+ }
19
+
20
+ // delete it and return success
21
+ fs.unlinkSync(imgPath);
22
+
23
+ // check for caption
24
+ const captionPath = imgPath.replace(/\.[^/.]+$/, '') + '.txt';
25
+ if (fs.existsSync(captionPath)) {
26
+ // delete caption file
27
+ fs.unlinkSync(captionPath);
28
+ }
29
+
30
+ return NextResponse.json({ success: true });
31
+ } catch (error) {
32
+ return NextResponse.json({ error: 'Failed to create dataset' }, { status: 500 });
33
+ }
34
+ }
ui/src/app/api/img/upload/route.ts ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // src/app/api/datasets/upload/route.ts
2
+ import { NextRequest, NextResponse } from 'next/server';
3
+ import { writeFile, mkdir } from 'fs/promises';
4
+ import { join } from 'path';
5
+ import { getDataRoot } from '@/server/settings';
6
+ import {v4 as uuidv4} from 'uuid';
7
+
8
+ export async function POST(request: NextRequest) {
9
+ try {
10
+ const dataRoot = await getDataRoot();
11
+ if (!dataRoot) {
12
+ return NextResponse.json({ error: 'Data root path not found' }, { status: 500 });
13
+ }
14
+ const imgRoot = join(dataRoot, 'images');
15
+
16
+
17
+ const formData = await request.formData();
18
+ const files = formData.getAll('files');
19
+
20
+ if (!files || files.length === 0) {
21
+ return NextResponse.json({ error: 'No files provided' }, { status: 400 });
22
+ }
23
+
24
+ // make it recursive if it doesn't exist
25
+ await mkdir(imgRoot, { recursive: true });
26
+ const savedFiles = await Promise.all(
27
+ files.map(async (file: any) => {
28
+ const bytes = await file.arrayBuffer();
29
+ const buffer = Buffer.from(bytes);
30
+
31
+ const extension = file.name.split('.').pop() || 'jpg';
32
+
33
+ // Clean filename and ensure it's unique
34
+ const fileName = `${uuidv4()}`; // Use UUID for unique file names
35
+ const filePath = join(imgRoot, `${fileName}.${extension}`);
36
+
37
+ await writeFile(filePath, buffer);
38
+ return filePath;
39
+ }),
40
+ );
41
+
42
+ return NextResponse.json({
43
+ message: 'Files uploaded successfully',
44
+ files: savedFiles,
45
+ });
46
+ } catch (error) {
47
+ console.error('Upload error:', error);
48
+ return NextResponse.json({ error: 'Error uploading files' }, { status: 500 });
49
+ }
50
+ }
51
+
52
+ // Increase payload size limit (default is 4mb)
53
+ export const config = {
54
+ api: {
55
+ bodyParser: false,
56
+ responseLimit: '50mb',
57
+ },
58
+ };
ui/src/app/api/jobs/[jobID]/delete/route.ts ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { NextRequest, NextResponse } from 'next/server';
2
+ import { PrismaClient } from '@prisma/client';
3
+ import { getTrainingFolder } from '@/server/settings';
4
+ import path from 'path';
5
+ import fs from 'fs';
6
+
7
+ const prisma = new PrismaClient();
8
+
9
+ export async function GET(request: NextRequest, { params }: { params: { jobID: string } }) {
10
+ const { jobID } = await params;
11
+
12
+ const job = await prisma.job.findUnique({
13
+ where: { id: jobID },
14
+ });
15
+
16
+ if (!job) {
17
+ return NextResponse.json({ error: 'Job not found' }, { status: 404 });
18
+ }
19
+
20
+ const trainingRoot = await getTrainingFolder();
21
+ const trainingFolder = path.join(trainingRoot, job.name);
22
+
23
+ if (fs.existsSync(trainingFolder)) {
24
+ fs.rmdirSync(trainingFolder, { recursive: true });
25
+ }
26
+
27
+ await prisma.job.delete({
28
+ where: { id: jobID },
29
+ });
30
+
31
+ return NextResponse.json(job);
32
+ }
ui/src/app/api/jobs/[jobID]/files/route.ts ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { NextRequest, NextResponse } from 'next/server';
2
+ import { PrismaClient } from '@prisma/client';
3
+ import path from 'path';
4
+ import fs from 'fs';
5
+ import { getTrainingFolder } from '@/server/settings';
6
+
7
+ const prisma = new PrismaClient();
8
+
9
+ export async function GET(request: NextRequest, { params }: { params: { jobID: string } }) {
10
+ const { jobID } = await params;
11
+
12
+ const job = await prisma.job.findUnique({
13
+ where: { id: jobID },
14
+ });
15
+
16
+ if (!job) {
17
+ return NextResponse.json({ error: 'Job not found' }, { status: 404 });
18
+ }
19
+
20
+ const trainingFolder = await getTrainingFolder();
21
+ const jobFolder = path.join(trainingFolder, job.name);
22
+
23
+ if (!fs.existsSync(jobFolder)) {
24
+ return NextResponse.json({ files: [] });
25
+ }
26
+
27
+ // find all safetensors files in the job folder
28
+ let files = fs
29
+ .readdirSync(jobFolder)
30
+ .filter(file => {
31
+ return file.endsWith('.safetensors');
32
+ })
33
+ .map(file => {
34
+ return path.join(jobFolder, file);
35
+ })
36
+ .sort();
37
+
38
+ // get the file size for each file
39
+ const fileObjects = files.map(file => {
40
+ const stats = fs.statSync(file);
41
+ return {
42
+ path: file,
43
+ size: stats.size,
44
+ };
45
+ });
46
+
47
+ return NextResponse.json({ files: fileObjects });
48
+ }
ui/src/app/api/jobs/[jobID]/log/route.ts ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { NextRequest, NextResponse } from 'next/server';
2
+ import { PrismaClient } from '@prisma/client';
3
+ import path from 'path';
4
+ import fs from 'fs';
5
+ import { getTrainingFolder } from '@/server/settings';
6
+
7
+ const prisma = new PrismaClient();
8
+
9
+ export async function GET(request: NextRequest, { params }: { params: { jobID: string } }) {
10
+ const { jobID } = await params;
11
+
12
+ const job = await prisma.job.findUnique({
13
+ where: { id: jobID },
14
+ });
15
+
16
+ if (!job) {
17
+ return NextResponse.json({ error: 'Job not found' }, { status: 404 });
18
+ }
19
+
20
+ const trainingFolder = await getTrainingFolder();
21
+ const jobFolder = path.join(trainingFolder, job.name);
22
+ const logPath = path.join(jobFolder, 'log.txt');
23
+
24
+ if (!fs.existsSync(logPath)) {
25
+ return NextResponse.json({ log: '' });
26
+ }
27
+ let log = '';
28
+ try {
29
+ log = fs.readFileSync(logPath, 'utf-8');
30
+ } catch (error) {
31
+ console.error('Error reading log file:', error);
32
+ log = 'Error reading log file';
33
+ }
34
+ return NextResponse.json({ log: log });
35
+ }
ui/src/app/api/jobs/[jobID]/samples/route.ts ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { NextRequest, NextResponse } from 'next/server';
2
+ import { PrismaClient } from '@prisma/client';
3
+ import path from 'path';
4
+ import fs from 'fs';
5
+ import { getTrainingFolder } from '@/server/settings';
6
+
7
+ const prisma = new PrismaClient();
8
+
9
+ export async function GET(request: NextRequest, { params }: { params: { jobID: string } }) {
10
+ const { jobID } = await params;
11
+
12
+ const job = await prisma.job.findUnique({
13
+ where: { id: jobID },
14
+ });
15
+
16
+ if (!job) {
17
+ return NextResponse.json({ error: 'Job not found' }, { status: 404 });
18
+ }
19
+
20
+ // setup the training
21
+ const trainingFolder = await getTrainingFolder();
22
+
23
+ const samplesFolder = path.join(trainingFolder, job.name, 'samples');
24
+ if (!fs.existsSync(samplesFolder)) {
25
+ return NextResponse.json({ samples: [] });
26
+ }
27
+
28
+ // find all img (png, jpg, jpeg) files in the samples folder
29
+ const samples = fs
30
+ .readdirSync(samplesFolder)
31
+ .filter(file => {
32
+ return file.endsWith('.png') || file.endsWith('.jpg') || file.endsWith('.jpeg') || file.endsWith('.webp');
33
+ })
34
+ .map(file => {
35
+ return path.join(samplesFolder, file);
36
+ })
37
+ .sort();
38
+
39
+ return NextResponse.json({ samples });
40
+ }
ui/src/app/api/jobs/[jobID]/start/route.ts ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { NextRequest, NextResponse } from 'next/server';
2
+ import { PrismaClient } from '@prisma/client';
3
+ import { TOOLKIT_ROOT } from '@/paths';
4
+ import { spawn } from 'child_process';
5
+ import path from 'path';
6
+ import fs from 'fs';
7
+ import os from 'os';
8
+ import { getTrainingFolder, getHFToken } from '@/server/settings';
9
+ const isWindows = process.platform === 'win32';
10
+
11
+ const prisma = new PrismaClient();
12
+
13
+ export async function GET(request: NextRequest, { params }: { params: { jobID: string } }) {
14
+ const { jobID } = await params;
15
+
16
+ const job = await prisma.job.findUnique({
17
+ where: { id: jobID },
18
+ });
19
+
20
+ if (!job) {
21
+ return NextResponse.json({ error: 'Job not found' }, { status: 404 });
22
+ }
23
+
24
+ // update job status to 'running'
25
+ await prisma.job.update({
26
+ where: { id: jobID },
27
+ data: {
28
+ status: 'running',
29
+ stop: false,
30
+ info: 'Starting job...',
31
+ },
32
+ });
33
+
34
+ // setup the training
35
+ const trainingRoot = await getTrainingFolder();
36
+
37
+ const trainingFolder = path.join(trainingRoot, job.name);
38
+ if (!fs.existsSync(trainingFolder)) {
39
+ fs.mkdirSync(trainingFolder, { recursive: true });
40
+ }
41
+
42
+ // make the config file
43
+ const configPath = path.join(trainingFolder, '.job_config.json');
44
+
45
+ //log to path
46
+ const logPath = path.join(trainingFolder, 'log.txt');
47
+
48
+ try {
49
+ // if the log path exists, move it to a folder called logs and rename it {num}_log.txt, looking for the highest num
50
+ // if the log path does not exist, create it
51
+ if (fs.existsSync(logPath)) {
52
+ const logsFolder = path.join(trainingFolder, 'logs');
53
+ if (!fs.existsSync(logsFolder)) {
54
+ fs.mkdirSync(logsFolder, { recursive: true });
55
+ }
56
+
57
+ let num = 0;
58
+ while (fs.existsSync(path.join(logsFolder, `${num}_log.txt`))) {
59
+ num++;
60
+ }
61
+
62
+ fs.renameSync(logPath, path.join(logsFolder, `${num}_log.txt`));
63
+ }
64
+ } catch (e) {
65
+ console.error('Error moving log file:', e);
66
+ }
67
+
68
+ // update the config dataset path
69
+ const jobConfig = JSON.parse(job.job_config);
70
+ jobConfig.config.process[0].sqlite_db_path = path.join(TOOLKIT_ROOT, 'aitk_db.db');
71
+
72
+ // write the config file
73
+ fs.writeFileSync(configPath, JSON.stringify(jobConfig, null, 2));
74
+
75
+ let pythonPath = 'python';
76
+ // use .venv or venv if it exists
77
+ if (fs.existsSync(path.join(TOOLKIT_ROOT, '.venv'))) {
78
+ if (isWindows) {
79
+ pythonPath = path.join(TOOLKIT_ROOT, '.venv', 'Scripts', 'python.exe');
80
+ } else {
81
+ pythonPath = path.join(TOOLKIT_ROOT, '.venv', 'bin', 'python');
82
+ }
83
+ } else if (fs.existsSync(path.join(TOOLKIT_ROOT, 'venv'))) {
84
+ if (isWindows) {
85
+ pythonPath = path.join(TOOLKIT_ROOT, 'venv', 'Scripts', 'python.exe');
86
+ } else {
87
+ pythonPath = path.join(TOOLKIT_ROOT, 'venv', 'bin', 'python');
88
+ }
89
+ }
90
+
91
+ const runFilePath = path.join(TOOLKIT_ROOT, 'run.py');
92
+ if (!fs.existsSync(runFilePath)) {
93
+ return NextResponse.json({ error: 'run.py not found' }, { status: 500 });
94
+ }
95
+
96
+ const additionalEnv: any = {
97
+ AITK_JOB_ID: jobID,
98
+ CUDA_VISIBLE_DEVICES: `${job.gpu_ids}`,
99
+ IS_AI_TOOLKIT_UI: '1'
100
+ };
101
+
102
+ // HF_TOKEN
103
+ const hfToken = await getHFToken();
104
+ if (hfToken && hfToken.trim() !== '') {
105
+ additionalEnv.HF_TOKEN = hfToken;
106
+ }
107
+
108
+ // Add the --log argument to the command
109
+ const args = [runFilePath, configPath, '--log', logPath];
110
+
111
+ try {
112
+ let subprocess;
113
+
114
+ if (isWindows) {
115
+ // For Windows, use 'cmd.exe' to open a new command window
116
+ subprocess = spawn('cmd.exe', ['/c', 'start', 'cmd.exe', '/k', pythonPath, ...args], {
117
+ env: {
118
+ ...process.env,
119
+ ...additionalEnv,
120
+ },
121
+ cwd: TOOLKIT_ROOT,
122
+ windowsHide: false,
123
+ });
124
+ } else {
125
+ // For non-Windows platforms
126
+ subprocess = spawn(pythonPath, args, {
127
+ detached: true,
128
+ stdio: ['ignore', 'pipe', 'pipe'], // Changed from 'ignore' to capture output
129
+ env: {
130
+ ...process.env,
131
+ ...additionalEnv,
132
+ },
133
+ cwd: TOOLKIT_ROOT,
134
+ });
135
+ }
136
+
137
+ // Start monitoring in the background without blocking the response
138
+ const monitorProcess = async () => {
139
+ const startTime = Date.now();
140
+ let errorOutput = '';
141
+ let stdoutput = '';
142
+
143
+ if (subprocess.stderr) {
144
+ subprocess.stderr.on('data', data => {
145
+ errorOutput += data.toString();
146
+ });
147
+ subprocess.stdout.on('data', data => {
148
+ stdoutput += data.toString();
149
+ // truncate to only get the last 500 characters
150
+ if (stdoutput.length > 500) {
151
+ stdoutput = stdoutput.substring(stdoutput.length - 500);
152
+ }
153
+ });
154
+ }
155
+
156
+ subprocess.on('exit', async code => {
157
+ const currentTime = Date.now();
158
+ const duration = (currentTime - startTime) / 1000;
159
+ console.log(`Job ${jobID} exited with code ${code} after ${duration} seconds.`);
160
+ // wait for 5 seconds to give it time to stop itself. It id still has a status of running in the db, update it to stopped
161
+ await new Promise(resolve => setTimeout(resolve, 5000));
162
+ const updatedJob = await prisma.job.findUnique({
163
+ where: { id: jobID },
164
+ });
165
+ if (updatedJob?.status === 'running') {
166
+ let errorString = errorOutput;
167
+ if (errorString.trim() === '') {
168
+ errorString = stdoutput;
169
+ }
170
+ await prisma.job.update({
171
+ where: { id: jobID },
172
+ data: {
173
+ status: 'error',
174
+ info: `Error launching job: ${errorString.substring(0, 500)}`,
175
+ },
176
+ });
177
+ }
178
+ });
179
+
180
+ // Wait 30 seconds before releasing the process
181
+ await new Promise(resolve => setTimeout(resolve, 30000));
182
+ // Detach the process for non-Windows systems
183
+ if (!isWindows && subprocess.unref) {
184
+ subprocess.unref();
185
+ }
186
+ };
187
+
188
+ // Start the monitoring without awaiting it
189
+ monitorProcess().catch(err => {
190
+ console.error(`Error in process monitoring for job ${jobID}:`, err);
191
+ });
192
+
193
+ // Return the response immediately
194
+ return NextResponse.json(job);
195
+ } catch (error: any) {
196
+ // Handle any exceptions during process launch
197
+ console.error('Error launching process:', error);
198
+
199
+ await prisma.job.update({
200
+ where: { id: jobID },
201
+ data: {
202
+ status: 'error',
203
+ info: `Error launching job: ${error?.message || 'Unknown error'}`,
204
+ },
205
+ });
206
+
207
+ return NextResponse.json(
208
+ {
209
+ error: 'Failed to launch job process',
210
+ details: error?.message || 'Unknown error',
211
+ },
212
+ { status: 500 },
213
+ );
214
+ }
215
+ }
ui/src/app/api/jobs/[jobID]/stop/route.ts ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { NextRequest, NextResponse } from 'next/server';
2
+ import { PrismaClient } from '@prisma/client';
3
+
4
+ const prisma = new PrismaClient();
5
+
6
+ export async function GET(request: NextRequest, { params }: { params: { jobID: string } }) {
7
+ const { jobID } = await params;
8
+
9
+ const job = await prisma.job.findUnique({
10
+ where: { id: jobID },
11
+ });
12
+
13
+ // update job status to 'running'
14
+ await prisma.job.update({
15
+ where: { id: jobID },
16
+ data: {
17
+ stop: true,
18
+ info: 'Stopping job...',
19
+ },
20
+ });
21
+
22
+ return NextResponse.json(job);
23
+ }
ui/src/app/api/jobs/route.ts ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { NextResponse } from 'next/server';
2
+ import { PrismaClient } from '@prisma/client';
3
+
4
+ const prisma = new PrismaClient();
5
+
6
+ export async function GET(request: Request) {
7
+ const { searchParams } = new URL(request.url);
8
+ const id = searchParams.get('id');
9
+
10
+ try {
11
+ if (id) {
12
+ const job = await prisma.job.findUnique({
13
+ where: { id },
14
+ });
15
+ return NextResponse.json(job);
16
+ }
17
+
18
+ const jobs = await prisma.job.findMany({
19
+ orderBy: { created_at: 'desc' },
20
+ });
21
+ return NextResponse.json({ jobs: jobs });
22
+ } catch (error) {
23
+ console.error(error);
24
+ return NextResponse.json({ error: 'Failed to fetch training data' }, { status: 500 });
25
+ }
26
+ }
27
+
28
+ export async function POST(request: Request) {
29
+ try {
30
+ const body = await request.json();
31
+ const { id, name, job_config, gpu_ids } = body;
32
+
33
+ // Ensure gpu_ids is never null/undefined - provide default value
34
+ const safeGpuIds = gpu_ids || '0';
35
+
36
+ if (id) {
37
+ // Update existing training
38
+ const training = await prisma.job.update({
39
+ where: { id },
40
+ data: {
41
+ name,
42
+ gpu_ids: safeGpuIds,
43
+ job_config: JSON.stringify(job_config),
44
+ },
45
+ });
46
+ return NextResponse.json(training);
47
+ } else {
48
+ // Create new training
49
+ const training = await prisma.job.create({
50
+ data: {
51
+ name,
52
+ gpu_ids: safeGpuIds,
53
+ job_config: JSON.stringify(job_config),
54
+ },
55
+ });
56
+ return NextResponse.json(training);
57
+ }
58
+ } catch (error: any) {
59
+ if (error.code === 'P2002') {
60
+ // Handle unique constraint violation, 409=Conflict
61
+ return NextResponse.json({ error: 'Job name already exists' }, { status: 409 });
62
+ }
63
+ console.error(error);
64
+ // Handle other errors
65
+ return NextResponse.json({ error: 'Failed to save training data' }, { status: 500 });
66
+ }
67
+ }
ui/src/app/api/settings/route.ts ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { NextResponse } from 'next/server';
2
+ import { PrismaClient } from '@prisma/client';
3
+ import { defaultTrainFolder, defaultDatasetsFolder } from '@/paths';
4
+ import { flushCache } from '@/server/settings';
5
+
6
+ const prisma = new PrismaClient();
7
+
8
+ export async function GET() {
9
+ try {
10
+ const settings = await prisma.settings.findMany();
11
+ const settingsObject = settings.reduce((acc: any, setting) => {
12
+ acc[setting.key] = setting.value;
13
+ return acc;
14
+ }, {});
15
+ // if TRAINING_FOLDER is not set, use default
16
+ if (!settingsObject.TRAINING_FOLDER || settingsObject.TRAINING_FOLDER === '') {
17
+ settingsObject.TRAINING_FOLDER = defaultTrainFolder;
18
+ }
19
+ // if DATASETS_FOLDER is not set, use default
20
+ if (!settingsObject.DATASETS_FOLDER || settingsObject.DATASETS_FOLDER === '') {
21
+ settingsObject.DATASETS_FOLDER = defaultDatasetsFolder;
22
+ }
23
+ return NextResponse.json(settingsObject);
24
+ } catch (error) {
25
+ return NextResponse.json({ error: 'Failed to fetch settings' }, { status: 500 });
26
+ }
27
+ }
28
+
29
+ export async function POST(request: Request) {
30
+ try {
31
+ const body = await request.json();
32
+ const { HF_TOKEN, TRAINING_FOLDER, DATASETS_FOLDER } = body;
33
+
34
+ // Upsert both settings
35
+ await Promise.all([
36
+ prisma.settings.upsert({
37
+ where: { key: 'HF_TOKEN' },
38
+ update: { value: HF_TOKEN },
39
+ create: { key: 'HF_TOKEN', value: HF_TOKEN },
40
+ }),
41
+ prisma.settings.upsert({
42
+ where: { key: 'TRAINING_FOLDER' },
43
+ update: { value: TRAINING_FOLDER },
44
+ create: { key: 'TRAINING_FOLDER', value: TRAINING_FOLDER },
45
+ }),
46
+ prisma.settings.upsert({
47
+ where: { key: 'DATASETS_FOLDER' },
48
+ update: { value: DATASETS_FOLDER },
49
+ create: { key: 'DATASETS_FOLDER', value: DATASETS_FOLDER },
50
+ }),
51
+ ]);
52
+
53
+ flushCache();
54
+
55
+ return NextResponse.json({ success: true });
56
+ } catch (error) {
57
+ return NextResponse.json({ error: 'Failed to update settings' }, { status: 500 });
58
+ }
59
+ }
ui/src/app/api/zip/route.ts ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* eslint-disable */
2
+ import { NextRequest, NextResponse } from 'next/server';
3
+ import fs from 'fs';
4
+ import fsp from 'fs/promises';
5
+ import path from 'path';
6
+ import archiver from 'archiver';
7
+ import { getTrainingFolder } from '@/server/settings';
8
+
9
+ export const runtime = 'nodejs'; // ensure Node APIs are available
10
+ export const dynamic = 'force-dynamic'; // long-running, non-cached
11
+
12
+ type PostBody = {
13
+ zipTarget: 'samples'; //only samples for now
14
+ jobName: string;
15
+ };
16
+
17
+ async function resolveSafe(p: string) {
18
+ // resolve symlinks + normalize
19
+ return await fsp.realpath(p);
20
+ }
21
+
22
+ export async function POST(request: NextRequest) {
23
+ try {
24
+ const body = (await request.json()) as PostBody;
25
+ if (!body || !body.jobName) {
26
+ return NextResponse.json({ error: 'jobName is required' }, { status: 400 });
27
+ }
28
+
29
+ const trainingRoot = await resolveSafe(await getTrainingFolder());
30
+ const folderPath = await resolveSafe(path.join(trainingRoot, body.jobName, 'samples'));
31
+ const outputPath = path.resolve(trainingRoot, body.jobName, 'samples.zip');
32
+
33
+ // Must be a directory
34
+ let stat: fs.Stats;
35
+ try {
36
+ stat = await fsp.stat(folderPath);
37
+ } catch {
38
+ return new NextResponse('Folder not found', { status: 404 });
39
+ }
40
+ if (!stat.isDirectory()) {
41
+ return new NextResponse('Not a directory', { status: 400 });
42
+ }
43
+
44
+ // delete current one if it exists
45
+ if (fs.existsSync(outputPath)) {
46
+ await fsp.unlink(outputPath);
47
+ }
48
+
49
+ // Create write stream & archive
50
+ await new Promise<void>((resolve, reject) => {
51
+ const output = fs.createWriteStream(outputPath);
52
+ const archive = archiver('zip', { zlib: { level: 9 } });
53
+
54
+ output.on('close', () => resolve());
55
+ output.on('error', reject);
56
+ archive.on('error', reject);
57
+
58
+ archive.pipe(output);
59
+
60
+ // Add the directory contents (place them under the folder's base name in the zip)
61
+ const rootName = path.basename(folderPath);
62
+ archive.directory(folderPath, rootName);
63
+
64
+ archive.finalize().catch(reject);
65
+ });
66
+
67
+ // Return the absolute path so your existing /api/files/[...filePath] can serve it
68
+ // Example download URL (client-side): `/api/files/${encodeURIComponent(resolvedOutPath)}`
69
+ return NextResponse.json({
70
+ ok: true,
71
+ zipPath: outputPath,
72
+ fileName: path.basename(outputPath),
73
+ });
74
+ } catch (err) {
75
+ console.error('Zip error:', err);
76
+ return new NextResponse('Internal Server Error', { status: 500 });
77
+ }
78
+ }
ui/src/app/apple-icon.png ADDED
ui/src/app/dashboard/page.tsx ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 'use client';
2
+
3
+ import GpuMonitor from '@/components/GPUMonitor';
4
+ import JobsTable from '@/components/JobsTable';
5
+ import { TopBar, MainContent } from '@/components/layout';
6
+ import Link from 'next/link';
7
+ import { useAuth } from '@/contexts/AuthContext';
8
+ import HFLoginButton from '@/components/HFLoginButton';
9
+
10
+ export default function Dashboard() {
11
+ const { status: authStatus, namespace } = useAuth();
12
+ const isAuthenticated = authStatus === 'authenticated';
13
+
14
+ return (
15
+ <>
16
+ <TopBar>
17
+ <div>
18
+ <h1 className="text-lg">Dashboard</h1>
19
+ </div>
20
+ <div className="flex-1 flex items-center justify-end gap-3 pr-2 text-sm text-gray-400">
21
+ {isAuthenticated ? (
22
+ <span>Welcome, {namespace || 'user'}</span>
23
+ ) : (
24
+ <>
25
+ <span>Welcome, Guest</span>
26
+ <HFLoginButton size="sm" />
27
+ <Link href="/settings" className="text-xs text-blue-400 hover:text-blue-300">
28
+ Settings
29
+ </Link>
30
+ </>
31
+ )}
32
+ </div>
33
+ </TopBar>
34
+ <MainContent>
35
+ <GpuMonitor />
36
+ <div className="w-full mt-4">
37
+ <div className="flex justify-between items-center mb-2">
38
+ <h1 className="text-md">Active Jobs</h1>
39
+ <div className="text-xs text-gray-500">
40
+ <Link href="/jobs">View All</Link>
41
+ </div>
42
+ </div>
43
+ {isAuthenticated ? (
44
+ <JobsTable onlyActive />
45
+ ) : (
46
+ <div className="border border-gray-800 rounded-lg p-6 bg-gray-900 text-gray-400 text-sm">
47
+ Sign in with Hugging Face or add an access token in Settings to view and manage jobs.
48
+ </div>
49
+ )}
50
+ </div>
51
+ </MainContent>
52
+ </>
53
+ );
54
+ }
ui/src/app/datasets/[datasetName]/page.tsx ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 'use client';
2
+
3
+ import { useEffect, useState, use, useMemo } from 'react';
4
+ import { LuImageOff, LuLoader, LuBan } from 'react-icons/lu';
5
+ import { FaChevronLeft } from 'react-icons/fa';
6
+ import DatasetImageCard from '@/components/DatasetImageCard';
7
+ import { Button } from '@headlessui/react';
8
+ import AddImagesModal, { openImagesModal } from '@/components/AddImagesModal';
9
+ import { TopBar, MainContent } from '@/components/layout';
10
+ import { apiClient } from '@/utils/api';
11
+ import FullscreenDropOverlay from '@/components/FullscreenDropOverlay';
12
+ import { useRouter } from 'next/navigation';
13
+ import { usingBrowserDb } from '@/utils/env';
14
+ import { hasUserDataset } from '@/utils/storage/datasetStorage';
15
+ import { useAuth } from '@/contexts/AuthContext';
16
+ import HFLoginButton from '@/components/HFLoginButton';
17
+ import Link from 'next/link';
18
+
19
+ export default function DatasetPage({ params }: { params: { datasetName: string } }) {
20
+ const [imgList, setImgList] = useState<{ img_path: string }[]>([]);
21
+ const usableParams = use(params as any) as { datasetName: string };
22
+ const datasetName = usableParams.datasetName;
23
+ const [status, setStatus] = useState<'idle' | 'loading' | 'success' | 'error'>('idle');
24
+ const router = useRouter();
25
+ const { status: authStatus } = useAuth();
26
+ const isAuthenticated = authStatus === 'authenticated';
27
+ const hasDatasetEntry = !usingBrowserDb || hasUserDataset(datasetName);
28
+ const allowAccess = hasDatasetEntry && isAuthenticated;
29
+
30
+ const refreshImageList = (dbName: string) => {
31
+ setStatus('loading');
32
+ console.log('Fetching images for dataset:', dbName);
33
+ apiClient
34
+ .post('/api/datasets/listImages', { datasetName: dbName })
35
+ .then((res: any) => {
36
+ const data = res.data;
37
+ console.log('Images:', data.images);
38
+ // sort
39
+ data.images.sort((a: { img_path: string }, b: { img_path: string }) => a.img_path.localeCompare(b.img_path));
40
+ setImgList(data.images);
41
+ setStatus('success');
42
+ })
43
+ .catch(error => {
44
+ console.error('Error fetching images:', error);
45
+ setStatus('error');
46
+ });
47
+ };
48
+ useEffect(() => {
49
+ if (!datasetName) {
50
+ return;
51
+ }
52
+
53
+ if (!isAuthenticated) {
54
+ return;
55
+ }
56
+
57
+ if (!hasDatasetEntry) {
58
+ setImgList([]);
59
+ setStatus('error');
60
+ router.replace('/datasets');
61
+ return;
62
+ }
63
+
64
+ refreshImageList(datasetName);
65
+ }, [datasetName, hasDatasetEntry, isAuthenticated, router]);
66
+
67
+ if (!allowAccess) {
68
+ return (
69
+ <>
70
+ <TopBar>
71
+ <div>
72
+ <Button className="text-gray-500 dark:text-gray-300 px-3 mt-1" onClick={() => history.back()}>
73
+ <FaChevronLeft />
74
+ </Button>
75
+ </div>
76
+ <div>
77
+ <h1 className="text-lg">Dataset: {datasetName}</h1>
78
+ </div>
79
+ <div className="flex-1"></div>
80
+ </TopBar>
81
+ <MainContent>
82
+ <div className="border border-gray-800 rounded-lg p-6 bg-gray-900 text-gray-400 text-sm flex flex-col gap-4">
83
+ <p>You need to sign in with Hugging Face or provide a valid token to view this dataset.</p>
84
+ <div className="flex items-center gap-3">
85
+ <HFLoginButton size="sm" />
86
+ <Link href="/settings" className="text-xs text-blue-400 hover:text-blue-300">
87
+ Manage authentication in Settings
88
+ </Link>
89
+ </div>
90
+ </div>
91
+ </MainContent>
92
+ </>
93
+ );
94
+ }
95
+
96
+ const PageInfoContent = useMemo(() => {
97
+ let icon = null;
98
+ let text = '';
99
+ let subtitle = '';
100
+ let showIt = false;
101
+ let bgColor = '';
102
+ let textColor = '';
103
+ let iconColor = '';
104
+
105
+ if (status == 'loading') {
106
+ icon = <LuLoader className="animate-spin w-8 h-8" />;
107
+ text = 'Loading Images';
108
+ subtitle = 'Please wait while we fetch your dataset images...';
109
+ showIt = true;
110
+ bgColor = 'bg-gray-50 dark:bg-gray-800/50';
111
+ textColor = 'text-gray-900 dark:text-gray-100';
112
+ iconColor = 'text-gray-500 dark:text-gray-400';
113
+ }
114
+ if (status == 'error') {
115
+ icon = <LuBan className="w-8 h-8" />;
116
+ text = 'Error Loading Images';
117
+ subtitle = 'There was a problem fetching the images. Please try refreshing the page.';
118
+ showIt = true;
119
+ bgColor = 'bg-red-50 dark:bg-red-950/20';
120
+ textColor = 'text-red-900 dark:text-red-100';
121
+ iconColor = 'text-red-600 dark:text-red-400';
122
+ }
123
+ if (status == 'success' && imgList.length === 0) {
124
+ icon = <LuImageOff className="w-8 h-8" />;
125
+ text = 'No Images Found';
126
+ subtitle = 'This dataset is empty. Click "Add Images" to get started.';
127
+ showIt = true;
128
+ bgColor = 'bg-gray-50 dark:bg-gray-800/50';
129
+ textColor = 'text-gray-900 dark:text-gray-100';
130
+ iconColor = 'text-gray-500 dark:text-gray-400';
131
+ }
132
+
133
+ if (!showIt) return null;
134
+
135
+ return (
136
+ <div
137
+ className={`mt-10 flex flex-col items-center justify-center py-16 px-8 rounded-xl border-2 border-gray-700 border-dashed ${bgColor} ${textColor} mx-auto max-w-md text-center`}
138
+ >
139
+ <div className={`${iconColor} mb-4`}>{icon}</div>
140
+ <h3 className="text-lg font-semibold mb-2">{text}</h3>
141
+ <p className="text-sm opacity-75 leading-relaxed">{subtitle}</p>
142
+ </div>
143
+ );
144
+ }, [status, imgList.length]);
145
+
146
+ return (
147
+ <>
148
+ {/* Fixed top bar */}
149
+ <TopBar>
150
+ <div>
151
+ <Button className="text-gray-500 dark:text-gray-300 px-3 mt-1" onClick={() => history.back()}>
152
+ <FaChevronLeft />
153
+ </Button>
154
+ </div>
155
+ <div>
156
+ <h1 className="text-lg">Dataset: {datasetName}</h1>
157
+ </div>
158
+ <div className="flex-1"></div>
159
+ <div>
160
+ <Button
161
+ className="text-gray-200 bg-slate-600 px-3 py-1 rounded-md"
162
+ onClick={() => openImagesModal(datasetName, () => refreshImageList(datasetName))}
163
+ >
164
+ Add Images
165
+ </Button>
166
+ </div>
167
+ </TopBar>
168
+ <MainContent>
169
+ {PageInfoContent}
170
+ {status === 'success' && imgList.length > 0 && (
171
+ <div className="grid grid-cols-1 sm:grid-cols-2 md:grid-cols-3 lg:grid-cols-4 gap-4">
172
+ {imgList.map(img => (
173
+ <DatasetImageCard
174
+ key={img.img_path}
175
+ alt="image"
176
+ imageUrl={img.img_path}
177
+ onDelete={() => refreshImageList(datasetName)}
178
+ />
179
+ ))}
180
+ </div>
181
+ )}
182
+ </MainContent>
183
+ <AddImagesModal />
184
+ <FullscreenDropOverlay
185
+ datasetName={datasetName}
186
+ onComplete={() => refreshImageList(datasetName)}
187
+ />
188
+ </>
189
+ );
190
+ }
ui/src/app/datasets/page.tsx ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 'use client';
2
+
3
+ import { useState } from 'react';
4
+ import { Modal } from '@/components/Modal';
5
+ import Link from 'next/link';
6
+ import { TextInput } from '@/components/formInputs';
7
+ import useDatasetList from '@/hooks/useDatasetList';
8
+ import { Button } from '@headlessui/react';
9
+ import { FaRegTrashAlt } from 'react-icons/fa';
10
+ import { openConfirm } from '@/components/ConfirmModal';
11
+ import { TopBar, MainContent } from '@/components/layout';
12
+ import UniversalTable, { TableColumn } from '@/components/UniversalTable';
13
+ import { apiClient } from '@/utils/api';
14
+ import { useRouter } from 'next/navigation';
15
+ import { usingBrowserDb } from '@/utils/env';
16
+ import { addUserDataset, removeUserDataset } from '@/utils/storage/datasetStorage';
17
+ import { useAuth } from '@/contexts/AuthContext';
18
+ import HFLoginButton from '@/components/HFLoginButton';
19
+
20
+ export default function Datasets() {
21
+ const router = useRouter();
22
+ const { datasets, status, refreshDatasets } = useDatasetList();
23
+ const [newDatasetName, setNewDatasetName] = useState('');
24
+ const [isNewDatasetModalOpen, setIsNewDatasetModalOpen] = useState(false);
25
+ const { status: authStatus } = useAuth();
26
+ const isAuthenticated = authStatus === 'authenticated';
27
+
28
+ // Transform datasets array into rows with objects
29
+ const tableRows = datasets.map(dataset => ({
30
+ name: dataset,
31
+ actions: dataset, // Pass full dataset name for actions
32
+ }));
33
+
34
+ const columns: TableColumn[] = [
35
+ {
36
+ title: 'Dataset Name',
37
+ key: 'name',
38
+ render: row => (
39
+ <Link href={`/datasets/${row.name}`} className="text-gray-200 hover:text-gray-100">
40
+ {row.name}
41
+ </Link>
42
+ ),
43
+ },
44
+ {
45
+ title: 'Actions',
46
+ key: 'actions',
47
+ className: 'w-20 text-right',
48
+ render: row => (
49
+ <button
50
+ className="text-gray-200 hover:bg-red-600 p-2 rounded-full transition-colors"
51
+ onClick={() => handleDeleteDataset(row.name)}
52
+ >
53
+ <FaRegTrashAlt />
54
+ </button>
55
+ ),
56
+ },
57
+ ];
58
+
59
+ const handleDeleteDataset = (datasetName: string) => {
60
+ openConfirm({
61
+ title: 'Delete Dataset',
62
+ message: `Are you sure you want to delete the dataset "${datasetName}"? This action cannot be undone.`,
63
+ type: 'warning',
64
+ confirmText: 'Delete',
65
+ onConfirm: () => {
66
+ apiClient
67
+ .post('/api/datasets/delete', { name: datasetName })
68
+ .then(() => {
69
+ console.log('Dataset deleted:', datasetName);
70
+ if (usingBrowserDb) {
71
+ removeUserDataset(datasetName);
72
+ }
73
+ refreshDatasets();
74
+ })
75
+ .catch(error => {
76
+ console.error('Error deleting dataset:', error);
77
+ });
78
+ },
79
+ });
80
+ };
81
+
82
+ const handleCreateDataset = async (e: React.FormEvent) => {
83
+ e.preventDefault();
84
+ if (!isAuthenticated) {
85
+ return;
86
+ }
87
+ try {
88
+ const data = await apiClient.post('/api/datasets/create', { name: newDatasetName }).then(res => res.data);
89
+ console.log('New dataset created:', data);
90
+ if (usingBrowserDb && data?.name) {
91
+ addUserDataset(data.name, data?.path || '');
92
+ }
93
+ refreshDatasets();
94
+ setNewDatasetName('');
95
+ setIsNewDatasetModalOpen(false);
96
+ } catch (error) {
97
+ console.error('Error creating new dataset:', error);
98
+ }
99
+ };
100
+
101
+ const openNewDatasetModal = () => {
102
+ if (!isAuthenticated) {
103
+ return;
104
+ }
105
+ openConfirm({
106
+ title: 'New Dataset',
107
+ message: 'Enter the name of the new dataset:',
108
+ type: 'info',
109
+ confirmText: 'Create',
110
+ inputTitle: 'Dataset Name',
111
+ onConfirm: async (name?: string) => {
112
+ if (!name) {
113
+ console.error('Dataset name is required.');
114
+ return;
115
+ }
116
+ if (!isAuthenticated) {
117
+ return;
118
+ }
119
+ try {
120
+ const data = await apiClient.post('/api/datasets/create', { name }).then(res => res.data);
121
+ console.log('New dataset created:', data);
122
+ if (usingBrowserDb && data?.name) {
123
+ addUserDataset(data.name, data?.path || '');
124
+ }
125
+ if (data.name) {
126
+ router.push(`/datasets/${data.name}`);
127
+ } else {
128
+ refreshDatasets();
129
+ }
130
+ } catch (error) {
131
+ console.error('Error creating new dataset:', error);
132
+ }
133
+ },
134
+ });
135
+ };
136
+
137
+ return (
138
+ <>
139
+ <TopBar>
140
+ <div>
141
+ <h1 className="text-2xl font-semibold text-gray-100">Datasets</h1>
142
+ </div>
143
+ <div className="flex-1"></div>
144
+ <div>
145
+ {isAuthenticated ? (
146
+ <Button
147
+ className="text-gray-200 bg-slate-600 px-4 py-2 rounded-md hover:bg-slate-500 transition-colors"
148
+ onClick={() => openNewDatasetModal()}
149
+ >
150
+ New Dataset
151
+ </Button>
152
+ ) : (
153
+ <span className="text-gray-600 bg-gray-900 px-3 py-1 rounded-md border border-gray-800">
154
+ Sign in to add datasets
155
+ </span>
156
+ )}
157
+ </div>
158
+ </TopBar>
159
+
160
+ <MainContent>
161
+ {isAuthenticated ? (
162
+ <UniversalTable
163
+ columns={columns}
164
+ rows={tableRows}
165
+ isLoading={status === 'loading'}
166
+ onRefresh={refreshDatasets}
167
+ />
168
+ ) : (
169
+ <div className="border border-gray-800 rounded-lg p-6 bg-gray-900 text-gray-400 text-sm flex flex-col gap-4">
170
+ <p>Sign in with Hugging Face or add an access token to manage datasets.</p>
171
+ <div className="flex items-center gap-3">
172
+ <HFLoginButton size="sm" />
173
+ <Link href="/settings" className="text-xs text-blue-400 hover:text-blue-300">
174
+ Manage authentication in Settings
175
+ </Link>
176
+ </div>
177
+ </div>
178
+ )}
179
+ </MainContent>
180
+
181
+ <Modal
182
+ isOpen={isNewDatasetModalOpen}
183
+ onClose={() => setIsNewDatasetModalOpen(false)}
184
+ title="New Dataset"
185
+ size="md"
186
+ >
187
+ <div className="space-y-4 text-gray-200">
188
+ <form onSubmit={handleCreateDataset}>
189
+ <div className="text-sm text-gray-400">
190
+ This will create a new folder with the name below in your dataset folder.
191
+ </div>
192
+ <div className="mt-4">
193
+ <TextInput label="Dataset Name" value={newDatasetName} onChange={value => setNewDatasetName(value)} />
194
+ </div>
195
+
196
+ <div className="mt-6 flex justify-end space-x-3">
197
+ <button
198
+ type="button"
199
+ className="rounded-md bg-gray-700 px-4 py-2 text-gray-200 hover:bg-gray-600 focus:outline-none focus:ring-2 focus:ring-gray-500"
200
+ onClick={() => setIsNewDatasetModalOpen(false)}
201
+ >
202
+ Cancel
203
+ </button>
204
+ <button
205
+ type="submit"
206
+ className="rounded-md bg-blue-600 px-4 py-2 text-white hover:bg-blue-700 focus:outline-none focus:ring-2 focus:ring-blue-500 disabled:opacity-50 disabled:cursor-not-allowed"
207
+ disabled={!isAuthenticated}
208
+ >
209
+ Confirm
210
+ </button>
211
+ </div>
212
+ </form>
213
+ </div>
214
+ </Modal>
215
+ </>
216
+ );
217
+ }