jbilcke-hf HF staff commited on
Commit
07d10ce
·
1 Parent(s): 5c099f3

various improvements

Browse files
package.json CHANGED
@@ -8,7 +8,7 @@
8
  "test:submitVideo": "node --loader ts-node/esm src/tests/submitVideo.mts",
9
  "test:checkStatus": "node --loader ts-node/esm src/tests/checkStatus.mts",
10
  "test:downloadFileToTmp": "node --loader ts-node/esm src/tests/downloadFileToTmp.mts",
11
- "test:stuff": "node --loader ts-node/esm src/stuff.mts",
12
  "docker": "npm run docker:build && npm run docker:run",
13
  "docker:build": "docker build -t videochain-api .",
14
  "docker:run": "docker run -it -p 7860:7860 videochain-api"
 
8
  "test:submitVideo": "node --loader ts-node/esm src/tests/submitVideo.mts",
9
  "test:checkStatus": "node --loader ts-node/esm src/tests/checkStatus.mts",
10
  "test:downloadFileToTmp": "node --loader ts-node/esm src/tests/downloadFileToTmp.mts",
11
+ "test:stuff": "node --loader ts-node/esm src/utils/segmentImage.mts",
12
  "docker": "npm run docker:build && npm run docker:run",
13
  "docker:build": "docker build -t videochain-api .",
14
  "docker:run": "docker run -it -p 7860:7860 videochain-api"
src/index.mts CHANGED
@@ -4,7 +4,7 @@ import path from "node:path"
4
  import { validate as uuidValidate } from "uuid"
5
  import express from "express"
6
 
7
- import { Video, VideoStatus, VideoAPIRequest } from "./types.mts"
8
  import { parseVideoRequest } from "./utils/parseVideoRequest.mts"
9
  import { savePendingVideo } from "./scheduler/savePendingVideo.mts"
10
  import { getVideo } from "./scheduler/getVideo.mts"
@@ -38,9 +38,9 @@ let isRendering = false
38
  // a "fast track" pipeline
39
  app.post("/render", async (req, res) => {
40
 
41
- const prompt = req.body.prompt as string
42
- console.log(`/render: "${prompt}"`)
43
- if (!prompt) {
44
  console.log("Invalid prompt")
45
  res.status(400)
46
  res.write(JSON.stringify({ url: "", error: "invalid prompt" }))
@@ -48,9 +48,15 @@ app.post("/render", async (req, res) => {
48
  return
49
  }
50
 
51
- let result = { url: "", error: "" }
 
 
 
 
 
 
52
  try {
53
- result = await renderScene(prompt)
54
  } catch (err) {
55
  // console.log("failed to render scene!")
56
  result.error = `failed to render scene: ${err}`
 
4
  import { validate as uuidValidate } from "uuid"
5
  import express from "express"
6
 
7
+ import { Video, VideoStatus, VideoAPIRequest, RenderRequest, RenderAPIResponse } from "./types.mts"
8
  import { parseVideoRequest } from "./utils/parseVideoRequest.mts"
9
  import { savePendingVideo } from "./scheduler/savePendingVideo.mts"
10
  import { getVideo } from "./scheduler/getVideo.mts"
 
38
  // a "fast track" pipeline
39
  app.post("/render", async (req, res) => {
40
 
41
+ const request = req.body as RenderRequest
42
+ console.log(req.body)
43
+ if (!request.prompt) {
44
  console.log("Invalid prompt")
45
  res.status(400)
46
  res.write(JSON.stringify({ url: "", error: "invalid prompt" }))
 
48
  return
49
  }
50
 
51
+ let result: RenderAPIResponse = {
52
+ videoUrl: "",
53
+ maskBase64: "",
54
+ error: "",
55
+ segments: []
56
+ }
57
+
58
  try {
59
+ result = await renderScene(request)
60
  } catch (err) {
61
  // console.log("failed to render scene!")
62
  result.error = `failed to render scene: ${err}`
src/production/generateAudioLegacy.mts CHANGED
@@ -18,7 +18,9 @@ export const generateAudio = async (prompt: string, options?: {
18
  const instance = instances.shift()
19
  instances.push(instance)
20
 
21
- const api = await client(instance)
 
 
22
 
23
  const rawResponse = await api.predict('/run', [
24
  prompt, // string in 'Prompt' Textbox component
 
18
  const instance = instances.shift()
19
  instances.push(instance)
20
 
21
+ const api = await client(instance, {
22
+ hf_token: `${process.env.VC_HF_API_TOKEN}` as any
23
+ })
24
 
25
  const rawResponse = await api.predict('/run', [
26
  prompt, // string in 'Prompt' Textbox component
src/production/generateVideo.mts CHANGED
@@ -3,8 +3,9 @@ import { client } from "@gradio/client"
3
  import { generateSeed } from "../utils/generateSeed.mts"
4
 
5
  const instances: string[] = [
6
- `${process.env.VC_VIDEO_GENERATION_SPACE_API_URL || ""}`,
7
- `${process.env.VC_RENDERING_ENGINE_SPACE_API_URL || ""}`,
 
8
  ].filter(instance => instance?.length > 0)
9
 
10
  export const generateVideo = async (prompt: string, options?: {
@@ -19,7 +20,9 @@ export const generateVideo = async (prompt: string, options?: {
19
  const instance = instances.shift()
20
  instances.push(instance)
21
 
22
- const api = await client(instance)
 
 
23
 
24
  const rawResponse = await api.predict('/run', [
25
  prompt, // string in 'Prompt' Textbox component
 
3
  import { generateSeed } from "../utils/generateSeed.mts"
4
 
5
  const instances: string[] = [
6
+ `${process.env.VC_ZEROSCOPE_SPACE_API_URL_1 || ""}`,
7
+ // `${process.env.VC_ZEROSCOPE_SPACE_API_URL_2 || ""}`,
8
+ // `${process.env.VC_ZEROSCOPE_SPACE_API_URL_3 || ""}`,
9
  ].filter(instance => instance?.length > 0)
10
 
11
  export const generateVideo = async (prompt: string, options?: {
 
20
  const instance = instances.shift()
21
  instances.push(instance)
22
 
23
+ const api = await client(instance, {
24
+ hf_token: `${process.env.VC_HF_API_TOKEN}` as any
25
+ })
26
 
27
  const rawResponse = await api.predict('/run', [
28
  prompt, // string in 'Prompt' Textbox component
src/production/interpolateVideoLegacy.mts CHANGED
@@ -18,7 +18,9 @@ export const interpolateVideo = async (fileName: string, steps: number, fps: num
18
  const instance = instances.shift()
19
  instances.push(instance)
20
 
21
- const api = await client(instance)
 
 
22
 
23
  const video = await fs.readFile(inputFilePath)
24
 
 
18
  const instance = instances.shift()
19
  instances.push(instance)
20
 
21
+ const api = await client(instance, {
22
+ hf_token: `${process.env.VC_HF_API_TOKEN}` as any
23
+ })
24
 
25
  const video = await fs.readFile(inputFilePath)
26
 
src/production/renderScene.mts CHANGED
@@ -1,5 +1,12 @@
 
 
 
 
1
  import { generateSeed } from "../utils/generateSeed.mts"
 
2
  import { generateVideo } from "./generateVideo.mts"
 
 
3
 
4
  const state = {
5
  isRendering: false
@@ -7,15 +14,22 @@ const state = {
7
 
8
  const seed = generateSeed()
9
 
10
- export async function renderScene(prompt: string) {
11
  // console.log("renderScene")
 
 
 
 
12
  if (state.isRendering) {
13
  // console.log("renderScene: isRendering")
14
  return {
15
- url: "",
16
- error: "already rendering"
 
 
17
  }
18
  }
 
19
 
20
  // onsole.log("marking as isRendering")
21
  state.isRendering = true
@@ -24,11 +38,10 @@ export async function renderScene(prompt: string) {
24
  let error = ""
25
 
26
  try {
27
- url = await generateVideo(prompt, {
28
- seed: generateSeed(),
29
- // seed,
30
- nbFrames: 16,
31
- nbSteps: 10,
32
  })
33
  // console.log("successfull generation")
34
  error = ""
@@ -36,12 +49,56 @@ export async function renderScene(prompt: string) {
36
  error = `failed to render scene: ${err}`
37
  }
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  // console.log("marking as not rendering anymore")
40
  state.isRendering = false
41
  error = ""
42
 
43
  return {
44
- url,
45
- error
46
- }
 
 
47
  }
 
1
+ import { v4 as uuidv4 } from "uuid"
2
+
3
+ import { ImageSegment, RenderAPIResponse, RenderRequest } from "../types.mts"
4
+ import { downloadFileToTmp } from "../utils/downloadFileToTmp.mts"
5
  import { generateSeed } from "../utils/generateSeed.mts"
6
+ import { getValidNumber } from "../utils/getValidNumber.mts"
7
  import { generateVideo } from "./generateVideo.mts"
8
+ import { getFirstVideoFrame } from "../utils/getFirstVideoFrame.mts"
9
+ import { segmentImage } from "../utils/segmentImage.mts"
10
 
11
  const state = {
12
  isRendering: false
 
14
 
15
  const seed = generateSeed()
16
 
17
+ export async function renderScene(scene: RenderRequest): Promise<RenderAPIResponse> {
18
  // console.log("renderScene")
19
+
20
+ // let's disable this for now
21
+ // this is only reliable if nothing crashes anyway..
22
+ /*
23
  if (state.isRendering) {
24
  // console.log("renderScene: isRendering")
25
  return {
26
+ videoUrl: "",
27
+ error: "already rendering",
28
+ maskBase64: "",
29
+ segments: [],
30
  }
31
  }
32
+ */
33
 
34
  // onsole.log("marking as isRendering")
35
  state.isRendering = true
 
38
  let error = ""
39
 
40
  try {
41
+ url = await generateVideo(scene.prompt, {
42
+ seed: getValidNumber(scene.seed, 0, 4294967295, generateSeed()),
43
+ nbFrames: getValidNumber(scene.nbFrames, 8, 24, 16), // 2 seconds by default
44
+ nbSteps: getValidNumber(scene.nbSteps, 1, 50, 10), // use 10 by default to go fast, but not too sloppy
 
45
  })
46
  // console.log("successfull generation")
47
  error = ""
 
49
  error = `failed to render scene: ${err}`
50
  }
51
 
52
+
53
+
54
+ // TODO add segmentation here
55
+ const actionnables = Array.isArray(scene.actionnables) ? scene.actionnables : []
56
+
57
+ let mask = ""
58
+ let segments: ImageSegment[] = []
59
+
60
+ if (actionnables.length > 0) {
61
+ console.log("we have some actionnables:", actionnables)
62
+ if (scene.segmentation === "firstframe") {
63
+ console.log("going to grab the first frame")
64
+ const tmpVideoFilePath = await downloadFileToTmp(url, `${uuidv4()}`)
65
+ console.log("downloaded the first frame to ", tmpVideoFilePath)
66
+ const firstFrameFilePath = await getFirstVideoFrame(tmpVideoFilePath)
67
+ console.log("downloaded the first frame to ", firstFrameFilePath)
68
+
69
+ if (!firstFrameFilePath) {
70
+ console.error("failed to get the image")
71
+ error = "failed to segment the image"
72
+ } else {
73
+ console.log("got the first frame! segmenting..")
74
+ const result = await segmentImage(firstFrameFilePath, actionnables)
75
+ mask = result.pngInBase64
76
+ segments = result.segments
77
+ // console.log("success!", { segments })
78
+ }
79
+ /*
80
+ const jpgBase64 = await getFirstVideoFrame(tmpVideoFileName)
81
+ if (!jpgBase64) {
82
+ console.error("failed to get the image")
83
+ error = "failed to segment the image"
84
+ } else {
85
+ console.log(`got the first frame (${jpgBase64.length})`)
86
+
87
+ console.log("TODO: call segmentImage with the base64 image")
88
+ await segmentImage()
89
+ }
90
+ */
91
+ }
92
+ }
93
+
94
  // console.log("marking as not rendering anymore")
95
  state.isRendering = false
96
  error = ""
97
 
98
  return {
99
+ videoUrl: url,
100
+ error,
101
+ maskBase64: mask,
102
+ segments
103
+ } as RenderAPIResponse
104
  }
src/types.mts CHANGED
@@ -269,8 +269,46 @@ export type Video = VideoSequence & {
269
  shots: VideoShot[]
270
  }
271
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
272
 
273
  export interface ImageSegmentationRequest {
274
  image: string // in base64
275
  keywords: string[]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
276
  }
 
269
  shots: VideoShot[]
270
  }
271
 
272
+ export interface RenderRequest {
273
+ prompt: string
274
+
275
+ // whether to use video segmentation
276
+ // disabled (default)
277
+ // firstframe: we only analyze the first frame
278
+ // allframes: we analyze all the frames
279
+ segmentation: 'disabled' | 'firstframe' | 'allframes'
280
+
281
+ // segmentation will only be executed if we have a non-empty list of actionnables
282
+ // actionnables are names of things like "chest", "key", "tree", "chair" etc
283
+ actionnables: string[]
284
+
285
+ // note: this is the number of frames for Zeroscope,
286
+ // which is currently configured to only output 3 seconds, so:
287
+ // nbFrames=8 -> 1 sec
288
+ // nbFrames=16 -> 2 sec
289
+ // nbFrames=24 -> 3 sec
290
+ nbFrames: number // min: 8, max: 24
291
+
292
+ nbSteps: number // min: 1, max: 50
293
+
294
+ seed: number
295
+ }
296
 
297
  export interface ImageSegmentationRequest {
298
  image: string // in base64
299
  keywords: string[]
300
+ }
301
+
302
+ export interface ImageSegment {
303
+ id: number
304
+ box: number[]
305
+ label: string
306
+ score: number
307
+ }
308
+
309
+ export interface RenderAPIResponse {
310
+ videoUrl: string
311
+ error: string
312
+ maskBase64: string
313
+ segments: ImageSegment[]
314
  }
src/utils/downloadFileAsBase64.mts ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export const downloadImageAsBase64 = async (remoteUrl: string): Promise<string> => {
2
+ const controller = new AbortController()
3
+
4
+ // download the image
5
+ const response = await fetch(remoteUrl, {
6
+ signal: controller.signal
7
+ })
8
+
9
+ // get as Buffer
10
+ const arrayBuffer = await response.arrayBuffer()
11
+ const buffer = Buffer.from(arrayBuffer)
12
+
13
+ // convert it to base64
14
+ const base64 = buffer.toString('base64')
15
+
16
+ return base64
17
+ };
src/utils/downloadFileToTmp.mts CHANGED
@@ -24,4 +24,6 @@ export const downloadFileToTmp = async (remoteUrl: string, fileName: string) =>
24
  filePath,
25
  Buffer.from(arrayBuffer)
26
  )
 
 
27
  }
 
24
  filePath,
25
  Buffer.from(arrayBuffer)
26
  )
27
+
28
+ return filePath
29
  }
src/utils/getFirstVideoFrame.mts ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import path from "node:path"
2
+
3
+ import ffmpeg from "fluent-ffmpeg"
4
+ import { v4 as uuidv4 } from "uuid"
5
+ import tmpDir from "temp-dir"
6
+
7
+ export async function getFirstVideoFrame(videoFilePath: string): Promise<string | void> {
8
+ const tmpFileName = `${uuidv4()}.jpg`
9
+
10
+ const tmpFilePath = path.resolve(tmpDir, tmpFileName)
11
+
12
+ return new Promise((resolve, reject) => {
13
+ ffmpeg(videoFilePath)
14
+ .outputOptions("-vframes 1")
15
+ .output(tmpFilePath)
16
+ .on("end", async () => {
17
+ resolve(tmpFilePath)
18
+ })
19
+ .on("error", reject)
20
+ .run()
21
+ })
22
+ }
src/utils/getFirstVideoFrameAsBase64.mts ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import fs from "node:fs"
2
+ import util from "node:util"
3
+ import path from "node:path"
4
+
5
+ import ffmpeg from "fluent-ffmpeg"
6
+ import { v4 as uuidv4 } from "uuid"
7
+ import tmpDir from "temp-dir"
8
+
9
+ const unlinkAsync = util.promisify(fs.unlink)
10
+
11
+ export async function getFirstVideoFrameAsBase64(videoPath: string): Promise<string | void> {
12
+ const tmpFileName = `${uuidv4()}.jpg`
13
+
14
+ const tmpFilePath = path.resolve(tmpDir, tmpFileName)
15
+
16
+ return new Promise((resolve, reject) => {
17
+ ffmpeg(videoPath)
18
+ .outputOptions("-vframes 1")
19
+ .output(tmpFilePath)
20
+ .on("end", async () => {
21
+ let base64;
22
+ try {
23
+ base64 = await fs.promises.readFile(tmpFilePath, { encoding: "base64" });
24
+ await unlinkAsync(tmpFilePath)
25
+ } catch(err) {
26
+ return reject(err)
27
+ }
28
+ resolve(base64)
29
+ })
30
+ .on("error", reject)
31
+ .run()
32
+ })
33
+ }
src/utils/segmentImage.mts CHANGED
@@ -1,15 +1,97 @@
1
- import { client } from "@gradio/client"
2
-
3
- const response_0 = await fetch("https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png")
4
- const exampleImage = await response_0.blob()
5
-
6
- const app = await client("https://jbilcke-hf-grounded-segment-anything.hf.space/")
7
- const result = await app.predict(0, [
8
- exampleImage, // blob in 'Upload' Image component
9
- "Howdy!", // string in 'Detection Prompt[To detect multiple objects, seperating each name with '.', like this: cat . dog . chair ]' Textbox component
10
- 0, // number (numeric value between 0.0 and 1.0) in 'Box Threshold' Slider component
11
- 0, // number (numeric value between 0.0 and 1.0) in 'Text Threshold' Slider component
12
- 0, // number (numeric value between 0.0 and 1.0) in 'IOU Threshold' Slider component
13
- ]) as any
14
-
15
- console.log(result.data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import puppeteer from "puppeteer"
2
+
3
+ import { sleep } from "./sleep.mts"
4
+ import { ImageSegment } from "../types.mts"
5
+ import { downloadImageAsBase64 } from "./downloadFileAsBase64.mts"
6
+
7
+ const instances: string[] = [
8
+ `${process.env.VC_SEGMENTATION_MODULE_SPACE_API_URL_1 || ""}`,
9
+ `${process.env.VC_SEGMENTATION_MODULE_SPACE_API_URL_2 || ""}`,
10
+ ]
11
+
12
+ // TODO we should use an inference endpoint instead
13
+
14
+ // note: on a large T4 (8 vCPU)
15
+ // it takes about 30 seconds to compute
16
+ export async function segmentImage(
17
+ inputImageFilePath: string,
18
+ actionnables: string[]
19
+ ): Promise<{
20
+ pngInBase64: string
21
+ segments: ImageSegment[]
22
+ }> {
23
+
24
+ console.log(`segmenting image..`)
25
+
26
+ const instance = instances.shift()
27
+ instances.push(instance)
28
+
29
+ const browser = await puppeteer.launch({
30
+ headless: true,
31
+ protocolTimeout: 70000,
32
+ })
33
+
34
+ const page = await browser.newPage()
35
+ await page.goto(instance, { waitUntil: 'networkidle2' })
36
+
37
+ await new Promise(r => setTimeout(r, 3000))
38
+
39
+ const fileField = await page.$('input[type="file"]')
40
+
41
+ // console.log(`uploading file..`)
42
+ await fileField.uploadFile(inputImageFilePath)
43
+
44
+ await sleep(500)
45
+
46
+ const firstTextarea = await page.$('textarea[data-testid="textbox"]')
47
+
48
+ const conceptsToDetect = actionnables.join(" . ")
49
+ await firstTextarea.type(conceptsToDetect)
50
+
51
+ // console.log('looking for the button to submit')
52
+ const submitButton = await page.$('button.lg')
53
+
54
+ await sleep(500)
55
+
56
+ // console.log('clicking on the button')
57
+ await submitButton.click()
58
+
59
+ await page.waitForSelector('img[data-testid="detailed-image"]', {
60
+ timeout: 70000, // need to be large enough in case someone else attemps to use our space
61
+ })
62
+
63
+ const maskUrl = await page.$$eval('img[data-testid="detailed-image"]', el => el.map(x => x.getAttribute("src"))[0])
64
+
65
+ let segments: ImageSegment[] = []
66
+
67
+ try {
68
+ segments = JSON.parse(await page.$$eval('textarea', el => el.map(x => x.value)[1]))
69
+ } catch (err) {
70
+ console.log(`failed to parse JSON: ${err}`)
71
+ segments = []
72
+ }
73
+
74
+ // const tmpMaskFileName = `${uuidv4()}.png`
75
+ // await downloadFileToTmp(maskUrl, tmpMaskFileName)
76
+
77
+ const pngInBase64 = await downloadImageAsBase64(maskUrl)
78
+ return {
79
+ pngInBase64,
80
+ segments,
81
+ }
82
+ }
83
+
84
+ /*
85
+
86
+ If you want to try:
87
+
88
+ / note: must be a jpg and not jpeg it seems
89
+ // (probably a playwright bug)
90
+ const results = await segmentImage("./barn.jpg", [
91
+ "roof",
92
+ "door",
93
+ "window"
94
+ ])
95
+
96
+ console.log("results:", results)
97
+ */
src/utils/segmentImageApi.mts ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { client } from "@gradio/client"
2
+
3
+
4
+ const response_0 = await fetch("https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png")
5
+ const exampleImage = await response_0.blob()
6
+
7
+ const app = await client("https://jbilcke-hf-image-segmentation.hf.space", {
8
+ hf_token: `${process.env.VC_HF_API_TOKEN}` as any
9
+ })
10
+ const result = await app.predict(0, [
11
+ exampleImage, // "", // blob in 'Upload' Image component
12
+ "Howdy!", // string in 'Detection Prompt[To detect multiple objects, seperating each name with '.', like this: cat . dog . chair ]' Textbox component
13
+ 0.3, // number (numeric value between 0.0 and 1.0) in 'Box Threshold' Slider component
14
+ 0.25, // number (numeric value between 0.0 and 1.0) in 'Text Threshold' Slider component
15
+ 0.8, // number (numeric value between 0.0 and 1.0) in 'IOU Threshold' Slider component
16
+ ]) as any
17
+
18
+ console.log(result)