File size: 2,835 Bytes
889a5cc
 
 
 
e710876
 
889a5cc
 
 
 
 
 
e710876
 
 
889a5cc
14b80ce
889a5cc
 
56bdf87
e710876
 
 
889a5cc
 
 
47af204
 
 
 
 
4872dd1
47af204
4872dd1
47af204
4872dd1
47af204
 
889a5cc
4872dd1
 
47af204
2cb87d0
889a5cc
 
47af204
15ad32c
889a5cc
e8db4c4
14b80ce
e8db4c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
889a5cc
 
47af204
 
 
14b80ce
47af204
4872dd1
48b5bc5
 
 
 
 
 
 
 
 
 
2cb87d0
48b5bc5
 
 
 
 
 
 
 
 
52e8cc5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
from funasr import AutoModel
from funasr.utils.postprocess_utils import rich_transcription_postprocess
from modelscope import snapshot_download

import datetime
import math
import io
import os
import tempfile
import json
from typing import Optional

from pyannote.audio import Audio, Pipeline
from pyannote.core import Segment

import torch
import gradio as gr

from config import model_config




device = "cuda:0" if torch.cuda.is_available() else "cpu"
model_dir = snapshot_download(model_config['model_dir'])

model = AutoModel(
    model=model_dir,
    trust_remote_code=False,
    remote_code="./model.py",
    vad_model="fsmn-vad",
    punc_model="ct-punc", 
    vad_kwargs={"max_single_segment_time": 30000},
    ncpu=torch.get_num_threads(),
    batch_size=1,
    hub="hf",
    device=device,
)

def transcribe_audio(file_path, vad_model="fsmn-vad", punc_model="ct-punc", vad_kwargs='{"max_single_segment_time": 30000}', 
                     batch_size=1, language="auto", use_itn=True, batch_size_s=60, 
                     merge_vad=True, merge_length_s=15, batch_size_threshold_s=50, 
                     hotword=" ", spk_model="cam++", ban_emo_unk=True):
    try:
        vad_kwargs = json.loads(vad_kwargs)
        
        temp_file_path = file_path

        res = model.generate(
            input=temp_file_path,
            cache={},
            language=language,
            use_itn=use_itn,
            batch_size_s=batch_size_s,
            merge_vad=merge_vad,
            merge_length_s=merge_length_s,
            batch_size_threshold_s=batch_size_threshold_s,
            hotword=hotword,
            spk_model=spk_model,
            ban_emo_unk=ban_emo_unk
        )

        text = rich_transcription_postprocess(res[0]["text"])
        
        return text

    except Exception as e:
        return str(e)

inputs = [
    gr.Audio(type="filepath"),
    gr.Textbox(value="fsmn-vad", label="VAD Model"),
    gr.Textbox(value="ct-punc", label="PUNC Model"),
    gr.Textbox(value='{"max_single_segment_time": 30000}', label="VAD Kwargs"),
    gr.Slider(1, 10, value=1, step=1, label="Batch Size"),
    gr.Textbox(value="auto", label="Language"),
    gr.Checkbox(value=True, label="Use ITN"),
    gr.Slider(30, 120, value=60, step=1, label="Batch Size (seconds)"),
    gr.Checkbox(value=True, label="Merge VAD"),
    gr.Slider(5, 60, value=15, step=1, label="Merge Length (seconds)"),
    gr.Slider(10, 100, value=50, step=1, label="Batch Size Threshold (seconds)"),
    gr.Textbox(value=" ", label="Hotword"),
    gr.Textbox(value="cam++", label="Speaker Model"),
    gr.Checkbox(value=True, label="Ban Emotional Unknown"),
]

outputs = gr.Textbox(label="Transcription")

gr.Interface(
    fn=transcribe_audio, 
    inputs=inputs, 
    outputs=outputs, 
    title="ASR Transcription with FunASR"
).launch()