hash-map commited on
Commit
44938e7
·
verified ·
1 Parent(s): a32630e

Update infer.py

Browse files
Files changed (1) hide show
  1. infer.py +128 -152
infer.py CHANGED
@@ -1,153 +1,129 @@
1
- import gradio as gr
2
- import torch
3
- import json
4
- import numpy as np
5
- import os
6
- from datetime import datetime
7
- from model import Image2Phoneme
8
- from utils import ctc_post_process, audio_to_mel, mel_to_image, text_to_phonemes
9
- import soundfile as sf
10
- import shutil
11
- import pronouncing
12
- import time
13
-
14
- # Configuration
15
- DEVICE = torch.device("cpu")
16
- PHMAP = "phoneme_to_id.json"
17
- AUDIO_DIR = "audio_inputs"
18
-
19
- # Ensure audio directory exists
20
- os.makedirs(AUDIO_DIR, exist_ok=True)
21
-
22
- # Load phoneme vocabulary
23
- try:
24
- vocab = json.load(open(PHMAP, "r"))
25
- id_to_ph = {v: k for k, v in vocab.items()}
26
- except FileNotFoundError:
27
- raise FileNotFoundError(f"Phoneme mapping file not found at {PHMAP}")
28
-
29
- # Build model
30
- vocab_size = max(vocab.values()) + 1
31
- model = Image2Phoneme(vocab_size=vocab_size).to(DEVICE)
32
- try:
33
- ckpt = torch.load("last_checkpoint.pt", map_location=DEVICE, weights_only=True)
34
- model.load_state_dict(ckpt["model_state_dict"])
35
- model.eval()
36
- except FileNotFoundError:
37
- raise FileNotFoundError(f"Checkpoint file not found at last_checkpoint.pt")
38
-
39
- def process_audio(audio_input):
40
- """Process audio to predict phonemes and display mel spectrogram."""
41
- try:
42
- print(f"Received audio_input before processing: {audio_input}")
43
- # Generate unique filename based on timestamp
44
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
45
- audio_path = os.path.join(AUDIO_DIR, f"input_{timestamp}.wav")
46
-
47
- # Handle audio input
48
- if audio_input is None:
49
- print("Audio input is None after stopping recording")
50
- return {"error": "No audio input provided"}, None, None, None
51
-
52
- if isinstance(audio_input, str):
53
- # File upload: Copy the uploaded file to audio_inputs/
54
- print(f"Processing uploaded file: {audio_input}")
55
- if not os.path.exists(audio_input):
56
- return {"error": f"Uploaded file not found: {audio_input}"}, None, None, None
57
- if audio_input.endswith(".mp3"):
58
- print("Converting .mp3 to .wav")
59
- from pydub import AudioSegment
60
- audio = AudioSegment.from_mp3(audio_input)
61
- audio_path = audio_path.replace(".wav", "_converted.wav")
62
- audio.export(audio_path, format="wav")
63
- print(f"Converted file saved to: {audio_path}")
64
- else:
65
- shutil.copy(audio_input, audio_path)
66
- print(f"Copied file to: {audio_path}")
67
- else:
68
- # Microphone input: (sample_rate, audio_data)
69
- print("Processing microphone input")
70
- sample_rate, audio_data = audio_input
71
- print(f"Sample rate: {sample_rate}, Audio data shape: {audio_data.shape if hasattr(audio_data, 'shape') else 'None'}")
72
- if audio_data is None or len(audio_data) == 0:
73
- print("Microphone audio data is empty or invalid")
74
- return {"error": "Microphone input data is empty or invalid"}, None, None, None
75
- # Add a small delay to ensure audio data is fully captured
76
- time.sleep(1)
77
- sf.write(audio_path, audio_data, sample_rate)
78
- print(f"Saved microphone audio to: {audio_path}")
79
- # Verify the file exists
80
- if not os.path.exists(audio_path):
81
- print(f"Failed to save audio file at: {audio_path}")
82
- return {"error": "Failed to save recorded audio file"}, None, None, None
83
-
84
- # Process audio to mel spectrogram
85
- mel_path = audio_to_mel(audio_path)
86
- print(f"Generated mel spectrogram: {mel_path}")
87
- if not os.path.exists(mel_path):
88
- return {"error": f"Mel spectrogram file not found: {mel_path}"}, None, None, None
89
-
90
- mel_image_path = mel_to_image(mel_path)
91
- print(f"Generated mel spectrogram image: {mel_image_path}")
92
- if not os.path.exists(mel_image_path):
93
- return {"error": f"Mel spectrogram image not found: {mel_image_path}"}, None, None, None
94
-
95
- # Load mel spectrogram
96
- mel = np.load(mel_path) # shape (n_mels, T)
97
- print(f"Loaded mel spectrogram shape: {mel.shape}")
98
- mel_tensor = torch.tensor(mel).unsqueeze(0).to(DEVICE) # add batch dim
99
- mel_lens = torch.tensor([mel.shape[1]]).to(DEVICE)
100
-
101
- # Predict phonemes
102
- with torch.no_grad():
103
- ph_pred = model(mel_tensor) # shape (B, seq_len, vocab_size)
104
- ph_ids = ph_pred.argmax(-1)[0].cpu().numpy() # pick first batch
105
- print(f"Predicted phoneme IDs: {ph_ids}")
106
-
107
- # Convert IDs to phonemes
108
- ph_seq = [id_to_ph[i] for i in ph_ids if i > 0]
109
- print(f"Raw phonemes: {ph_seq}")
110
-
111
- # Post-process phonemes
112
- post_processed = ctc_post_process(ph_seq)
113
- print(f"Post-processed phonemes: {post_processed}")
114
-
115
- # Return results
116
- return {
117
- "audio_path": audio_path,
118
- "phonemes": " ".join(ph_seq),
119
- "post_processed_phonemes": " ".join(post_processed)
120
- }, mel_image_path, " ".join(ph_seq), " ".join(post_processed)
121
- except Exception as e:
122
- print(f"Error in process_audio: {str(e)}")
123
- return {"error": f"Processing failed: {str(e)}"}, None, None, None
124
-
125
- # Gradio interface
126
- with gr.Blocks() as iface:
127
- gr.Markdown("# Speech to Phonemes Converter")
128
- gr.Markdown("Record or upload audio to predict phonemes and display mel spectrogram. Paste input text if available.")
129
-
130
- audio_input = gr.Audio(sources=[ "upload"], type="filepath", label="Upload Audio (.wav or .mp3)", interactive=True)
131
- text_input = gr.Textbox(label="Enter Text", placeholder="Type a sentence to convert to phonemes")
132
- process_button = gr.Button("Process")
133
-
134
- audio_output = gr.JSON(label="Audio Processing Results (Audio Path, Phonemes, Post-Processed Phonemes)")
135
- mel_image = gr.Image(label="Mel Spectrogram", type="filepath")
136
- raw_phonemes = gr.Textbox(label="Raw Phonemes")
137
- post_processed_phonemes = gr.Textbox(label="Post-Processed Phonemes")
138
- text_output = gr.JSON(label="Text-to-Phoneme Results")
139
-
140
- def process(audio_input, text_input):
141
- print(f"Processing inputs - Audio: {audio_input}, Text: {text_input}")
142
- audio_result, mel_image_path, raw_ph, post_ph = process_audio(audio_input) if audio_input else ({}, None, None, None)
143
- text_result = text_to_phonemes(text_input) if text_input else {}
144
- return audio_result, mel_image_path, raw_ph, post_ph, text_result
145
-
146
- process_button.click(
147
- fn=process,
148
- inputs=[audio_input, text_input],
149
- outputs=[audio_output, mel_image, raw_phonemes, post_processed_phonemes, text_output]
150
- )
151
-
152
- if __name__ == "__main__":
153
  iface.launch(debug=True)
 
1
+ import gradio as gr
2
+ import torch
3
+ import json
4
+ import numpy as np
5
+ import os
6
+ from datetime import datetime
7
+ from model import Image2Phoneme
8
+ from utils import ctc_post_process, audio_to_mel, mel_to_image, text_to_phonemes
9
+ import soundfile as sf
10
+ import shutil
11
+ import time
12
+
13
+ # Configuration
14
+ DEVICE = torch.device("cpu")
15
+ PHMAP = "phoneme_to_id.json"
16
+ AUDIO_DIR = "audio_inputs"
17
+
18
+ # Ensure audio directory exists
19
+ os.makedirs(AUDIO_DIR, exist_ok=True)
20
+
21
+ # Load phoneme vocabulary
22
+ try:
23
+ vocab = json.load(open(PHMAP, "r"))
24
+ id_to_ph = {v: k for k, v in vocab.items()}
25
+ except FileNotFoundError:
26
+ raise FileNotFoundError(f"Phoneme mapping file not found at {PHMAP}")
27
+
28
+ # Build model
29
+ vocab_size = max(vocab.values()) + 1
30
+ model = Image2Phoneme(vocab_size=vocab_size).to(DEVICE)
31
+ try:
32
+ ckpt = torch.load("last_checkpoint.pt", map_location=DEVICE, weights_only=True)
33
+ model.load_state_dict(ckpt["model_state_dict"])
34
+ model.eval()
35
+ except FileNotFoundError:
36
+ raise FileNotFoundError(f"Checkpoint file not found at last_checkpoint.pt")
37
+
38
+ def process_audio(audio_input):
39
+ """Process audio to predict phonemes and display mel spectrogram."""
40
+ try:
41
+ print(f"Received audio_input before processing: {audio_input}")
42
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
43
+ audio_path = os.path.join(AUDIO_DIR, f"input_{timestamp}.wav")
44
+
45
+ if audio_input is None:
46
+ print("Audio input is None")
47
+ return {"error": "No audio input provided"}, None, None, None
48
+
49
+ if isinstance(audio_input, str):
50
+ print(f"Processing uploaded file: {audio_input}")
51
+ if not os.path.exists(audio_input):
52
+ return {"error": f"Uploaded file not found: {audio_input}"}, None, None, None
53
+ if audio_input.endswith(".mp3"):
54
+ print("Converting .mp3 to .wav")
55
+ from pydub import AudioSegment
56
+ audio = AudioSegment.from_mp3(audio_input)
57
+ audio_path = audio_path.replace(".wav", "_converted.wav")
58
+ audio.export(audio_path, format="wav")
59
+ print(f"Converted file saved to: {audio_path}")
60
+ else:
61
+ shutil.copy(audio_input, audio_path)
62
+ print(f"Copied file to: {audio_path}")
63
+ else:
64
+ raise ValueError("Microphone input not supported in this configuration")
65
+
66
+ mel_path = audio_to_mel(audio_path)
67
+ print(f"Generated mel spectrogram: {mel_path}")
68
+ if not os.path.exists(mel_path):
69
+ return {"error": f"Mel spectrogram file not found: {mel_path}"}, None, None, None
70
+
71
+ mel_image_path = mel_to_image(mel_path)
72
+ print(f"Generated mel spectrogram image: {mel_image_path}")
73
+ if not os.path.exists(mel_image_path):
74
+ return {"error": f"Mel spectrogram image not found: {mel_image_path}"}, None, None, None
75
+
76
+ mel = np.load(mel_path)
77
+ print(f"Loaded mel spectrogram shape: {mel.shape}")
78
+ mel_tensor = torch.tensor(mel).unsqueeze(0).to(DEVICE)
79
+ mel_lens = torch.tensor([mel.shape[1]]).to(DEVICE)
80
+
81
+ with torch.no_grad():
82
+ ph_pred = model(mel_tensor)
83
+ ph_ids = ph_pred.argmax(-1)[0].cpu().numpy()
84
+ print(f"Predicted phoneme IDs: {ph_ids}")
85
+
86
+ ph_seq = [id_to_ph[i] for i in ph_ids if i > 0]
87
+ print(f"Raw phonemes: {ph_seq}")
88
+
89
+ post_processed = ctc_post_process(ph_seq)
90
+ print(f"Post-processed phonemes: {post_processed}")
91
+
92
+ return {
93
+ "audio_path": audio_path,
94
+ "phonemes": " ".join(ph_seq),
95
+ "post_processed_phonemes": " ".join(post_processed)
96
+ }, mel_image_path, " ".join(ph_seq), " ".join(post_processed)
97
+ except Exception as e:
98
+ print(f"Error in process_audio: {str(e)}")
99
+ return {"error": f"Processing failed: {str(e)}"}, None, None, None
100
+
101
+ # Gradio interface
102
+ with gr.Blocks() as iface:
103
+ gr.Markdown("# Speech to Phonemes Converter")
104
+ gr.Markdown("Upload audio to predict phonemes and display mel spectrogram. Enter text to convert to phonemes.")
105
+
106
+ audio_input = gr.Audio(sources=["upload"], type="filepath", label="Upload Audio (.wav or .mp3)", interactive=True)
107
+ text_input = gr.Textbox(label="Enter Text", placeholder="Type a sentence to convert to phonemes")
108
+ process_button = gr.Button("Process")
109
+
110
+ audio_output = gr.JSON(label="Audio Processing Results (Audio Path, Phonemes, Post-Processed Phonemes)")
111
+ mel_image = gr.Image(label="Mel Spectrogram", type="filepath")
112
+ raw_phonemes = gr.Textbox(label="Raw Phonemes")
113
+ post_processed_phonemes = gr.Textbox(label="Post-Processed Phonemes")
114
+ text_output = gr.JSON(label="Text-to-Phoneme Results")
115
+
116
+ def process(audio_input, text_input):
117
+ print(f"Processing inputs - Audio: {audio_input}, Text: {text_input}")
118
+ audio_result, mel_image_path, raw_ph, post_ph = process_audio(audio_input) if audio_input else ({}, None, None, None)
119
+ text_result = text_to_phonemes(text_input) if text_input and text_input.strip() else {}
120
+ return audio_result, mel_image_path, raw_ph, post_ph, text_result
121
+
122
+ process_button.click(
123
+ fn=process,
124
+ inputs=[audio_input, text_input],
125
+ outputs=[audio_output, mel_image, raw_phonemes, post_processed_phonemes, text_output]
126
+ )
127
+
128
+ if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  iface.launch(debug=True)