File size: 4,686 Bytes
bfb646b
 
 
 
 
 
 
 
 
 
285d9a7
 
 
 
 
 
 
 
 
 
 
22458a0
4ad4f3c
22458a0
f65d62d
 
628cba2
 
 
 
 
 
e8470c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
628cba2
4ad4f3c
628cba2
 
 
 
 
 
 
 
 
29cbef3
 
 
 
 
 
 
 
4ad4f3c
 
285d9a7
 
 
 
 
59fd959
44f49c7
 
 
285d9a7
 
 
 
 
 
 
 
 
 
 
bfb646b
 
 
 
 
 
 
 
 
 
 
 
4ad4f3c
 
bfb646b
 
 
 
 
 
 
 
 
 
e8470c7
 
 
 
 
 
27065d5
bfb646b
af57b93
bfb646b
 
 
 
 
 
 
27065d5
bfb646b
 
 
d533cff
 
bfb646b
285d9a7
d533cff
f432d2b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import gradio as gr
import torch
import time
import librosa
import soundfile
import nemo.collections.asr as nemo_asr
import tempfile
import os
import uuid

from transformers import BlenderbotTokenizer, BlenderbotForConditionalGeneration
import torch

# PersistDataset -----
import os
import csv
import gradio as gr
from gradio import inputs, outputs
import huggingface_hub
from huggingface_hub import Repository, hf_hub_download, upload_file
from datetime import datetime

# ---------------------------------------------
# Dataset and Token links - change awacke1 to your own HF id, and add a HF_TOKEN copy to your repo for write permissions
# This should allow you to save your results to your own Dataset hosted on HF. 

DATASET_REPO_URL = "https://huggingface.co/datasets/awacke1/ASRLive.csv"
DATASET_REPO_ID = "awacke1/ASRLive.csv"
DATA_FILENAME = "ASRLive.csv"
DATA_FILE = os.path.join("data", DATA_FILENAME)
HF_TOKEN = os.environ.get("HF_TOKEN")

PersistToDataset = False
#PersistToDataset = True  # uncomment to save inference output to ASRLive.csv dataset

if PersistToDataset:
    try:
        hf_hub_download(
            repo_id=DATASET_REPO_ID,
            filename=DATA_FILENAME,
            cache_dir=DATA_DIRNAME,
            force_filename=DATA_FILENAME
        )
    except:
        print("file not found")
    repo = Repository(
        local_dir="data", clone_from=DATASET_REPO_URL, use_auth_token=HF_TOKEN
    )
           
def store_message(name: str, message: str):
    if name and message:
        with open(DATA_FILE, "a") as csvfile:
            writer = csv.DictWriter(csvfile, fieldnames=["name", "message", "time"])
            writer.writerow(
                {"name": name.strip(), "message": message.strip(), "time": str(datetime.now())}
            )
        # uncomment line below to begin saving - 
        commit_url = repo.push_to_hub()
        ret = ""
        with open(DATA_FILE, "r") as csvfile:
            reader = csv.DictReader(csvfile)
            
            for row in reader:
                ret += row
                ret += "\r\n"
    return ret            

# main -------------------------
mname = "facebook/blenderbot-400M-distill"
model = BlenderbotForConditionalGeneration.from_pretrained(mname)
tokenizer = BlenderbotTokenizer.from_pretrained(mname)

def take_last_tokens(inputs, note_history, history):
    filterTokenCount = 128 # filter last 128 tokens
    if inputs['input_ids'].shape[1] > filterTokenCount:
        inputs['input_ids'] = torch.tensor([inputs['input_ids'][0][-filterTokenCount:].tolist()])
        inputs['attention_mask'] = torch.tensor([inputs['attention_mask'][0][-filterTokenCount:].tolist()])
        note_history = ['</s> <s>'.join(note_history[0].split('</s> <s>')[2:])]
        history = history[1:]
    return inputs, note_history, history

def add_note_to_history(note, note_history):
    note_history.append(note)
    note_history = '</s> <s>'.join(note_history)
    return [note_history]



SAMPLE_RATE = 16000
model = nemo_asr.models.EncDecRNNTBPEModel.from_pretrained("nvidia/stt_en_conformer_transducer_xlarge")
model.change_decoding_strategy(None)
model.eval()

def process_audio_file(file):
    data, sr = librosa.load(file)
    if sr != SAMPLE_RATE:
        data = librosa.resample(data, orig_sr=sr, target_sr=SAMPLE_RATE)
    data = librosa.to_mono(data)
    return data


def transcribe(audio, state = ""):   
    if state is None:
        state = ""
    audio_data = process_audio_file(audio)
    with tempfile.TemporaryDirectory() as tmpdir:
        audio_path = os.path.join(tmpdir, f'audio_{uuid.uuid4()}.wav')
        soundfile.write(audio_path, audio_data, SAMPLE_RATE)
        transcriptions = model.transcribe([audio_path])
        if type(transcriptions) == tuple and len(transcriptions) == 2:
            transcriptions = transcriptions[0]
        transcriptions = transcriptions[0]
        
    if PersistToDataset:
        ret = store_message(transcriptions, state) # Save to dataset - uncomment to store into a dataset - hint you will need your HF_TOKEN
        state = state + transcriptions + " " + ret
    else:
        state = state + transcriptions
    return state, state

gr.Interface(
    fn=transcribe,
    inputs=[
        gr.Audio(source="microphone", type='filepath', streaming=True),
        "state",
    ],
    outputs=[
        "textbox",
        "state"
    ],
    layout="horizontal",
    theme="huggingface",
    title="🗣️ASR-Gradio-Live🧠💾",
    description=f"Live Automatic Speech Recognition (ASR).",
    allow_flagging='never',
    live=True,    
    article=f"Result💾 Dataset: [{DATASET_REPO_URL}]({DATASET_REPO_URL})"
).launch(debug=True)