File size: 2,835 Bytes
cf32bd5
 
 
 
 
 
7138209
cf32bd5
 
cdf69a2
327e3b5
539da00
cf32bd5
 
 
327e3b5
cf32bd5
b2c1876
 
 
 
cf32bd5
 
 
 
327e3b5
 
 
cf32bd5
 
 
b1d0b95
 
 
 
 
cf32bd5
 
 
 
 
b2c1876
 
 
 
cf32bd5
b2c1876
cf32bd5
327e3b5
 
 
f7d8526
327e3b5
cf32bd5
 
bbb642a
 
 
 
 
 
 
cf32bd5
 
d73593f
cf32bd5
 
 
 
 
 
 
2fd5502
bbb642a
 
a3b778d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
# Gaepago model V1 (CPU Test)

# import package
from transformers import AutoModelForAudioClassification
from transformers import AutoFeatureExtractor
from transformers import pipeline
from datasets import Dataset, Audio
import gradio as gr
import torch
from utils.postprocess import text_mapping,text_encoding
import json
import os
# Set model & Dataset NM
MODEL_NAME = "Gae8J/gaepago-20"
DATASET_NAME = "Gae8J/modeling_v1"
TEXT_LABEL = "text_label.json"
# Import Model & feature extractor
# model = AutoModelForAudioClassification.from_pretrained(MODEL_NAME)
from transformers import AutoConfig
config = AutoConfig.from_pretrained(MODEL_NAME)
model = torch.jit.load(f"./model/gaepago-20-lite/model_quant_int8.pt")
feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_NAME)

# ๋ชจ๋ธ cpu๋กœ ๋ณ€๊ฒฝํ•˜์—ฌ ์ง„ํ–‰
model.to("cpu")
# TEXT LABEL ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
with open(TEXT_LABEL,"r",encoding='utf-8')  as f:
    text_label = json.load(f)

# Gaepago Inference Model function
def gaepago_fn(tmp_audio_dir):
    # if os.path.isfile(tmp_audio_dir):
    print(tmp_audio_dir)
    # else:
    #     ## khan test
    #     tmp_audio_dir = './sample/bark_sample.wav'
    audio_dataset = Dataset.from_dict({"audio": [tmp_audio_dir]}).cast_column("audio", Audio(sampling_rate=16000))
    inputs = feature_extractor(audio_dataset[0]["audio"]["array"]
                               ,sampling_rate=audio_dataset[0]["audio"]["sampling_rate"]
                               ,return_tensors="pt")
    with torch.no_grad():
        # logits = model(**inputs).logits
        logits = model(**inputs)["logits"]        
    # predicted_class_ids = torch.argmax(logits).item()
    # predicted_label = model.config.id2label[predicted_class_ids]
    predicted_class_ids = torch.argmax(logits).item()
    predicted_label = config.id2label[predicted_class_ids]
    
    # add postprocessing
    ## 1. text mapping
    output = text_mapping(predicted_label,text_label)
    # output = text_encoding(output)
    return output

# Main
example_list = ["./sample/bark_sample.wav"
                ,"./sample/growling_sample.wav"
                ,"./sample/howl_sample.wav"
                ,"./sample/panting_sample.wav"
                ,"./sample/whimper_sample.wav"
               ]

main_api = gr.Blocks()

with main_api as demo:
    gr.Markdown("## 8J Gaepago Demo(with CPU)")
    with gr.Row():
        audio = gr.Audio(source="microphone", type="filepath"
                         ,label='๋…น์Œ๋ฒ„ํŠผ์„ ๋ˆŒ๋Ÿฌ ์ดˆ์ฝ”๊ฐ€ ํ•˜๋Š” ๋ง์„ ๋“ค๋ ค์ฃผ์„ธ์š”')
        transcription = gr.Textbox(label='์ง€๊ธˆ ์ดˆ์ฝ”๊ฐ€ ํ•˜๋Š” ๋ง์€...')
    b1 = gr.Button("๊ฐ•์•„์ง€ ์–ธ์–ด ๋ฒˆ์—ญ!")

    b1.click(gaepago_fn, inputs=audio, outputs=transcription,api_name="predict")
    examples = gr.Examples(examples=example_list, inputs=[audio])

demo.launch(show_error=True)