dmedhi commited on
Commit
0e6eb34
1 Parent(s): f7b8f19

add app and dependencies

Browse files
Files changed (4) hide show
  1. .gitignore +2 -0
  2. app.py +41 -0
  3. requirements.txt +107 -0
  4. whisper.py +73 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # audio files
2
+ *.wav
app.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ from whisper import transcribe_audio
4
+
5
+
6
+ def transcribe(audio_file):
7
+ return transcribe_audio(audio_file)
8
+
9
+
10
+ def main():
11
+ st.set_page_config(page_title="Transcriber", page_icon="💬", layout="wide")
12
+ st.markdown(
13
+ """<h1 align="center";>Transcriber</h1>""",
14
+ unsafe_allow_html=True,
15
+ )
16
+ cols = st.columns(2)
17
+
18
+ with cols[0]:
19
+ with st.container(border=True, height=300):
20
+ audio_file = st.file_uploader(
21
+ label="Upload your audio",
22
+ type=["wav", "mp3"],
23
+ key="audio_file_uploader",
24
+ )
25
+ if audio_file:
26
+ st.audio(audio_file)
27
+ sub_btn = st.button("Run", key="sub_btn")
28
+ with cols[1]:
29
+ with st.container(border=True, height=400):
30
+ if sub_btn and audio_file:
31
+ st.text_area(
32
+ label="Transcribed text",
33
+ value=transcribe(audio_file.read())["text"],
34
+ height=350,
35
+ )
36
+ else:
37
+ st.info("Upload audio file", icon="💡")
38
+
39
+
40
+ if __name__ == "__main__":
41
+ main()
requirements.txt ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.31.0
2
+ aiofiles==23.2.1
3
+ altair==5.3.0
4
+ annotated-types==0.7.0
5
+ anyio==4.4.0
6
+ attrs==23.2.0
7
+ blinker==1.8.2
8
+ cachetools==5.3.3
9
+ certifi==2024.6.2
10
+ charset-normalizer==3.3.2
11
+ click==8.1.7
12
+ contourpy==1.2.1
13
+ cycler==0.12.1
14
+ dnspython==2.6.1
15
+ email_validator==2.1.1
16
+ exceptiongroup==1.2.1
17
+ fastapi==0.111.0
18
+ fastapi-cli==0.0.4
19
+ ffmpy==0.3.2
20
+ filelock==3.15.1
21
+ fonttools==4.53.0
22
+ fsspec==2024.6.0
23
+ gitdb==4.0.11
24
+ GitPython==3.1.43
25
+ h11==0.14.0
26
+ httpcore==1.0.5
27
+ httptools==0.6.1
28
+ httpx==0.27.0
29
+ huggingface-hub==0.23.3
30
+ idna==3.7
31
+ importlib_resources==6.4.0
32
+ Jinja2==3.1.4
33
+ jsonschema==4.22.0
34
+ jsonschema-specifications==2023.12.1
35
+ kiwisolver==1.4.5
36
+ markdown-it-py==3.0.0
37
+ MarkupSafe==2.1.5
38
+ matplotlib==3.9.0
39
+ mdurl==0.1.2
40
+ mpmath==1.3.0
41
+ networkx==3.3
42
+ numpy==1.26.4
43
+ nvidia-cublas-cu12==12.1.3.1
44
+ nvidia-cuda-cupti-cu12==12.1.105
45
+ nvidia-cuda-nvrtc-cu12==12.1.105
46
+ nvidia-cuda-runtime-cu12==12.1.105
47
+ nvidia-cudnn-cu12==8.9.2.26
48
+ nvidia-cufft-cu12==11.0.2.54
49
+ nvidia-curand-cu12==10.3.2.106
50
+ nvidia-cusolver-cu12==11.4.5.107
51
+ nvidia-cusparse-cu12==12.1.0.106
52
+ nvidia-nccl-cu12==2.20.5
53
+ nvidia-nvjitlink-cu12==12.5.40
54
+ nvidia-nvtx-cu12==12.1.105
55
+ orjson==3.10.5
56
+ packaging==24.1
57
+ pandas==2.2.2
58
+ pillow==10.3.0
59
+ protobuf==4.25.3
60
+ psutil==5.9.8
61
+ pyarrow==16.1.0
62
+ pydantic==2.7.4
63
+ pydantic_core==2.18.4
64
+ pydeck==0.9.1
65
+ pydub==0.25.1
66
+ Pygments==2.18.0
67
+ pyparsing==3.1.2
68
+ python-dateutil==2.9.0.post0
69
+ python-dotenv==1.0.1
70
+ python-multipart==0.0.9
71
+ pytz==2024.1
72
+ PyYAML==6.0.1
73
+ referencing==0.35.1
74
+ regex==2024.5.15
75
+ requests==2.32.3
76
+ rich==13.7.1
77
+ rpds-py==0.18.1
78
+ ruff==0.4.8
79
+ safetensors==0.4.3
80
+ semantic-version==2.10.0
81
+ shellingham==1.5.4
82
+ six==1.16.0
83
+ smmap==5.0.1
84
+ sniffio==1.3.1
85
+ starlette==0.37.2
86
+ streamlit==1.35.0
87
+ sympy==1.12.1
88
+ tenacity==8.3.0
89
+ tokenizers==0.19.1
90
+ toml==0.10.2
91
+ tomlkit==0.12.0
92
+ toolz==0.12.1
93
+ torch==2.3.1
94
+ tornado==6.4.1
95
+ tqdm==4.66.4
96
+ transformers==4.41.2
97
+ triton==2.3.1
98
+ typer==0.12.3
99
+ typing_extensions==4.12.2
100
+ tzdata==2024.1
101
+ ujson==5.10.0
102
+ urllib3==2.2.1
103
+ uvicorn==0.30.1
104
+ uvloop==0.19.0
105
+ watchdog==4.0.1
106
+ watchfiles==0.22.0
107
+ websockets==11.0.3
whisper.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
3
+
4
+
5
+ class Whisper:
6
+ """Whisper - audio transcriber class"""
7
+
8
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
9
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
10
+
11
+ def __init__(self, model_id: str = "openai/whisper-base") -> None:
12
+ self.model_id = model_id
13
+ self.model = AutoModelForSpeechSeq2Seq.from_pretrained(
14
+ model_id,
15
+ torch_dtype=self.torch_dtype,
16
+ low_cpu_mem_usage=True,
17
+ use_safetensors=True,
18
+ )
19
+ self.model.to(self.device)
20
+ self.processor = AutoProcessor.from_pretrained(model_id)
21
+
22
+ @property
23
+ def model_name(self):
24
+ """
25
+ Getter method for retrieving the model name.
26
+ """
27
+ return self.model_id
28
+
29
+ def save(self, save_dir: str):
30
+ """
31
+ Saves the model and processor to the specified directory.
32
+
33
+ Args:
34
+ save_dir (str): The directory where the model and processor will be saved.
35
+ """
36
+ self.model.save_pretrained(f"{save_dir}/model")
37
+ self.processor.save_pretrained(f"{save_dir}/processor")
38
+
39
+ def load(self, load_dir: str):
40
+ """
41
+ Load the model and processor from the specified directory.
42
+
43
+ Args:
44
+ load_dir (str): The directory from which to load the model and processor.
45
+ """
46
+ self.model = AutoModelForSpeechSeq2Seq.from_pretrained(f"{load_dir}/model")
47
+ self.processor = AutoProcessor.from_pretrained(f"{load_dir}/processor")
48
+
49
+ self.model.to(self.device)
50
+
51
+ def pipeline(self):
52
+ pipe = pipeline(
53
+ "automatic-speech-recognition",
54
+ model=self.model,
55
+ tokenizer=self.processor.tokenizer,
56
+ feature_extractor=self.processor.feature_extractor,
57
+ max_new_tokens=128,
58
+ chunk_length_s=15,
59
+ batch_size=16,
60
+ return_timestamps=True,
61
+ torch_dtype=self.torch_dtype,
62
+ device=self.device,
63
+ )
64
+
65
+ return pipe
66
+
67
+
68
+ def transcribe_audio(file):
69
+ whisper = Whisper()
70
+ pipe = whisper.pipeline()
71
+
72
+ result = pipe(file)
73
+ return result