Spaces:
				
			
			
	
			
			
		Sleeping
		
	
	
	
			
			
	
	
	
	
		
		
		Sleeping
		
	update
Browse files
    	
        examples/sound_classification_by_lstm/step_6_export_onnx_model.py
    CHANGED
    
    | @@ -120,9 +120,9 @@ def main(): | |
| 120 | 
             
                    "logits", "new_h", "new_c"
         | 
| 121 | 
             
                ]
         | 
| 122 | 
             
                logits, new_h, new_c = ort_session.run(output_names, input_feed)
         | 
| 123 | 
            -
                print(f"logits: {logits.shape}")
         | 
| 124 | 
            -
                print(f"new_h: {new_h.shape}")
         | 
| 125 | 
            -
                print(f"new_c: {new_c.shape}")
         | 
| 126 | 
             
                return
         | 
| 127 |  | 
| 128 |  | 
|  | |
| 120 | 
             
                    "logits", "new_h", "new_c"
         | 
| 121 | 
             
                ]
         | 
| 122 | 
             
                logits, new_h, new_c = ort_session.run(output_names, input_feed)
         | 
| 123 | 
            +
                # print(f"logits: {logits.shape}")
         | 
| 124 | 
            +
                # print(f"new_h: {new_h.shape}")
         | 
| 125 | 
            +
                # print(f"new_c: {new_c.shape}")
         | 
| 126 | 
             
                return
         | 
| 127 |  | 
| 128 |  | 
    	
        examples/sound_classification_by_lstm/step_8_test_onnx_model.py
    CHANGED
    
    | @@ -31,7 +31,8 @@ def get_args(): | |
| 31 | 
             
                )
         | 
| 32 | 
             
                parser.add_argument(
         | 
| 33 | 
             
                    "--wav_file",
         | 
| 34 | 
            -
                    default=r"C:\Users\tianx\Desktop\a073d03d-d280-46df-9b2d-d904965f4500_zh-CN_h3f25ivhb0c0_1719478037746.wav",
         | 
|  | |
| 35 | 
             
                    type=str
         | 
| 36 | 
             
                )
         | 
| 37 |  | 
| @@ -107,10 +108,23 @@ def main(): | |
| 107 | 
             
                    "logits", "new_h", "new_c"
         | 
| 108 | 
             
                ]
         | 
| 109 | 
             
                logits, new_h, new_c = ort_session.run(output_names, input_feed)
         | 
| 110 | 
            -
                print(f"logits: {logits.shape}")
         | 
| 111 | 
            -
                print(f"new_h: {new_h.shape}")
         | 
| 112 | 
            -
                print(f"new_c: {new_c.shape}")
         | 
| 113 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 114 | 
             
                return
         | 
| 115 |  | 
| 116 |  | 
|  | |
| 31 | 
             
                )
         | 
| 32 | 
             
                parser.add_argument(
         | 
| 33 | 
             
                    "--wav_file",
         | 
| 34 | 
            +
                    # default=r"C:\Users\tianx\Desktop\a073d03d-d280-46df-9b2d-d904965f4500_zh-CN_h3f25ivhb0c0_1719478037746.wav",
         | 
| 35 | 
            +
                    default=(project_path / "data/examples/examples/zh-TW/voicemail/00a1d109-23c2-4b8b-a066-993ac2ae8260_zh-TW_1672210785598.wav").as_posix(),
         | 
| 36 | 
             
                    type=str
         | 
| 37 | 
             
                )
         | 
| 38 |  | 
|  | |
| 108 | 
             
                    "logits", "new_h", "new_c"
         | 
| 109 | 
             
                ]
         | 
| 110 | 
             
                logits, new_h, new_c = ort_session.run(output_names, input_feed)
         | 
| 111 | 
            +
                # print(f"logits: {logits.shape}")
         | 
| 112 | 
            +
                # print(f"new_h: {new_h.shape}")
         | 
| 113 | 
            +
                # print(f"new_c: {new_c.shape}")
         | 
| 114 |  | 
| 115 | 
            +
                logits = torch.tensor(logits, dtype=torch.float32)
         | 
| 116 | 
            +
                probs = torch.nn.functional.softmax(logits, dim=-1)
         | 
| 117 | 
            +
                label_idx = torch.argmax(probs, dim=-1)
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                label_idx = label_idx.cpu()
         | 
| 120 | 
            +
                probs = probs.cpu()
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                label_idx = label_idx.numpy()[0]
         | 
| 123 | 
            +
                prob = probs.numpy()[0][label_idx]
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                label_str = vocabulary.get_token_from_index(label_idx, namespace="labels")
         | 
| 126 | 
            +
                print(label_str)
         | 
| 127 | 
            +
                print(prob)
         | 
| 128 | 
             
                return
         | 
| 129 |  | 
| 130 |  | 
    	
        tabs/cls_tab.py
    CHANGED
    
    | @@ -1,11 +1,14 @@ | |
| 1 | 
             
            #!/usr/bin/python3
         | 
| 2 | 
             
            # -*- coding: utf-8 -*-
         | 
| 3 | 
             
            import argparse
         | 
|  | |
| 4 | 
             
            from functools import lru_cache
         | 
|  | |
| 5 | 
             
            from pathlib import Path
         | 
| 6 | 
             
            import platform
         | 
| 7 | 
             
            import shutil
         | 
| 8 | 
             
            import tempfile
         | 
|  | |
| 9 | 
             
            import zipfile
         | 
| 10 | 
             
            from typing import Tuple
         | 
| 11 |  | 
| @@ -61,10 +64,12 @@ def when_click_cls_button(audio_t, | |
| 61 | 
             
                inputs = torch.tensor(inputs, dtype=torch.float32)
         | 
| 62 | 
             
                inputs = torch.unsqueeze(inputs, dim=0)
         | 
| 63 |  | 
|  | |
| 64 | 
             
                with torch.no_grad():
         | 
| 65 | 
             
                    logits = model.forward(inputs)
         | 
| 66 | 
             
                    probs = torch.nn.functional.softmax(logits, dim=-1)
         | 
| 67 | 
             
                    label_idx = torch.argmax(probs, dim=-1)
         | 
|  | |
| 68 |  | 
| 69 | 
             
                label_idx = label_idx.cpu()
         | 
| 70 | 
             
                probs = probs.cpu()
         | 
| @@ -74,7 +79,13 @@ def when_click_cls_button(audio_t, | |
| 74 |  | 
| 75 | 
             
                label_str = vocabulary.get_token_from_index(label_idx, namespace="labels")
         | 
| 76 |  | 
| 77 | 
            -
                 | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 78 |  | 
| 79 |  | 
| 80 | 
             
            def get_cls_tab(examples_dir: str, trained_model_dir: str):
         | 
| @@ -121,13 +132,12 @@ def get_cls_tab(examples_dir: str, trained_model_dir: str): | |
| 121 |  | 
| 122 | 
             
                            cls_button = gr.Button("run", variant="primary")
         | 
| 123 | 
             
                        with gr.Column(scale=3):
         | 
| 124 | 
            -
                             | 
| 125 | 
            -
                            cls_probability = gr.Number(label="probability")
         | 
| 126 |  | 
| 127 | 
             
                    gr.Examples(
         | 
| 128 | 
             
                        cls_examples,
         | 
| 129 | 
             
                        inputs=[cls_audio, cls_model_name, cls_ground_true],
         | 
| 130 | 
            -
                        outputs=[ | 
| 131 | 
             
                        fn=when_click_cls_button,
         | 
| 132 | 
             
                        examples_per_page=5,
         | 
| 133 | 
             
                    )
         | 
| @@ -135,7 +145,7 @@ def get_cls_tab(examples_dir: str, trained_model_dir: str): | |
| 135 | 
             
                    cls_button.click(
         | 
| 136 | 
             
                        when_click_cls_button,
         | 
| 137 | 
             
                        inputs=[cls_audio, cls_model_name, cls_ground_true],
         | 
| 138 | 
            -
                        outputs=[ | 
| 139 | 
             
                    )
         | 
| 140 |  | 
| 141 | 
             
                return locals()
         | 
|  | |
| 1 | 
             
            #!/usr/bin/python3
         | 
| 2 | 
             
            # -*- coding: utf-8 -*-
         | 
| 3 | 
             
            import argparse
         | 
| 4 | 
            +
            import json
         | 
| 5 | 
             
            from functools import lru_cache
         | 
| 6 | 
            +
            from os import times
         | 
| 7 | 
             
            from pathlib import Path
         | 
| 8 | 
             
            import platform
         | 
| 9 | 
             
            import shutil
         | 
| 10 | 
             
            import tempfile
         | 
| 11 | 
            +
            import time
         | 
| 12 | 
             
            import zipfile
         | 
| 13 | 
             
            from typing import Tuple
         | 
| 14 |  | 
|  | |
| 64 | 
             
                inputs = torch.tensor(inputs, dtype=torch.float32)
         | 
| 65 | 
             
                inputs = torch.unsqueeze(inputs, dim=0)
         | 
| 66 |  | 
| 67 | 
            +
                time_begin = time.time()
         | 
| 68 | 
             
                with torch.no_grad():
         | 
| 69 | 
             
                    logits = model.forward(inputs)
         | 
| 70 | 
             
                    probs = torch.nn.functional.softmax(logits, dim=-1)
         | 
| 71 | 
             
                    label_idx = torch.argmax(probs, dim=-1)
         | 
| 72 | 
            +
                time_cost = time.time() - time_begin
         | 
| 73 |  | 
| 74 | 
             
                label_idx = label_idx.cpu()
         | 
| 75 | 
             
                probs = probs.cpu()
         | 
|  | |
| 79 |  | 
| 80 | 
             
                label_str = vocabulary.get_token_from_index(label_idx, namespace="labels")
         | 
| 81 |  | 
| 82 | 
            +
                result = {
         | 
| 83 | 
            +
                    "label": label_str,
         | 
| 84 | 
            +
                    "prob": round(float(prob), 4),
         | 
| 85 | 
            +
                    "time_cost": round(time_cost, 4),
         | 
| 86 | 
            +
                }
         | 
| 87 | 
            +
                result = json.dumps(result, ensure_ascii=False, indent=4)
         | 
| 88 | 
            +
                return result
         | 
| 89 |  | 
| 90 |  | 
| 91 | 
             
            def get_cls_tab(examples_dir: str, trained_model_dir: str):
         | 
|  | |
| 132 |  | 
| 133 | 
             
                            cls_button = gr.Button("run", variant="primary")
         | 
| 134 | 
             
                        with gr.Column(scale=3):
         | 
| 135 | 
            +
                            cls_outputs = gr.Textbox(label="outputs", lines=1, max_lines=15)
         | 
|  | |
| 136 |  | 
| 137 | 
             
                    gr.Examples(
         | 
| 138 | 
             
                        cls_examples,
         | 
| 139 | 
             
                        inputs=[cls_audio, cls_model_name, cls_ground_true],
         | 
| 140 | 
            +
                        outputs=[cls_outputs],
         | 
| 141 | 
             
                        fn=when_click_cls_button,
         | 
| 142 | 
             
                        examples_per_page=5,
         | 
| 143 | 
             
                    )
         | 
|  | |
| 145 | 
             
                    cls_button.click(
         | 
| 146 | 
             
                        when_click_cls_button,
         | 
| 147 | 
             
                        inputs=[cls_audio, cls_model_name, cls_ground_true],
         | 
| 148 | 
            +
                        outputs=[cls_outputs],
         | 
| 149 | 
             
                    )
         | 
| 150 |  | 
| 151 | 
             
                return locals()
         | 
