File size: 4,305 Bytes
544e017
 
3a18b3b
 
 
 
 
 
 
 
7a0f405
1dfec92
544e017
 
3a18b3b
499b2c1
3a18b3b
 
1dfec92
 
3a18b3b
499b2c1
7bc4048
 
 
 
 
c492cbb
3a18b3b
 
1dfec92
 
3a18b3b
c492cbb
70da837
 
1dfec92
 
70da837
3a18b3b
544e017
1dfec92
 
 
 
 
ef107e3
1dfec92
 
 
 
 
544e017
1dfec92
6502e85
7a0f405
 
 
 
339c131
7a0f405
 
6502e85
71494c3
3da96bb
499b2c1
 
 
 
71494c3
3a18b3b
 
6502e85
3a18b3b
 
544e017
3da96bb
 
544e017
6502e85
544e017
 
6502e85
544e017
 
 
 
 
 
 
 
3a18b3b
 
 
 
 
 
 
544e017
1dfec92
 
 
 
 
 
81e83c9
1dfec92
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import torchaudio
import torch
from transformers import (
    WhisperProcessor,
    AutoProcessor,
    AutoModelForSpeechSeq2Seq,
    AutoModelForCTC,
    Wav2Vec2Processor,
    Wav2Vec2ForCTC
)
import numpy as np
import util

# Load processor and model
models_info = {
    "OpenAI-Whisper": {
        "processor": WhisperProcessor.from_pretrained("openai/whisper-small", language="uzbek", task="transcribe"),
        "model": AutoModelForSpeechSeq2Seq.from_pretrained("openai/whisper-small"),
        "ctc_model": False,
        "arabic_script": False
    },
    "Meta-MMS": {
        "processor": AutoProcessor.from_pretrained("facebook/mms-1b-all", target_lang='uig-script_arabic'),
        "model": AutoModelForCTC.from_pretrained("facebook/mms-1b-all", target_lang='uig-script_arabic', ignore_mismatched_sizes=True),
        "ctc_model": True,
        "arabic_script": True
    },
    "Ixxan-FineTuned-Whisper": {
        "processor": AutoProcessor.from_pretrained("ixxan/whisper-small-uyghur-common-voice"),
        "model": AutoModelForSpeechSeq2Seq.from_pretrained("ixxan/whisper-small-uyghur-common-voice"),
        "ctc_model": False,
        "arabic_script": False
    },
    "Ixxan-FineTuned-MMS": {
        "processor": Wav2Vec2Processor.from_pretrained("ixxan/wav2vec2-large-mms-1b-uyghur-latin", target_lang='uig-script_latin'),
        "model": Wav2Vec2ForCTC.from_pretrained("ixxan/wav2vec2-large-mms-1b-uyghur-latin", target_lang='uig-script_latin'),
        "ctc_model": True,
        "arabic_script": False
    },
}

# def transcribe(audio_data, model_id) -> str:
#     if model_id == "Compare All Models":
#         return transcribe_all_models(audio_data)
#     else:
#         return transcribe_with_model(audio_data, model_id)

# def transcribe_all_models(audio_data) -> dict:
#     transcriptions = {}
#     for model_id in models_info.keys():
#         transcriptions[model_id] = transcribe_with_model(audio_data, model_id)
#     return transcriptions

def transcribe(audio_data, model_id) -> str:
    # Load user audio
    if isinstance(audio_data, tuple):
        # microphone
        sampling_rate, audio_input = audio_data
        audio_input = (audio_input / 32768.0).astype(np.float32)
    elif isinstance(audio_data, str):
        # file upload
        audio_input, sampling_rate = torchaudio.load(audio_data)
    else: 
        return "<<ERROR: Invalid Audio Input Instance: {}>>".format(type(audio_data)), None

    # # Check audio duration
    # duration = audio_input.shape[1] / sampling_rate
    # if duration > 10:
    #     return f"<<ERROR: Audio duration ({duration:.2f}s) exceeds 10 seconds. Please upload a shorter audio clip for faster processing.>>", None
    
    model = models_info[model_id]["model"]
    processor = models_info[model_id]["processor"]
    target_sr = processor.feature_extractor.sampling_rate
    ctc_model = models_info[model_id]["ctc_model"]

    # Resample if needed
    if sampling_rate != target_sr:
        resampler = torchaudio.transforms.Resample(sampling_rate, target_sr)
        audio_input = resampler(audio_input)
        sampling_rate = target_sr

    # Preprocess the audio input
    inputs = processor(audio_input.squeeze(), sampling_rate=sampling_rate, return_tensors="pt")

    # Move model to GPU if available
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.to(device)
    inputs = {key: val.to(device) for key, val in inputs.items()}

    # Generate transcription
    with torch.no_grad():
        if ctc_model:
            logits = model(**inputs).logits
            predicted_ids = torch.argmax(logits, dim=-1)
            transcription = processor.batch_decode(predicted_ids)[0]
        else:
            generated_ids = model.generate(inputs["input_features"], max_length=225)
            transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
    
    if models_info[model_id]["arabic_script"]:
        transcription_arabic = transcription
        transcription_latin = util.ug_arab_to_latn(transcription)
    else: # Latin script output
        transcription_arabic = util.ug_latn_to_arab(transcription)
        transcription_latin = transcription
    print(model_id, transcription_arabic, transcription_latin)
    return transcription_arabic, transcription_latin