DGSpitzer commited on
Commit
4b513b0
1 Parent(s): 94f21e6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +82 -40
app.py CHANGED
@@ -14,6 +14,7 @@ import torch
14
 
15
  from spectro import wav_bytes_from_spectrogram_image
16
  from diffusers import StableDiffusionPipeline
 
17
 
18
  import io
19
  from os import path
@@ -38,8 +39,10 @@ tips = {"en": "Tips: The input text will be translated into English for generati
38
 
39
  count = 0
40
 
 
41
  model_id = "runwayml/stable-diffusion-v1-5"
42
- pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
 
43
  pipe = pipe.to("cuda")
44
 
45
  model_id2 = "riffusion/riffusion-model-v1"
@@ -60,23 +63,23 @@ def translate_language(text_prompts):
60
  text_prompts = language_translation_model.translate(text_prompts, language_code, 'en')
61
  except Exception as e:
62
  error_text = str(e)
63
- return {status_text:error_text, language_tips_text:gr.update(visible=False)}
64
  if language_code in tips:
65
  tips_text = tips[language_code]
66
  else:
67
  tips_text = tips['en']
68
- if language_code == 'zh':
69
  return {language_tips_text:gr.update(visible=False), translated_language:text_prompts, trigger_component: gr.update(value=count, visible=False)}
70
  else:
71
  return {language_tips_text:gr.update(visible=True, value=tips_text), translated_language:text_prompts, trigger_component: gr.update(value=count, visible=False)}
72
 
73
 
74
 
75
- def get_result(text_prompts, style_indx, musicAI_indx):
76
  style = style_list_EN[style_indx]
77
  prompt = style + "," + text_prompts
78
 
79
- sdresult = pipe(prompt)
80
  image_output = sdresult.images[0] if not sdresult.nsfw_content_detected[0] else Image.open("nsfw_placeholder.jpg")
81
 
82
  print("Generated image with prompt " + prompt)
@@ -91,15 +94,18 @@ def get_result(text_prompts, style_indx, musicAI_indx):
91
 
92
  interrogate_prompt = img_to_text(imagefile, "ViT-L (best for Stable Diffusion 1.*)", "fast", fn_index=1)[0]
93
  print(interrogate_prompt)
94
- spec_image, music_output = get_music(interrogate_prompt + ", " + style_list_EN[style_indx], musicAI_indx)
95
 
96
  video_merged = merge_video(music_output, image_output)
97
- return {spec_result:spec_image, video_result:video_merged, status_text:'Success'}
98
-
99
 
100
- def get_music(prompt, musicAI_indx):
101
  if musicAI_indx == 0:
102
- spec = pipe2(prompt).images[0]
 
 
 
 
103
  print(spec)
104
  wav = wav_bytes_from_spectrogram_image(spec)
105
  with open("output.wav", "wb") as f:
@@ -148,7 +154,9 @@ def merge_video(mp3file_name, image):
148
  fps = 12
149
  slide_time = audio_length
150
  fourcc = cv2.VideoWriter.fourcc(*'MJPG')
151
- out = cv2.VideoWriter(file_name, fourcc, fps, (512, 512))
 
 
152
 
153
  # for image in img_list:
154
  # cv_img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
@@ -188,6 +196,11 @@ def merge_video(mp3file_name, image):
188
  mergedclip.to_videofile('mergedvideo.mp4')
189
  return 'mergedvideo.mp4'
190
 
 
 
 
 
 
191
  title="文生图生音乐视频 Text to Image to Music to Video with Riffusion"
192
 
193
  description="An AI art generation pipeline, which supports text-to-image-to-music task."
@@ -263,6 +276,22 @@ css = """
263
  font-weight: bold;
264
  font-size: 115%;
265
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266
  """
267
 
268
  block = gr.Blocks(css=css)
@@ -429,6 +458,7 @@ with block:
429
  </div>
430
  """
431
  )
 
432
  with gr.Group():
433
  with gr.Box():
434
  with gr.Row().style(mobile_collapse=False, equal_height=True):
@@ -437,6 +467,7 @@ with block:
437
  show_label=False,
438
  max_lines=1,
439
  placeholder="Enter your prompt, multiple languages are supported now.",
 
440
  ).style(
441
  border=(True, False, True, True),
442
  rounded=(True, False, False, True),
@@ -453,6 +484,7 @@ with block:
453
  '像素风格(Pixel Style)', '概念艺术(Conceptual Art)', '未来主义(Futurism)', '赛博朋克(Cyberpunk)', '写实风格(Realistic style)',
454
  '洛丽塔风格(Lolita style)', '巴洛克风格(Baroque style)', '超现实主义(Surrealism)', '默认(Default)'], value='默认(Default)', type="index")
455
  musicAI = gr.Dropdown(label="音乐生成技术(AI Music Generator)", choices=['Riffusion', 'Mubert AI'], value='Riffusion', type="index")
 
456
  status_text = gr.Textbox(
457
  label="处理状态(Process status)",
458
  show_label=True,
@@ -460,35 +492,45 @@ with block:
460
  interactive=False
461
  )
462
 
463
- video_result = gr.Video(type=None, label='Final Merged video')
464
- spec_result = gr.Image()
465
-
466
- trigger_component = gr.Textbox(vaule="", visible=False) # This component is used for triggering inference funtion.
467
- translated_language = gr.Textbox(vaule="", visible=False)
468
-
469
-
470
- ex = gr.Examples(examples=examples, fn=translate_language_example, inputs=[text, styles], outputs=[language_tips_text, status_text, trigger_component, translated_language], cache_examples=False)
471
- ex.dataset.headers = [""]
472
 
473
-
474
- text.submit(translate_language, inputs=[text], outputs=[language_tips_text, status_text, trigger_component, translated_language])
475
- btn.click(translate_language, inputs=[text], outputs=[language_tips_text, status_text, trigger_component, translated_language])
476
- trigger_component.change(fn=get_result, inputs=[translated_language, styles, musicAI], outputs=[spec_result, video_result, status_text])
477
-
478
-
479
- gr.Markdown(
480
- """
481
- Space by [@DGSpitzer](https://www.youtube.com/channel/UCzzsYBF4qwtMwJaPJZ5SuPg)❤️ [@大谷的游戏创作小屋](https://space.bilibili.com/176003)
482
- [![Twitter Follow](https://img.shields.io/twitter/follow/DGSpitzer?label=%40DGSpitzer&style=social)](https://twitter.com/DGSpitzer)
483
- ![visitors](https://visitor-badge.glitch.me/badge?page_id=dgspitzer_txt2img2video)
484
- """
485
- )
486
- gr.HTML('''
487
- <div class="footer">
488
- <p>Model:<a href="https://huggingface.co/riffusion/riffusion-model-v1" style="text-decoration: underline;" target="_blank">Riffusion</a>
489
- </p>
490
- </div>
491
- ''')
 
 
 
 
 
492
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
493
 
494
- block.queue(concurrency_count=128).launch()
 
14
 
15
  from spectro import wav_bytes_from_spectrogram_image
16
  from diffusers import StableDiffusionPipeline
17
+ from diffusers import EulerAncestralDiscreteScheduler
18
 
19
  import io
20
  from os import path
 
39
 
40
  count = 0
41
 
42
+
43
  model_id = "runwayml/stable-diffusion-v1-5"
44
+ eulera = EulerAncestralDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
45
+ pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16, scheduler=eulera)
46
  pipe = pipe.to("cuda")
47
 
48
  model_id2 = "riffusion/riffusion-model-v1"
 
63
  text_prompts = language_translation_model.translate(text_prompts, language_code, 'en')
64
  except Exception as e:
65
  error_text = str(e)
66
+ return {status_text:error_text, language_tips_text:gr.update(visible=False), translated_language:text_prompts, trigger_component: gr.update(value=count, visible=False)}
67
  if language_code in tips:
68
  tips_text = tips[language_code]
69
  else:
70
  tips_text = tips['en']
71
+ if language_code == 'en':
72
  return {language_tips_text:gr.update(visible=False), translated_language:text_prompts, trigger_component: gr.update(value=count, visible=False)}
73
  else:
74
  return {language_tips_text:gr.update(visible=True, value=tips_text), translated_language:text_prompts, trigger_component: gr.update(value=count, visible=False)}
75
 
76
 
77
 
78
+ def get_result(text_prompts, style_indx, musicAI_indx, duration):
79
  style = style_list_EN[style_indx]
80
  prompt = style + "," + text_prompts
81
 
82
+ sdresult = pipe(prompt, negative_prompt = "out of focus, scary, creepy, evil, disfigured, missing limbs, ugly, gross, missing fingers", num_inference_steps=50, guidance_scale=7, width=576, height=576)
83
  image_output = sdresult.images[0] if not sdresult.nsfw_content_detected[0] else Image.open("nsfw_placeholder.jpg")
84
 
85
  print("Generated image with prompt " + prompt)
 
94
 
95
  interrogate_prompt = img_to_text(imagefile, "ViT-L (best for Stable Diffusion 1.*)", "fast", fn_index=1)[0]
96
  print(interrogate_prompt)
97
+ spec_image, music_output = get_music(interrogate_prompt + ", " + style_list_EN[style_indx], musicAI_indx, duration)
98
 
99
  video_merged = merge_video(music_output, image_output)
100
+ return {spec_result:spec_image, video_result:video_merged, status_text:'Success', share_button:gr.update(visible=True), community_icon:gr.update(visible=True), loading_icon:gr.update(visible=True)}
 
101
 
102
+ def get_music(prompt, musicAI_indx, duration):
103
  if musicAI_indx == 0:
104
+ if duration == 5:
105
+ width_duration=512
106
+ else :
107
+ width_duration = 512 + ((int(duration)-5) * 128)
108
+ spec = pipe2(prompt, height=512, width=width_duration).images[0]
109
  print(spec)
110
  wav = wav_bytes_from_spectrogram_image(spec)
111
  with open("output.wav", "wb") as f:
 
154
  fps = 12
155
  slide_time = audio_length
156
  fourcc = cv2.VideoWriter.fourcc(*'MJPG')
157
+
158
+ #W, H should be the same as input image
159
+ out = cv2.VideoWriter(file_name, fourcc, fps, (576, 576))
160
 
161
  # for image in img_list:
162
  # cv_img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
 
196
  mergedclip.to_videofile('mergedvideo.mp4')
197
  return 'mergedvideo.mp4'
198
 
199
+ def change_music_generator(current_choice):
200
+ if current_choice == 0:
201
+ return gr.update(visible=True)
202
+ return gr.update(visible=False)
203
+
204
  title="文生图生音乐视频 Text to Image to Music to Video with Riffusion"
205
 
206
  description="An AI art generation pipeline, which supports text-to-image-to-music task."
 
276
  font-weight: bold;
277
  font-size: 115%;
278
  }
279
+ #share-btn-container {
280
+ display: flex; padding-left: 0.5rem !important; padding-right: 0.5rem !important; background-color: #000000; justify-content: center; align-items: center; border-radius: 9999px !important; width: 13rem;
281
+ }
282
+ #share-btn {
283
+ all: initial; color: #ffffff;font-weight: 600; cursor:pointer; font-family: 'IBM Plex Sans', sans-serif; margin-left: 0.5rem !important; padding-top: 0.25rem !important; padding-bottom: 0.25rem !important;right:0;
284
+ }
285
+ #share-btn * {
286
+ all: unset;
287
+ }
288
+ #share-btn-container div:nth-child(-n+2){
289
+ width: auto !important;
290
+ min-height: 0px !important;
291
+ }
292
+ #share-btn-container .wrap {
293
+ display: none !important;
294
+ }
295
  """
296
 
297
  block = gr.Blocks(css=css)
 
458
  </div>
459
  """
460
  )
461
+
462
  with gr.Group():
463
  with gr.Box():
464
  with gr.Row().style(mobile_collapse=False, equal_height=True):
 
467
  show_label=False,
468
  max_lines=1,
469
  placeholder="Enter your prompt, multiple languages are supported now.",
470
+ elem_id="input-prompt",
471
  ).style(
472
  border=(True, False, True, True),
473
  rounded=(True, False, False, True),
 
484
  '像素风格(Pixel Style)', '概念艺术(Conceptual Art)', '未来主义(Futurism)', '赛博朋克(Cyberpunk)', '写实风格(Realistic style)',
485
  '洛丽塔风格(Lolita style)', '巴洛克风格(Baroque style)', '超现实主义(Surrealism)', '默认(Default)'], value='默认(Default)', type="index")
486
  musicAI = gr.Dropdown(label="音乐生成技术(AI Music Generator)", choices=['Riffusion', 'Mubert AI'], value='Riffusion', type="index")
487
+ duration_input = gr.Slider(label="Duration in seconds", minimum=5, maximum=10, step=1, value=5, elem_id="duration-slider", visible=True)
488
  status_text = gr.Textbox(
489
  label="处理状态(Process status)",
490
  show_label=True,
 
492
  interactive=False
493
  )
494
 
 
 
 
 
 
 
 
 
 
495
 
496
+ with gr.Column(elem_id="col-container"):
497
+ with gr.Group(elem_id="share-btn-container"):
498
+ community_icon = gr.HTML(community_icon_html, visible=False)
499
+ loading_icon = gr.HTML(loading_icon_html, visible=False)
500
+ share_button = gr.Button("Share to community", elem_id="share-btn", visible=False)
501
+
502
+ share_button.click(None, [], [], _js=share_js)
503
+
504
+ video_result = gr.Video(type=None, label='Final Merged video', elem_id="output-video")
505
+ spec_result = gr.Image()
506
+
507
+ trigger_component = gr.Textbox(vaule="", visible=False) # This component is used for triggering inference funtion.
508
+ translated_language = gr.Textbox(vaule="", visible=False)
509
+
510
+
511
+ ex = gr.Examples(examples=examples, fn=translate_language_example, inputs=[text, styles], outputs=[language_tips_text, status_text, trigger_component, translated_language], cache_examples=False)
512
+ ex.dataset.headers = [""]
513
+
514
+
515
+ musicAI.change(fn=change_music_generator, inputs=[musicAI], outputs=[duration_input])
516
+ text.submit(translate_language, inputs=[text], outputs=[language_tips_text, status_text, trigger_component, translated_language])
517
+ btn.click(translate_language, inputs=[text], outputs=[language_tips_text, status_text, trigger_component, translated_language])
518
+ trigger_component.change(fn=get_result, inputs=[translated_language, styles, musicAI, duration_input], outputs=[spec_result, video_result, status_text, share_button, community_icon, loading_icon])
519
+
520
 
521
+ gr.Markdown(
522
+ """
523
+ Space by [@DGSpitzer](https://www.youtube.com/channel/UCzzsYBF4qwtMwJaPJZ5SuPg)❤️ [@大谷的游戏创作小屋](https://space.bilibili.com/176003)
524
+ [![Twitter Follow](https://img.shields.io/twitter/follow/DGSpitzer?label=%40DGSpitzer&style=social)](https://twitter.com/DGSpitzer)
525
+ ![visitors](https://visitor-badge.glitch.me/badge?page_id=dgspitzer_txt2img2video)
526
+ """
527
+ )
528
+ gr.HTML('''
529
+ <div class="footer">
530
+ <p>Model:<a href="https://huggingface.co/riffusion/riffusion-model-v1" style="text-decoration: underline;" target="_blank">Riffusion</a>
531
+ </p>
532
+ </div>
533
+ ''')
534
+
535
 
536
+ block.queue().launch()