Ken Lin
commited on
Commit
•
cf8e8a6
1
Parent(s):
3013126
Update
Browse files
app.py
CHANGED
@@ -18,7 +18,7 @@ transform = get_transform(image_size=image_size)
|
|
18 |
tag2text_model = tag2text(pretrained="tag2text_swin_14m.pth", image_size=image_size, vit='swin_b').eval().to(device)
|
19 |
|
20 |
|
21 |
-
def generate_music(raw_image):
|
22 |
raw_image = Image.fromarray(raw_image)
|
23 |
image = transform(raw_image).unsqueeze(0).to(device)
|
24 |
res = inference_tag2text(image, tag2text_model)
|
@@ -37,8 +37,11 @@ def generate_music(raw_image):
|
|
37 |
return_tensors="pt",
|
38 |
)
|
39 |
|
40 |
-
audio_values = model.generate(**inputs, max_new_tokens=256)
|
41 |
sampling_rate = model.audio_encoder.config.sampling_rate
|
|
|
|
|
|
|
|
|
42 |
target_dtype = np.int16
|
43 |
max_range = np.iinfo(target_dtype).max
|
44 |
audio_values = audio_values[0, 0].numpy()
|
@@ -49,7 +52,10 @@ iface = gr.Interface(
|
|
49 |
fn=generate_music,
|
50 |
title=title,
|
51 |
description=description,
|
52 |
-
inputs=
|
|
|
|
|
|
|
53 |
outputs=gr.Audio(label='Generated Music'))
|
54 |
|
55 |
iface.launch()
|
|
|
18 |
tag2text_model = tag2text(pretrained="tag2text_swin_14m.pth", image_size=image_size, vit='swin_b').eval().to(device)
|
19 |
|
20 |
|
21 |
+
def generate_music(raw_image, audio_length):
|
22 |
raw_image = Image.fromarray(raw_image)
|
23 |
image = transform(raw_image).unsqueeze(0).to(device)
|
24 |
res = inference_tag2text(image, tag2text_model)
|
|
|
37 |
return_tensors="pt",
|
38 |
)
|
39 |
|
|
|
40 |
sampling_rate = model.audio_encoder.config.sampling_rate
|
41 |
+
frame_rate = model.audio_encoder.config.frame_rate
|
42 |
+
max_new_tokens = int(frame_rate * audio_length)
|
43 |
+
audio_values = model.generate(**inputs, max_new_tokens=max_new_tokens)
|
44 |
+
|
45 |
target_dtype = np.int16
|
46 |
max_range = np.iinfo(target_dtype).max
|
47 |
audio_values = audio_values[0, 0].numpy()
|
|
|
52 |
fn=generate_music,
|
53 |
title=title,
|
54 |
description=description,
|
55 |
+
inputs=[
|
56 |
+
gr.Image(label="Painting"),
|
57 |
+
gr.Slider(5, 30, value=15, step=1, label="Audio length(sec)")
|
58 |
+
],
|
59 |
outputs=gr.Audio(label='Generated Music'))
|
60 |
|
61 |
iface.launch()
|