Kartikeyssj2 commited on
Commit
10dd4bf
1 Parent(s): 2f6faa5

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +155 -138
main.py CHANGED
@@ -1,173 +1,190 @@
1
- import re
2
- import requests
3
- import pyarrow as pa
4
- import librosa
5
- import torch
6
- from transformers import Wav2Vec2ForCTC, Wav2Vec2Tokenizer
7
- from fastapi import FastAPI, File, UploadFile
8
- import warnings
9
- from starlette.formparsers import MultiPartParser
10
- import io
11
- import random
12
- import tempfile
13
- import os
14
- import numba
15
- import soundfile as sf
16
- import asyncio
17
-
18
- MultiPartParser.max_file_size = 200 * 1024 * 1024
19
-
20
- # Initialize FastAPI app
21
- app = FastAPI()
22
-
23
- # Load Wav2Vec2 tokenizer and model
24
- tokenizer = Wav2Vec2Tokenizer.from_pretrained("./models/tokenizer")
25
- model = Wav2Vec2ForCTC.from_pretrained("./models/model")
26
-
27
-
28
- # Function to download English word list
29
- def download_word_list():
30
- print("Downloading English word list...")
31
- url = "https://raw.githubusercontent.com/dwyl/english-words/master/words_alpha.txt"
32
- response = requests.get(url)
33
- words = set(response.text.split())
34
- print("Word list downloaded.")
35
- return words
36
-
37
- english_words = download_word_list()
38
-
39
- # Function to count correctly spelled words in text
40
- def count_spelled_words(text, word_list):
41
- print("Counting spelled words...")
42
- # Split the text into words
43
- words = re.findall(r'\b\w+\b', text.lower())
44
-
45
- correct = sum(1 for word in words if word in word_list)
46
- incorrect = len(words) - correct
47
-
48
- print("Spelling check complete.")
49
- return incorrect, correct
50
-
51
- # Function to apply spell check to an item (assuming it's a dictionary)
52
- def apply_spell_check(item, word_list):
53
- print("Applying spell check...")
54
- if isinstance(item, dict):
55
- # This is a single item
56
- text = item['transcription']
57
- incorrect, correct = count_spelled_words(text, word_list)
58
- item['incorrect_words'] = incorrect
59
- item['correct_words'] = correct
60
- print("Spell check applied to single item.")
61
- return item
62
- else:
63
- # This is likely a batch
64
- texts = item['transcription']
65
- results = [count_spelled_words(text, word_list) for text in texts]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
- incorrect_counts, correct_counts = zip(*results)
68
 
69
- item = item.append_column('incorrect_words', pa.array(incorrect_counts))
70
- item = item.append_column('correct_words', pa.array(correct_counts))
71
 
72
- print("Spell check applied to batch of items.")
73
- return item
74
 
75
- # FastAPI routes
76
- @app.get('/')
77
- async def root():
78
- return "Welcome to the pronunciation scoring API!"
79
 
80
- @app.post('/check_post')
81
- async def rnc(number):
82
- return {
83
- "your value:" , number
84
- }
85
 
86
- @app.get('/check_get')
87
- async def get_rnc():
88
- return random.randint(0 , 10)
89
 
90
 
91
- @app.post('/fluency_score')
92
- async def fluency_scoring(file: UploadFile = File(...)):
93
- audio_array, sample_rate = librosa.load(file.file, sr=16000)
94
- print(audio_array)
95
- return audio_array[:5]
96
 
97
 
98
- @app.post('/pronunciation_score')
99
- async def pronunciation_scoring(file: UploadFile = File(...)):
100
- print("loading the file")
101
- url = "https://speech-processing-6.onrender.com/process_audio"
102
- files = {'file': await file.read()}
103
 
104
- print("file loaded")
105
 
106
- # print(files)
107
 
108
- print("making a POST request on speech processor")
109
 
110
- # Make the POST request
111
- response = requests.post(url, files=files)
112
 
113
- audio = response.json().get('audio_array')
114
 
115
- print("audio:" , audio[:5])
116
 
117
 
118
 
119
- print("length of the audio array:" , len(audio))
120
 
121
- print("*" * 100)
122
 
123
- # Tokenization
124
- print("Tokenizing audio...")
125
- input_values = tokenizer(
126
- audio,
127
- return_tensors="pt",
128
- padding="max_length",
129
- max_length= 386380,
130
- truncation=True
131
- ).input_values
132
 
133
- print(input_values.shape)
134
 
135
- print("Tokenization complete. Shape of input_values:", input_values.shape)
136
 
137
- return "tokenization successful"
138
 
139
- # Perform inference
140
- print("Performing inference with Wav2Vec2 model...")
141
 
142
- logits = model(input_values).logits
143
 
144
- print("Inference complete. Shape of logits:", logits.shape)
145
 
146
- # Get predictions
147
- print("Getting predictions...")
148
- prediction = torch.argmax(logits, dim=-1)
149
- print("Prediction shape:", prediction.shape)
150
 
151
- # Decode predictions
152
- print("Decoding predictions...")
153
- transcription = tokenizer.batch_decode(prediction)[0]
154
 
155
- # Convert transcription to lowercase
156
- transcription = transcription.lower()
157
 
158
- print("Decoded transcription:", transcription)
159
 
160
- incorrect, correct = count_spelled_words(transcription, english_words)
161
- print("Spelling check - Incorrect words:", incorrect, ", Correct words:", correct)
162
 
163
- # Calculate pronunciation score
164
- fraction = correct / (incorrect + correct)
165
- score = round(fraction * 100, 2)
166
- print("Pronunciation score for", transcription, ":", score)
167
 
168
- print("Pronunciation scoring process complete.")
169
 
170
- return {
171
- "transcription": transcription,
172
- "pronunciation_score": score
173
- }
 
1
+ import soundfile as sf
2
+ import numpy as np
3
+
4
+ @app.post('/fluency_score')
5
+ async def fluency_scoring(file: UploadFile = File(...)):
6
+ with sf.SoundFile(file.file, 'r') as sound_file:
7
+ audio_array = sound_file.read(dtype="float32")
8
+ sample_rate = sound_file.samplerate
9
+
10
+ if sample_rate != 16000:
11
+ # Resample to 16000 Hz if needed
12
+ audio_array = librosa.resample(audio_array, sample_rate, 16000)
13
+
14
+ print(audio_array)
15
+ return audio_array[:5].tolist()
16
+
17
+
18
+ # import re
19
+ # import requests
20
+ # import pyarrow as pa
21
+ # import librosa
22
+ # import torch
23
+ # from transformers import Wav2Vec2ForCTC, Wav2Vec2Tokenizer
24
+ # from fastapi import FastAPI, File, UploadFile
25
+ # import warnings
26
+ # from starlette.formparsers import MultiPartParser
27
+ # import io
28
+ # import random
29
+ # import tempfile
30
+ # import os
31
+ # import numba
32
+ # import soundfile as sf
33
+ # import asyncio
34
+
35
+ # MultiPartParser.max_file_size = 200 * 1024 * 1024
36
+
37
+ # # Initialize FastAPI app
38
+ # app = FastAPI()
39
+
40
+ # # Load Wav2Vec2 tokenizer and model
41
+ # tokenizer = Wav2Vec2Tokenizer.from_pretrained("./models/tokenizer")
42
+ # model = Wav2Vec2ForCTC.from_pretrained("./models/model")
43
+
44
+
45
+ # # Function to download English word list
46
+ # def download_word_list():
47
+ # print("Downloading English word list...")
48
+ # url = "https://raw.githubusercontent.com/dwyl/english-words/master/words_alpha.txt"
49
+ # response = requests.get(url)
50
+ # words = set(response.text.split())
51
+ # print("Word list downloaded.")
52
+ # return words
53
+
54
+ # english_words = download_word_list()
55
+
56
+ # # Function to count correctly spelled words in text
57
+ # def count_spelled_words(text, word_list):
58
+ # print("Counting spelled words...")
59
+ # # Split the text into words
60
+ # words = re.findall(r'\b\w+\b', text.lower())
61
+
62
+ # correct = sum(1 for word in words if word in word_list)
63
+ # incorrect = len(words) - correct
64
+
65
+ # print("Spelling check complete.")
66
+ # return incorrect, correct
67
+
68
+ # # Function to apply spell check to an item (assuming it's a dictionary)
69
+ # def apply_spell_check(item, word_list):
70
+ # print("Applying spell check...")
71
+ # if isinstance(item, dict):
72
+ # # This is a single item
73
+ # text = item['transcription']
74
+ # incorrect, correct = count_spelled_words(text, word_list)
75
+ # item['incorrect_words'] = incorrect
76
+ # item['correct_words'] = correct
77
+ # print("Spell check applied to single item.")
78
+ # return item
79
+ # else:
80
+ # # This is likely a batch
81
+ # texts = item['transcription']
82
+ # results = [count_spelled_words(text, word_list) for text in texts]
83
 
84
+ # incorrect_counts, correct_counts = zip(*results)
85
 
86
+ # item = item.append_column('incorrect_words', pa.array(incorrect_counts))
87
+ # item = item.append_column('correct_words', pa.array(correct_counts))
88
 
89
+ # print("Spell check applied to batch of items.")
90
+ # return item
91
 
92
+ # # FastAPI routes
93
+ # @app.get('/')
94
+ # async def root():
95
+ # return "Welcome to the pronunciation scoring API!"
96
 
97
+ # @app.post('/check_post')
98
+ # async def rnc(number):
99
+ # return {
100
+ # "your value:" , number
101
+ # }
102
 
103
+ # @app.get('/check_get')
104
+ # async def get_rnc():
105
+ # return random.randint(0 , 10)
106
 
107
 
108
+ # @app.post('/fluency_score')
109
+ # async def fluency_scoring(file: UploadFile = File(...)):
110
+ # audio_array, sample_rate = librosa.load(file.file, sr=16000)
111
+ # print(audio_array)
112
+ # return audio_array[:5]
113
 
114
 
115
+ # @app.post('/pronunciation_score')
116
+ # async def pronunciation_scoring(file: UploadFile = File(...)):
117
+ # print("loading the file")
118
+ # url = "https://speech-processing-6.onrender.com/process_audio"
119
+ # files = {'file': await file.read()}
120
 
121
+ # print("file loaded")
122
 
123
+ # # print(files)
124
 
125
+ # print("making a POST request on speech processor")
126
 
127
+ # # Make the POST request
128
+ # response = requests.post(url, files=files)
129
 
130
+ # audio = response.json().get('audio_array')
131
 
132
+ # print("audio:" , audio[:5])
133
 
134
 
135
 
136
+ # print("length of the audio array:" , len(audio))
137
 
138
+ # print("*" * 100)
139
 
140
+ # # Tokenization
141
+ # print("Tokenizing audio...")
142
+ # input_values = tokenizer(
143
+ # audio,
144
+ # return_tensors="pt",
145
+ # padding="max_length",
146
+ # max_length= 386380,
147
+ # truncation=True
148
+ # ).input_values
149
 
150
+ # print(input_values.shape)
151
 
152
+ # print("Tokenization complete. Shape of input_values:", input_values.shape)
153
 
154
+ # return "tokenization successful"
155
 
156
+ # # Perform inference
157
+ # print("Performing inference with Wav2Vec2 model...")
158
 
159
+ # logits = model(input_values).logits
160
 
161
+ # print("Inference complete. Shape of logits:", logits.shape)
162
 
163
+ # # Get predictions
164
+ # print("Getting predictions...")
165
+ # prediction = torch.argmax(logits, dim=-1)
166
+ # print("Prediction shape:", prediction.shape)
167
 
168
+ # # Decode predictions
169
+ # print("Decoding predictions...")
170
+ # transcription = tokenizer.batch_decode(prediction)[0]
171
 
172
+ # # Convert transcription to lowercase
173
+ # transcription = transcription.lower()
174
 
175
+ # print("Decoded transcription:", transcription)
176
 
177
+ # incorrect, correct = count_spelled_words(transcription, english_words)
178
+ # print("Spelling check - Incorrect words:", incorrect, ", Correct words:", correct)
179
 
180
+ # # Calculate pronunciation score
181
+ # fraction = correct / (incorrect + correct)
182
+ # score = round(fraction * 100, 2)
183
+ # print("Pronunciation score for", transcription, ":", score)
184
 
185
+ # print("Pronunciation scoring process complete.")
186
 
187
+ # return {
188
+ # "transcription": transcription,
189
+ # "pronunciation_score": score
190
+ # }