jbilcke-hf HF staff commited on
Commit
8aa943e
β€’
1 Parent(s): e4fbf30

refactoring

Browse files
src/production/renderAnalysis.mts ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import { RenderedScene, RenderRequest } from "../types.mts"
3
+
4
+ import { renderImageAnalysis } from "./renderImageAnalysis.mts"
5
+
6
+ export async function renderAnalysis(request: RenderRequest, response: RenderedScene) {
7
+
8
+ if (request.analyze) {
9
+ const isVideo = request?.nbFrames > 1
10
+
11
+ // note: this only works on images for now,
12
+ // but we could also analyze the first video frame to get ourselves an idea
13
+ const optionalAnalysisFn = !isVideo
14
+ ? renderImageAnalysis(request, response)
15
+ : Promise.resolve()
16
+
17
+ await optionalAnalysisFn
18
+ }
19
+ }
src/production/renderContent.mts ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import { RenderedScene, RenderRequest } from "../types.mts"
3
+
4
+ import { renderImage } from "./renderImage.mts"
5
+ import { renderVideo } from "./renderVideo.mts"
6
+
7
+ export async function renderContent(request: RenderRequest, response: RenderedScene) {
8
+ const isVideo = request?.nbFrames > 1
9
+
10
+ const renderContentFn = isVideo
11
+ ? renderVideo
12
+ : renderImage
13
+
14
+ try {
15
+ await renderContentFn(request, response)
16
+ } catch (err) {
17
+ // console.log(`renderContent() failed, trying a 2nd time..`)
18
+ try {
19
+ await renderContentFn(request, response)
20
+ } catch (err2) {
21
+ // console.log(`renderContent() failed, trying a 3th time..`)
22
+ await renderContentFn(request, response)
23
+ }
24
+ }
25
+ }
src/production/renderPipeline.mts CHANGED
@@ -1,53 +1,19 @@
1
 
2
  import { RenderedScene, RenderRequest } from "../types.mts"
3
 
4
- import { renderImage } from "./renderImage.mts"
5
- import { renderVideo } from "./renderVideo.mts"
6
- import { renderImageSegmentation } from "./renderImageSegmentation.mts"
7
- import { renderVideoSegmentation } from "./renderVideoSegmentation.mts"
8
- import { renderImageUpscaling } from "./renderImageUpscaling.mts"
9
  import { saveRenderedSceneToCache } from "../utils/filesystem/saveRenderedSceneToCache.mts"
10
- import { renderImageAnalysis } from "./renderImageAnalysis.mts"
 
 
 
11
 
12
  export async function renderPipeline(request: RenderRequest, response: RenderedScene) {
13
- const isVideo = request?.nbFrames > 1
14
-
15
- const renderContent = isVideo ? renderVideo : renderImage
16
- const renderSegmentation = isVideo ? renderVideoSegmentation : renderImageSegmentation
17
-
18
- if (isVideo) {
19
- // console.log(`rendering a video..`)
20
- } else {
21
- // console.log(`rendering an image..`)
22
- }
23
-
24
- try {
25
- await renderContent(request, response)
26
- } catch (err) {
27
- // console.log(`renderContent() failed, trying a 2nd time..`)
28
- try {
29
- await renderContent(request, response)
30
- } catch (err2) {
31
- // console.log(`renderContent() failed, trying a 3th time..`)
32
- await renderContent(request, response)
33
- }
34
- }
35
-
36
- // we upscale images with esrgan
37
- // and for videos, well.. let's just skip this part,
38
- // but later we could use Zeroscope V2 XL maybe?
39
- const optionalUpscalingStep = isVideo
40
- ? Promise.resolve()
41
- : renderImageUpscaling(request, response)
42
-
43
- const optionalAnalysisStep = request.analyze
44
- ? renderImageAnalysis(request, response)
45
- : Promise.resolve()
46
 
47
  await Promise.all([
48
  renderSegmentation(request, response),
49
- optionalAnalysisStep,
50
- optionalUpscalingStep
51
  ])
52
 
53
  /*
 
1
 
2
  import { RenderedScene, RenderRequest } from "../types.mts"
3
 
 
 
 
 
 
4
  import { saveRenderedSceneToCache } from "../utils/filesystem/saveRenderedSceneToCache.mts"
5
+ import { renderSegmentation } from "./renderSegmentation.mts"
6
+ import { renderUpscaling } from "./renderUpscaling.mts"
7
+ import { renderContent } from "./renderContent.mts"
8
+ import { renderAnalysis } from "./renderAnalysis.mts"
9
 
10
  export async function renderPipeline(request: RenderRequest, response: RenderedScene) {
11
+ await renderContent(request, response)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  await Promise.all([
14
  renderSegmentation(request, response),
15
+ renderAnalysis(request, response),
16
+ renderUpscaling(request, response)
17
  ])
18
 
19
  /*
src/production/renderSegmentation.mts ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import { RenderedScene, RenderRequest } from "../types.mts"
3
+
4
+ import { renderImageSegmentation } from "./renderImageSegmentation.mts"
5
+ import { renderVideoSegmentation } from "./renderVideoSegmentation.mts"
6
+
7
+ export async function renderSegmentation(request: RenderRequest, response: RenderedScene) {
8
+
9
+ if (request.segmentation === "firstframe" || request.segmentation === "allframes") {
10
+ const isVideo = request?.nbFrames > 1
11
+
12
+ const renderSegmentationFn = isVideo
13
+ ? renderVideoSegmentation
14
+ : renderImageSegmentation
15
+
16
+ await renderSegmentationFn(request, response)
17
+ }
18
+ }
src/production/renderUpscaling.mts ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import { RenderedScene, RenderRequest } from "../types.mts"
3
+
4
+ import { renderImageUpscaling } from "./renderImageUpscaling.mts"
5
+ import { renderVideoUpscaling } from "./renderVideoUpscaling.mts"
6
+
7
+ export async function renderUpscaling(request: RenderRequest, response: RenderedScene) {
8
+
9
+ if (request.upscalingFactor > 1) {
10
+
11
+ const isVideo = request?.nbFrames > 1
12
+
13
+ // we upscale images with esrgan, and video with Zeroscope XL
14
+ const renderFn = isVideo
15
+ ? renderVideoUpscaling
16
+ : renderImageUpscaling
17
+
18
+ await renderFn(request, response)
19
+ }
20
+ }
src/production/renderVideoUpscaling.mts ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { upscaleVideoToBase64URL } from "../providers/video-upscaling/upscaleVideoToBase64URL.mts"
2
+ import { RenderedScene, RenderRequest } from "../types.mts"
3
+
4
+ export async function renderVideoUpscaling(
5
+ request: RenderRequest,
6
+ response: RenderedScene,
7
+ ): Promise<RenderedScene> {
8
+
9
+ try {
10
+ // note: this converts a base64 PNG to a base64 JPG (which is good, actually!)
11
+ response.assetUrl = await upscaleVideoToBase64URL(response.assetUrl, request.prompt)
12
+ // console.log(`upscaling worked on the first try!`)
13
+ } catch (err) {
14
+ // console.error(`upscaling failed the first time.. let's try again..`)
15
+ try {
16
+ response.assetUrl = await upscaleVideoToBase64URL(response.assetUrl, request.prompt)
17
+ // console.log(`upscaling worked on the second try!`)
18
+ } catch (err) {
19
+ console.error(`upscaling failed on the second attempt.. let's keep the low-res image then :|`)
20
+ // no need to log a catastrophic failure here, since we still have the original (low-res image)
21
+ // to work with
22
+ }
23
+ }
24
+
25
+ return response
26
+ }
src/providers/image-segmentation/segmentImage.mts CHANGED
@@ -2,7 +2,7 @@ import puppeteer from "puppeteer"
2
 
3
  import { sleep } from "../../utils/misc/sleep.mts"
4
  import { ImageSegment } from "../../types.mts"
5
- import { downloadImageAsBase64 } from "../../utils/download/downloadFileAsBase64.mts"
6
  import { resizeBase64Image } from "../../utils/image/resizeBase64Image.mts"
7
 
8
  // we don't use replicas yet, because it ain't easy to get their hostname
@@ -78,7 +78,7 @@ export async function segmentImage(
78
  // const tmpMaskFileName = `${uuidv4()}.png`
79
  // await downloadFileToTmp(maskUrl, tmpMaskFileName)
80
 
81
- const rawPngInBase64 = await downloadImageAsBase64(tmpMaskDownloadUrl)
82
 
83
  const maskUrl = await resizeBase64Image(rawPngInBase64, width, height)
84
 
 
2
 
3
  import { sleep } from "../../utils/misc/sleep.mts"
4
  import { ImageSegment } from "../../types.mts"
5
+ import { downloadFileAsBase64 } from "../../utils/download/downloadFileAsBase64.mts"
6
  import { resizeBase64Image } from "../../utils/image/resizeBase64Image.mts"
7
 
8
  // we don't use replicas yet, because it ain't easy to get their hostname
 
78
  // const tmpMaskFileName = `${uuidv4()}.png`
79
  // await downloadFileToTmp(maskUrl, tmpMaskFileName)
80
 
81
+ const rawPngInBase64 = await downloadFileAsBase64(tmpMaskDownloadUrl)
82
 
83
  const maskUrl = await resizeBase64Image(rawPngInBase64, width, height)
84
 
src/providers/video-upscaling/upscaleVideo.mts CHANGED
@@ -32,7 +32,7 @@ export async function upscaleVideo(fileName: string, prompt: string) {
32
  })
33
 
34
  const secretField = await page.$('input[type=text]')
35
- await secretField.type(prompt)
36
 
37
  const promptField = await page.$('textarea')
38
  await promptField.type(prompt)
 
32
  })
33
 
34
  const secretField = await page.$('input[type=text]')
35
+ await secretField.type(secretToken)
36
 
37
  const promptField = await page.$('textarea')
38
  await promptField.type(prompt)
src/providers/video-upscaling/upscaleVideoToBase64URL.mts ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import path from "node:path"
2
+
3
+
4
+ import puppeteer from "puppeteer"
5
+
6
+ import { pendingFilesDirFilePath } from '../../config.mts'
7
+ import { downloadFileAsBase64URL } from "../../utils/download/downloadFileAsBase64URL.mts"
8
+
9
+ const instances: string[] = [
10
+ `${process.env.VC_VIDEO_UPSCALE_SPACE_API_URL_1 || ""}`
11
+ ].filter(instance => instance?.length > 0)
12
+
13
+ const secretToken = `${process.env.VC_MICROSERVICE_SECRET_TOKEN || ""}`
14
+
15
+ // TODO we should use an inference endpoint instead (or a space which bakes generation + upscale at the same time)
16
+ export async function upscaleVideoToBase64URL(fileName: string, prompt: string) {
17
+ const instance = instances.shift()
18
+ instances.push(instance)
19
+
20
+ const browser = await puppeteer.launch({
21
+ // headless: true,
22
+ protocolTimeout: 800000,
23
+ })
24
+
25
+ try {
26
+ const page = await browser.newPage()
27
+
28
+ await page.goto(instance, {
29
+ waitUntil: 'networkidle2',
30
+ })
31
+
32
+ const secretField = await page.$('input[type=text]')
33
+ await secretField.type(secretToken)
34
+
35
+ const promptField = await page.$('textarea')
36
+ await promptField.type(prompt)
37
+
38
+ const inputFilePath = path.join(pendingFilesDirFilePath, fileName)
39
+ // console.log(`local file to upscale: ${inputFilePath}`)
40
+
41
+ // await new Promise(r => setTimeout(r, 1000))
42
+
43
+ const fileField = await page.$('input[type=file]')
44
+
45
+ // console.log(`uploading file..`)
46
+ await fileField.uploadFile(inputFilePath)
47
+
48
+ // console.log('looking for the button to submit')
49
+ const submitButton = await page.$('button.lg')
50
+
51
+ // console.log('clicking on the button')
52
+ await submitButton.click()
53
+
54
+ /*
55
+ const client = await page.target().createCDPSession()
56
+
57
+ await client.send('Page.setDownloadBehavior', {
58
+ behavior: 'allow',
59
+ downloadPath: tmpDir,
60
+ })
61
+ */
62
+
63
+ await page.waitForSelector('a[download="xl_result.mp4"]', {
64
+ timeout: 800000, // need to be large enough in case someone else attemps to use our space
65
+ })
66
+
67
+ const upscaledFileUrl = await page.$$eval('a[download="xl_result.mp4"]', el => el.map(x => x.getAttribute("href"))[0])
68
+
69
+ // we download the whole file
70
+ // it's only a few seconds of video, so it should be < 2MB
71
+ const assetUrl = await downloadFileAsBase64URL(upscaledFileUrl)
72
+
73
+ return assetUrl
74
+ } catch (err) {
75
+ throw err
76
+ } finally {
77
+ await browser.close()
78
+ }
79
+ }
src/utils/download/downloadFileAsBase64.mts CHANGED
@@ -1,4 +1,4 @@
1
- export const downloadImageAsBase64 = async (remoteUrl: string): Promise<string> => {
2
  const controller = new AbortController()
3
 
4
  // download the image
 
1
+ export const downloadFileAsBase64 = async (remoteUrl: string): Promise<string> => {
2
  const controller = new AbortController()
3
 
4
  // download the image
src/utils/download/downloadFileAsBase64URL.mts ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export const downloadFileAsBase64URL = async (remoteUrl: string): Promise<string> => {
2
+ const controller = new AbortController()
3
+
4
+ // download the file
5
+ const response = await fetch(remoteUrl, {
6
+ signal: controller.signal
7
+ })
8
+
9
+ // get as Buffer
10
+ const arrayBuffer = await response.arrayBuffer()
11
+ const buffer = Buffer.from(arrayBuffer)
12
+
13
+ // convert it to base64
14
+ const base64 = buffer.toString('base64')
15
+
16
+ const contentType = response.headers.get('content-type')
17
+
18
+ const assetUrl = `data:${contentType};base64,${base64}`
19
+ return assetUrl
20
+ };