wsntxxn commited on
Commit
487e498
1 Parent(s): df7102c

Update Clotho model

Browse files
app.py CHANGED
@@ -23,7 +23,7 @@ def load_model(cfg,
23
  return model, tokenizer
24
 
25
 
26
- def infer(file, device, model, tokenizer, target_sr):
27
  sr, wav = file
28
  wav = torch.as_tensor(wav)
29
  if wav.dtype == torch.short:
@@ -32,9 +32,9 @@ def infer(file, device, model, tokenizer, target_sr):
32
  wav = wav / 2 ** 31
33
  if wav.ndim > 1:
34
  wav = wav.mean(1)
35
- wav = resample(wav, sr, target_sr)
36
  wav_len = len(wav)
37
- wav = wav.float().unsqueeze(0).to(device)
38
  input_dict = {
39
  "mode": "inference",
40
  "wav": wav,
@@ -44,9 +44,9 @@ def infer(file, device, model, tokenizer, target_sr):
44
  "beam_size": 3,
45
  }
46
  with torch.no_grad():
47
- output_dict = model(input_dict)
48
  seq = output_dict["seq"].cpu().numpy()
49
- cap = tokenizer.decode(seq)[0]
50
  return cap
51
 
52
  # def input_toggle(input_type):
@@ -55,43 +55,47 @@ def infer(file, device, model, tokenizer, target_sr):
55
  # elif input_type == "mic":
56
  # return gr.update(visible=False), gr.update(visible=True)
57
 
 
58
 
59
- if __name__ == "__main__":
 
 
 
 
 
60
 
61
- parser = argparse.ArgumentParser()
62
- parser.add_argument("--share", action="store_true", default=False)
 
 
 
63
 
64
- args = parser.parse_args()
65
 
66
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
67
- exp_dir = Path("./checkpoints/audiocaps")
68
- cfg = train_util.load_config(exp_dir / "config.yaml")
69
- target_sr = cfg["target_sr"]
70
- model, tokenizer = load_model(cfg, exp_dir / "ckpt.pth", device)
71
 
72
- with gr.Blocks() as demo:
73
- with gr.Row():
74
- with gr.Column():
75
- # radio = gr.Radio(
76
- # ["file", "mic"],
77
- # value="file",
78
- # label="Select input type"
79
- # )
80
- file = gr.Audio(label="Input", visible=True)
81
- # mic = gr.Microphone(label="Input", visible=False)
82
- # radio.change(fn=input_toggle, inputs=radio, outputs=[file, mic])
83
- btn = gr.Button("Run")
84
- with gr.Column():
85
- output = gr.Textbox(label="Output")
86
- btn.click(
87
- fn=partial(infer,
88
- device=device,
89
- model=model,
90
- tokenizer=tokenizer,
91
- target_sr=target_sr),
92
- inputs=[file,],
93
- outputs=output
94
  )
95
-
96
- demo.launch(share=args.share)
 
 
 
 
 
 
 
 
 
 
 
 
97
 
 
23
  return model, tokenizer
24
 
25
 
26
+ def infer(file, runner):
27
  sr, wav = file
28
  wav = torch.as_tensor(wav)
29
  if wav.dtype == torch.short:
 
32
  wav = wav / 2 ** 31
33
  if wav.ndim > 1:
34
  wav = wav.mean(1)
35
+ wav = resample(wav, sr, runner.target_sr)
36
  wav_len = len(wav)
37
+ wav = wav.float().unsqueeze(0).to(runner.device)
38
  input_dict = {
39
  "mode": "inference",
40
  "wav": wav,
 
44
  "beam_size": 3,
45
  }
46
  with torch.no_grad():
47
+ output_dict = runner.model(input_dict)
48
  seq = output_dict["seq"].cpu().numpy()
49
+ cap = runner.tokenizer.decode(seq)[0]
50
  return cap
51
 
52
  # def input_toggle(input_type):
 
55
  # elif input_type == "mic":
56
  # return gr.update(visible=False), gr.update(visible=True)
57
 
58
+ class InferRunner:
59
 
60
+ def __init__(self, model_name):
61
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
62
+ exp_dir = Path(f"./checkpoints/{model_name.lower()}")
63
+ cfg = train_util.load_config(exp_dir / "config.yaml")
64
+ self.model, self.tokenizer = load_model(cfg, exp_dir / "ckpt.pth", self.device)
65
+ self.target_sr = cfg["target_sr"]
66
 
67
+ def change_model(self, model_name):
68
+ exp_dir = Path(f"./checkpoints/{model_name.lower()}")
69
+ cfg = train_util.load_config(exp_dir / "config.yaml")
70
+ self.model, self.tokenizer = load_model(cfg, exp_dir / "ckpt.pth", self.device)
71
+ self.target_sr = cfg["target_sr"]
72
 
 
73
 
74
+ def change_model(radio):
75
+ global infer_runner
76
+ infer_runner.change_model(radio)
 
 
77
 
78
+
79
+ with gr.Blocks() as demo:
80
+ with gr.Row():
81
+ with gr.Column():
82
+ radio = gr.Radio(
83
+ ["AudioCaps", "Clotho"],
84
+ value="AudioCaps",
85
+ label="Select model"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  )
87
+ infer_runner = InferRunner(radio.value)
88
+ file = gr.Audio(label="Input", visible=True)
89
+ radio.change(fn=change_model, inputs=[radio,],)
90
+ btn = gr.Button("Run")
91
+ with gr.Column():
92
+ output = gr.Textbox(label="Output")
93
+ btn.click(
94
+ fn=partial(infer,
95
+ runner=infer_runner),
96
+ inputs=[file,],
97
+ outputs=output
98
+ )
99
+
100
+ demo.launch()
101
 
checkpoints/clotho/ckpt.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:694c9e7139be7ec5aff2153d1af980d6bc305403a76be0d8940481579ea51483
3
+ size 54651005
checkpoints/clotho/config.yaml ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ tokenizer:
2
+ type: text_tokenizer.DictTokenizer
3
+ args:
4
+ max_length: 20
5
+
6
+ target_sr: 16000
7
+
8
+ model:
9
+ args:
10
+ shared_dim: 1024
11
+ tchr_dim: 768
12
+ model:
13
+ args: {}
14
+ decoder:
15
+ args:
16
+ attn_emb_dim: 1408
17
+ dropout: 0.2
18
+ emb_dim: 256
19
+ fc_emb_dim: 1408
20
+ nlayers: 2
21
+ tie_weights: true
22
+ vocab_size: 4368
23
+ type: models.transformer_decoder.TransformerDecoder
24
+ encoder:
25
+ args:
26
+ freeze: false
27
+ pretrained: true
28
+ type: models.cnn_encoder.EfficientNetB2
29
+ type: models.transformer_model.TransformerModel
30
+ type: models.kd_wrapper.ContraEncoderKdWrapper
models/eff_latent_encoder.py CHANGED
@@ -17,7 +17,7 @@ from einops import rearrange, reduce
17
  from torch.hub import load_state_dict_from_url
18
 
19
 
20
- model_dir = "./"
21
 
22
 
23
  class _EffiNet(nn.Module):
 
17
  from torch.hub import load_state_dict_from_url
18
 
19
 
20
+ model_dir = os.getcwd()
21
 
22
 
23
  class _EffiNet(nn.Module):
utils/model_util.py CHANGED
@@ -11,7 +11,7 @@ def sort_pack_padded_sequence(input, lengths):
11
  sorted_lengths, indices = torch.sort(lengths, descending=True)
12
  tmp = pack_padded_sequence(input[indices], sorted_lengths.cpu(), batch_first=True)
13
  inv_ix = indices.clone()
14
- inv_ix[indices] = torch.arange(0,len(indices)).type_as(inv_ix)
15
  return tmp, inv_ix
16
 
17
  def pad_unsort_packed_sequence(input, inv_ix):
 
11
  sorted_lengths, indices = torch.sort(lengths, descending=True)
12
  tmp = pack_padded_sequence(input[indices], sorted_lengths.cpu(), batch_first=True)
13
  inv_ix = indices.clone()
14
+ inv_ix[indices] = torch.arange(0, len(indices)).type_as(inv_ix)
15
  return tmp, inv_ix
16
 
17
  def pad_unsort_packed_sequence(input, inv_ix):
utils/train_util.py CHANGED
@@ -80,7 +80,7 @@ def merge_load_state_dict(state_dict,
80
  pretrained_dict[key] = value
81
  else:
82
  mismatch_keys.append(key)
83
- output_fn(f"Loading pre-trained model, with mismatched keys {mismatch_keys}")
84
  model_dict.update(pretrained_dict)
85
  model.load_state_dict(model_dict, strict=True)
86
  return pretrained_dict.keys()
 
80
  pretrained_dict[key] = value
81
  else:
82
  mismatch_keys.append(key)
83
+ output_fn(f"Loading pre-trained model, with mismatched keys {mismatch_keys}\n")
84
  model_dict.update(pretrained_dict)
85
  model.load_state_dict(model_dict, strict=True)
86
  return pretrained_dict.keys()