File size: 2,457 Bytes
7718032
 
 
 
 
 
450bd2c
 
 
 
 
 
 
 
 
 
 
 
 
7718032
b7d6c4c
 
 
 
 
 
 
 
 
 
 
7718032
 
b7d6c4c
 
7718032
 
b7d6c4c
6db627d
7718032
 
 
e0dcf02
 
cdb7851
1b3bbfe
cdb7851
e0dcf02
 
 
 
 
 
 
 
 
 
 
 
0f45386
 
85e7dfb
5178b9b
0f45386
 
 
e0dcf02
0f45386
 
 
7718032
0f45386
f946a20
7718032
 
 
5178b9b
450bd2c
7718032
450bd2c
7718032
 
 
 
 
 
 
 
 
5178b9b
7718032
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import gradio as gr

lpmc_client = gr.load("seungheondoh/LP-Music-Caps-demo", src="spaces")
from gradio_client import Client

client = Client("https://ysharma-explore-llamav2-with-tgi.hf.space/")

from diffusers import DiffusionPipeline
import torch

pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, use_safetensors=True, variant="fp16")
pipe.to("cuda")

# if using torch < 2.0
# pipe.enable_xformers_memory_efficient_attention()





from pydub import AudioSegment

def cut_audio(input_path, output_path, max_duration=30000):
    audio = AudioSegment.from_file(input_path)

    if len(audio) > max_duration:
        audio = audio[:max_duration]

    audio.export(output_path, format="mp3")

    return output_path

def infer(audio_file):

    truncated_audio = cut_audio(audio_file, "trunc_audio.mp3")
    
    cap_result = lpmc_client(
    				truncated_audio,	# str (filepath or URL to file) in 'audio_path' Audio component
    				api_name="predict"
    )
    print(cap_result)

    summarize_q = f"""

    I'll give you a list of music descriptions. Create a summary reflecting the musical ambiance. 
    Do not processs each segment, but provide a summary for the whole instead.
    
    Here's the list:

    {cap_result}
    """

    summary_result = client.predict(
    				summarize_q,	# str in 'Message' Textbox component
    				api_name="/chat_1"
    )

    print(f"SUMMARY: {summary_result}")

    llama_q = f"""

    I'll give you music description, then i want you to provide an illustrative image description that would fit well with the music.
    Answer with only one image description. Never do lists.

    Here's the music description :

    {summary_result}
    
    """

    result = client.predict(
    				llama_q,	# str in 'Message' Textbox component
    				api_name="/chat_1"
    )
    
    print(result)

    images = pipe(prompt=result).images[0]
    
    return cap_result, result, images

with gr.Blocks() as demo:
    with gr.Column(elem_id="col-container"):
        audio_input = gr.Audio(type="filepath", source="upload")
        infer_btn = gr.Button("Generate")
        lpmc_cap = gr.Textbox(label="Lp Music Caps caption")
        llama_trans_cap = gr.Textbox(label="Llama translation")
        img_result = gr.Video(label="Result")

    infer_btn.click(fn=infer, inputs=[audio_input], outputs=[lpmc_cap, llama_trans_cap, img_result])

demo.queue().launch()