Spaces:
Build error
Build error
import re | |
import requests | |
import pyarrow as pa | |
import librosa | |
import torch | |
from transformers import Wav2Vec2ForCTC, Wav2Vec2Tokenizer | |
from fastapi import FastAPI, File, UploadFile | |
import warnings | |
from starlette.formparsers import MultiPartParser | |
import io | |
import random | |
MultiPartParser.max_file_size = 200 * 1024 * 1024 | |
# Initialize FastAPI app | |
app = FastAPI() | |
# Load Wav2Vec2 tokenizer and model | |
tokenizer = Wav2Vec2Tokenizer.from_pretrained("./models/tokenizer") | |
model = Wav2Vec2ForCTC.from_pretrained("./models/model") | |
# Function to download English word list | |
def download_word_list(): | |
print("Downloading English word list...") | |
url = "https://raw.githubusercontent.com/dwyl/english-words/master/words_alpha.txt" | |
response = requests.get(url) | |
words = set(response.text.split()) | |
print("Word list downloaded.") | |
return words | |
english_words = download_word_list() | |
# Function to count correctly spelled words in text | |
def count_spelled_words(text, word_list): | |
print("Counting spelled words...") | |
# Split the text into words | |
words = re.findall(r'\b\w+\b', text.lower()) | |
correct = sum(1 for word in words if word in word_list) | |
incorrect = len(words) - correct | |
print("Spelling check complete.") | |
return incorrect, correct | |
# Function to apply spell check to an item (assuming it's a dictionary) | |
def apply_spell_check(item, word_list): | |
print("Applying spell check...") | |
if isinstance(item, dict): | |
# This is a single item | |
text = item['transcription'] | |
incorrect, correct = count_spelled_words(text, word_list) | |
item['incorrect_words'] = incorrect | |
item['correct_words'] = correct | |
print("Spell check applied to single item.") | |
return item | |
else: | |
# This is likely a batch | |
texts = item['transcription'] | |
results = [count_spelled_words(text, word_list) for text in texts] | |
incorrect_counts, correct_counts = zip(*results) | |
item = item.append_column('incorrect_words', pa.array(incorrect_counts)) | |
item = item.append_column('correct_words', pa.array(correct_counts)) | |
print("Spell check applied to batch of items.") | |
return item | |
# FastAPI routes | |
async def root(): | |
return "Welcome to the pronunciation scoring API!" | |
async def rnc(number): | |
return { | |
"your value:" , number | |
} | |
async def get_rnc(): | |
return random.randint(0 , 10) | |
async def unscripted_root(audio_file: UploadFile): | |
print("Pronunciation Scoring") | |
# Read the UploadFile into memory | |
contents = await audio_file.read() | |
print("Contents:" , contents) | |
# Create a BytesIO object from the contents | |
audio_bytes = io.BytesIO(contents) | |
print("audio_bytes:" , audio_bytes) | |
# Load the audio file using librosa | |
audio, sr = librosa.load(audio_bytes) | |
# Tokenize audio | |
print("Tokenizing audio...") | |
input_values = tokenizer(audio, return_tensors="pt").input_values | |
# Perform inference | |
print("Performing inference with Wav2Vec2 model...") | |
logits = model(input_values).logits | |
# Get predictions | |
print("Getting predictions...") | |
prediction = torch.argmax(logits, dim=-1) | |
# Decode predictions | |
print("Decoding predictions...") | |
transcription = tokenizer.batch_decode(prediction)[0] | |
# Convert transcription to lowercase | |
transcription = transcription.lower() | |
# Print transcription and word counts | |
print("Decoded transcription:", transcription) | |
incorrect, correct = count_spelled_words(transcription, english_words) | |
print("Spelling check - Incorrect words:", incorrect, ", Correct words:", correct) | |
# Calculate pronunciation score | |
fraction = correct / (incorrect + correct) | |
score = round(fraction * 100, 2) | |
print("Pronunciation score for", transcription, ":", score) | |
print("Pronunciation scoring process complete.") | |
return { | |
"transcription": transcription, | |
"pronunciation_score": score | |
} |