Epsilon617 commited on
Commit
826be26
1 Parent(s): 5247bff

add model inference codes

Browse files
Files changed (3) hide show
  1. __pycache__/app.cpython-310.pyc +0 -0
  2. app.py +25 -8
  3. requirements.txt +88 -0
__pycache__/app.cpython-310.pyc CHANGED
Binary files a/__pycache__/app.cpython-310.pyc and b/__pycache__/app.cpython-310.pyc differ
 
app.py CHANGED
@@ -5,9 +5,21 @@ import torch
5
  from torch import nn
6
  import torchaudio
7
  import torchaudio.transforms as T
8
-
9
  # input cr: https://huggingface.co/spaces/thealphhamerc/audio-to-text/blob/main/app.py
10
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  inputs = [gr.components.Audio(type="filepath", label="Add music audio file"),
12
  gr.inputs.Audio(source="microphone",optional=True, type="filepath"),
13
  ]
@@ -17,8 +29,8 @@ title = "Output the tags of a (music) audio"
17
  description = "An example of using MERT-95M-public to conduct music tagging."
18
  article = ""
19
  audio_examples = [
20
- ["input/example-1.wav"],
21
- ["input/example-2.wav"],
22
  ]
23
 
24
  # Load the model
@@ -26,13 +38,14 @@ model = AutoModel.from_pretrained("m-a-p/MERT-v0-public", trust_remote_code=True
26
  # loading the corresponding preprocessor config
27
  processor = Wav2Vec2FeatureExtractor.from_pretrained("m-a-p/MERT-v0-public",trust_remote_code=True)
28
 
 
 
29
 
30
  def convert_audio(inputs, microphone):
31
  if (microphone is not None):
32
  inputs = microphone
33
 
34
  waveform, sample_rate = torchaudio.load(inputs)
35
-
36
 
37
  resample_rate = processor.sampling_rate
38
 
@@ -42,15 +55,19 @@ def convert_audio(inputs, microphone):
42
  resampler = T.Resample(sample_rate, resample_rate)
43
  waveform = resampler(waveform)
44
 
45
- inputs = processor(waveform, sampling_rate=resample_rate, return_tensors="pt")
 
 
46
  with torch.no_grad():
47
- outputs = model(**inputs, output_hidden_states=True)
48
 
49
  # take a look at the output shape, there are 13 layers of representation
50
  # each layer performs differently in different downstream tasks, you should choose empirically
51
- all_layer_hidden_states = torch.stack(outputs.hidden_states).squeeze()
52
  # print(all_layer_hidden_states.shape) # [13 layer, Time steps, 768 feature_dim]
53
- return str(all_layer_hidden_states.shape)
 
 
54
 
55
 
56
  # iface = gr.Interface(fn=convert_audio, inputs="audio", outputs="text")
 
5
  from torch import nn
6
  import torchaudio
7
  import torchaudio.transforms as T
8
+ import logging
9
  # input cr: https://huggingface.co/spaces/thealphhamerc/audio-to-text/blob/main/app.py
10
 
11
+
12
+ logger = logging.getLogger("whisper-jax-app")
13
+ logger.setLevel(logging.INFO)
14
+ ch = logging.StreamHandler()
15
+ ch.setLevel(logging.INFO)
16
+ formatter = logging.Formatter(
17
+ "%(asctime)s;%(levelname)s;%(message)s", "%Y-%m-%d %H:%M:%S")
18
+ ch.setFormatter(formatter)
19
+ logger.addHandler(ch)
20
+
21
+
22
+
23
  inputs = [gr.components.Audio(type="filepath", label="Add music audio file"),
24
  gr.inputs.Audio(source="microphone",optional=True, type="filepath"),
25
  ]
 
29
  description = "An example of using MERT-95M-public to conduct music tagging."
30
  article = ""
31
  audio_examples = [
32
+ # ["input/example-1.wav"],
33
+ # ["input/example-2.wav"],
34
  ]
35
 
36
  # Load the model
 
38
  # loading the corresponding preprocessor config
39
  processor = Wav2Vec2FeatureExtractor.from_pretrained("m-a-p/MERT-v0-public",trust_remote_code=True)
40
 
41
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
42
+ model.to(device)
43
 
44
  def convert_audio(inputs, microphone):
45
  if (microphone is not None):
46
  inputs = microphone
47
 
48
  waveform, sample_rate = torchaudio.load(inputs)
 
49
 
50
  resample_rate = processor.sampling_rate
51
 
 
55
  resampler = T.Resample(sample_rate, resample_rate)
56
  waveform = resampler(waveform)
57
 
58
+ waveform = waveform.view(-1,) # make it (n_sample, )
59
+ model_inputs = processor(waveform, sampling_rate=resample_rate, return_tensors="pt")
60
+ model_inputs.to(device)
61
  with torch.no_grad():
62
+ model_outputs = model(**model_inputs, output_hidden_states=True)
63
 
64
  # take a look at the output shape, there are 13 layers of representation
65
  # each layer performs differently in different downstream tasks, you should choose empirically
66
+ all_layer_hidden_states = torch.stack(model_outputs.hidden_states).squeeze()
67
  # print(all_layer_hidden_states.shape) # [13 layer, Time steps, 768 feature_dim]
68
+ # logger.warning(all_layer_hidden_states.shape)
69
+
70
+ return device + " :" + str(all_layer_hidden_states.shape)
71
 
72
 
73
  # iface = gr.Interface(fn=convert_audio, inputs="audio", outputs="text")
requirements.txt ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.1.0
2
+ aiohttp==3.8.4
3
+ aiosignal==1.3.1
4
+ altair==5.0.0
5
+ anyio==3.6.2
6
+ async-timeout==4.0.2
7
+ attrs==23.1.0
8
+ certifi==2023.5.7
9
+ charset-normalizer==3.1.0
10
+ click==8.1.3
11
+ cmake==3.26.3
12
+ contourpy==1.0.7
13
+ cycler==0.11.0
14
+ fastapi==0.95.2
15
+ ffmpy==0.3.0
16
+ filelock==3.12.0
17
+ fonttools==4.39.4
18
+ frozenlist==1.3.3
19
+ fsspec==2023.5.0
20
+ gradio==3.31.0
21
+ gradio_client==0.2.5
22
+ h11==0.14.0
23
+ httpcore==0.17.1
24
+ httpx==0.24.0
25
+ huggingface-hub==0.14.1
26
+ idna==3.4
27
+ Jinja2==3.1.2
28
+ jsonschema==4.17.3
29
+ kiwisolver==1.4.4
30
+ linkify-it-py==2.0.2
31
+ lit==16.0.5
32
+ markdown-it-py==2.2.0
33
+ MarkupSafe==2.1.2
34
+ matplotlib==3.7.1
35
+ mdit-py-plugins==0.3.3
36
+ mdurl==0.1.2
37
+ mpmath==1.3.0
38
+ multidict==6.0.4
39
+ networkx==3.1
40
+ nnAudio==0.3.2
41
+ numpy==1.24.3
42
+ nvidia-cublas-cu11==11.10.3.66
43
+ nvidia-cuda-cupti-cu11==11.7.101
44
+ nvidia-cuda-nvrtc-cu11==11.7.99
45
+ nvidia-cuda-runtime-cu11==11.7.99
46
+ nvidia-cudnn-cu11==8.5.0.96
47
+ nvidia-cufft-cu11==10.9.0.58
48
+ nvidia-curand-cu11==10.2.10.91
49
+ nvidia-cusolver-cu11==11.4.0.1
50
+ nvidia-cusparse-cu11==11.7.4.91
51
+ nvidia-nccl-cu11==2.14.3
52
+ nvidia-nvtx-cu11==11.7.91
53
+ orjson==3.8.12
54
+ packaging==23.1
55
+ pandas==2.0.1
56
+ Pillow==9.5.0
57
+ pydantic==1.10.7
58
+ pydub==0.25.1
59
+ Pygments==2.15.1
60
+ pyparsing==3.0.9
61
+ pyrsistent==0.19.3
62
+ python-dateutil==2.8.2
63
+ python-multipart==0.0.6
64
+ pytz==2023.3
65
+ PyYAML==6.0
66
+ regex==2023.5.5
67
+ requests==2.30.0
68
+ scipy==1.10.1
69
+ semantic-version==2.10.0
70
+ six==1.16.0
71
+ sniffio==1.3.0
72
+ starlette==0.27.0
73
+ sympy==1.12
74
+ tokenizers==0.13.3
75
+ toolz==0.12.0
76
+ torch==2.0.1
77
+ torchaudio==2.0.2
78
+ torchvision==0.15.2
79
+ tqdm==4.65.0
80
+ transformers==4.29.2
81
+ triton==2.0.0
82
+ typing_extensions==4.5.0
83
+ tzdata==2023.3
84
+ uc-micro-py==1.0.2
85
+ urllib3==2.0.2
86
+ uvicorn==0.22.0
87
+ websockets==11.0.3
88
+ yarl==1.9.2