Ken Lin commited on
Commit
cf8e8a6
1 Parent(s): 3013126
Files changed (1) hide show
  1. app.py +9 -3
app.py CHANGED
@@ -18,7 +18,7 @@ transform = get_transform(image_size=image_size)
18
  tag2text_model = tag2text(pretrained="tag2text_swin_14m.pth", image_size=image_size, vit='swin_b').eval().to(device)
19
 
20
 
21
- def generate_music(raw_image):
22
  raw_image = Image.fromarray(raw_image)
23
  image = transform(raw_image).unsqueeze(0).to(device)
24
  res = inference_tag2text(image, tag2text_model)
@@ -37,8 +37,11 @@ def generate_music(raw_image):
37
  return_tensors="pt",
38
  )
39
 
40
- audio_values = model.generate(**inputs, max_new_tokens=256)
41
  sampling_rate = model.audio_encoder.config.sampling_rate
 
 
 
 
42
  target_dtype = np.int16
43
  max_range = np.iinfo(target_dtype).max
44
  audio_values = audio_values[0, 0].numpy()
@@ -49,7 +52,10 @@ iface = gr.Interface(
49
  fn=generate_music,
50
  title=title,
51
  description=description,
52
- inputs=gr.Image(label="Painting"),
 
 
 
53
  outputs=gr.Audio(label='Generated Music'))
54
 
55
  iface.launch()
 
18
  tag2text_model = tag2text(pretrained="tag2text_swin_14m.pth", image_size=image_size, vit='swin_b').eval().to(device)
19
 
20
 
21
+ def generate_music(raw_image, audio_length):
22
  raw_image = Image.fromarray(raw_image)
23
  image = transform(raw_image).unsqueeze(0).to(device)
24
  res = inference_tag2text(image, tag2text_model)
 
37
  return_tensors="pt",
38
  )
39
 
 
40
  sampling_rate = model.audio_encoder.config.sampling_rate
41
+ frame_rate = model.audio_encoder.config.frame_rate
42
+ max_new_tokens = int(frame_rate * audio_length)
43
+ audio_values = model.generate(**inputs, max_new_tokens=max_new_tokens)
44
+
45
  target_dtype = np.int16
46
  max_range = np.iinfo(target_dtype).max
47
  audio_values = audio_values[0, 0].numpy()
 
52
  fn=generate_music,
53
  title=title,
54
  description=description,
55
+ inputs=[
56
+ gr.Image(label="Painting"),
57
+ gr.Slider(5, 30, value=15, step=1, label="Audio length(sec)")
58
+ ],
59
  outputs=gr.Audio(label='Generated Music'))
60
 
61
  iface.launch()