# Gaepago model V1 (CPU Test) # import package from transformers import AutoModelForAudioClassification from transformers import AutoFeatureExtractor from transformers import pipeline from datasets import Dataset, Audio import gradio as gr import torch from utils.postprocess import text_mapping,text_encoding import json import os # Set model & Dataset NM MODEL_NAME = "Gae8J/gaepago-20" DATASET_NAME = "Gae8J/modeling_v1" TEXT_LABEL = "text_label.json" # Import Model & feature extractor # model = AutoModelForAudioClassification.from_pretrained(MODEL_NAME) from transformers import AutoConfig config = AutoConfig.from_pretrained(MODEL_NAME) model = torch.jit.load(f"./model/gaepago-20-lite/model_quant_int8.pt") feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_NAME) # 모델 cpu로 변경하여 진행 model.to("cpu") # TEXT LABEL 불러오기 with open(TEXT_LABEL,"r",encoding='utf-8') as f: text_label = json.load(f) # Gaepago Inference Model function def gaepago_fn(tmp_audio_dir): # if os.path.isfile(tmp_audio_dir): print(tmp_audio_dir) # else: # ## khan test # tmp_audio_dir = './sample/bark_sample.wav' audio_dataset = Dataset.from_dict({"audio": [tmp_audio_dir]}).cast_column("audio", Audio(sampling_rate=16000)) inputs = feature_extractor(audio_dataset[0]["audio"]["array"] ,sampling_rate=audio_dataset[0]["audio"]["sampling_rate"] ,return_tensors="pt") with torch.no_grad(): # logits = model(**inputs).logits logits = model(**inputs)["logits"] # predicted_class_ids = torch.argmax(logits).item() # predicted_label = model.config.id2label[predicted_class_ids] predicted_class_ids = torch.argmax(logits).item() predicted_label = config.id2label[predicted_class_ids] # add postprocessing ## 1. text mapping output = text_mapping(predicted_label,text_label) # output = text_encoding(output) return output # Main example_list = ["./sample/bark_sample.wav" ,"./sample/growling_sample.wav" ,"./sample/howl_sample.wav" ,"./sample/panting_sample.wav" ,"./sample/whimper_sample.wav" ] main_api = gr.Blocks() with main_api as demo: gr.Markdown("## 8J Gaepago Demo(with CPU)") with gr.Row(): audio = gr.Audio(source="microphone", type="filepath" ,label='녹음버튼을 눌러 초코가 하는 말을 들려주세요') transcription = gr.Textbox(label='지금 초코가 하는 말은...') b1 = gr.Button("강아지 언어 번역!") b1.click(gaepago_fn, inputs=audio, outputs=transcription,api_name="predict") examples = gr.Examples(examples=example_list, inputs=[audio]) demo.launch(show_error=True)