KeenWoo commited on
Commit
f5a8bba
·
verified ·
1 Parent(s): 4146de9

Create utils.py

Browse files
Files changed (1) hide show
  1. utils.py +186 -0
utils.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # utils.py
2
+ # Contains shared utility functions for text processing, audio transcription,
3
+ # date/time handling, and image analysis that can be used by any assessment module.
4
+
5
+ import os
6
+ import re
7
+ import time
8
+ from datetime import datetime
9
+
10
+ import cv2
11
+ import nltk
12
+ import numpy as np
13
+ import pytz
14
+ import whisper
15
+ from scipy.io.wavfile import write as write_wav
16
+ from shapely.geometry import Polygon
17
+
18
+ # --- NLTK Setup ---
19
+ LOCAL_NLTK_DATA_PATH = os.path.join(os.path.dirname(__file__), 'nltk_data')
20
+ if LOCAL_NLTK_DATA_PATH not in nltk.data.path:
21
+ nltk.data.path.append(LOCAL_NLTK_DATA_PATH)
22
+
23
+ def download_nltk_data_if_needed(resource_name, download_name):
24
+ """Checks if NLTK data exists and downloads it if necessary."""
25
+ try:
26
+ nltk.data.find(resource_name)
27
+ except LookupError:
28
+ print(f"Downloading NLTK resource '{download_name}'...")
29
+ if not os.path.exists(LOCAL_NLTK_DATA_PATH):
30
+ os.makedirs(LOCAL_NLTK_DATA_PATH)
31
+ nltk.download(download_name, download_dir=LOCAL_NLTK_DATA_PATH)
32
+ print("Download complete.")
33
+
34
+ # Download necessary NLTK packages
35
+ download_nltk_data_if_needed('tokenizers/punkt', 'punkt')
36
+ download_nltk_data_if_needed('taggers/averaged_perceptron_tagger', 'averaged_perceptron_tagger')
37
+
38
+
39
+ # --- Whisper Model Loading ---
40
+ print("Loading Whisper transcription model...")
41
+ model = whisper.load_model("small")
42
+ print("Whisper model loaded.")
43
+
44
+ def transcribe(audio):
45
+ """Transcribes audio using the Whisper model."""
46
+ if audio is None:
47
+ return ""
48
+ sample_rate, y = audio
49
+ temp_wav_path = "/tmp/temp_audio.wav"
50
+ write_wav(temp_wav_path, sample_rate, y)
51
+ result = model.transcribe(temp_wav_path, language="en")
52
+ return result["text"]
53
+
54
+
55
+ # --- Date & Time Utilities ---
56
+ TARGET_TIMEZONE = pytz.timezone("America/New_York")
57
+ now_utc = datetime.now(pytz.utc)
58
+ now = now_utc.astimezone(TARGET_TIMEZONE)
59
+
60
+ def get_season(month):
61
+ """Determines the season in the Northern Hemisphere based on the month."""
62
+ if 3 <= month <= 5: return "spring"
63
+ elif 6 <= month <= 8: return "summer"
64
+ elif 9 <= month <= 11: return "fall"
65
+ else: return "winter"
66
+
67
+
68
+ # --- Text Normalization and Cleaning Dictionaries & Functions ---
69
+ WORD_TO_DIGIT = {
70
+ 'zero': '0', 'one': '1', 'two': '2', 'three': '3', 'four': '4', 'five': '5',
71
+ 'six': '6', 'seven': '7', 'eight': '8', 'nine': '9', 'ten': '10',
72
+ 'eleven': '11', 'twelve': '12', 'thirteen': '13', 'fourteen': '14',
73
+ 'fifteen': '15', 'sixteen': '16', 'seventeen': '17', 'eighteen': '18',
74
+ 'nineteen': '19', 'twenty': '20', 'thirty': '30', 'thirty one': '31',
75
+ 'ninety three': '93', 'eighty six': '86', 'seventy nine': '79',
76
+ 'seventy two': '72', 'sixty five': '65'
77
+ }
78
+
79
+ ORDINAL_TO_DIGIT = {
80
+ 'first': '1', 'second': '2', 'third': '3', 'fourth': '4', 'fifth': '5',
81
+ 'sixth': '6', 'seventh': '7', 'eighth': '8', 'ninth': '9', 'tenth': '10',
82
+ 'eleventh': '11', 'twelfth': '12', 'thirteenth': '13', 'fourteenth': '14',
83
+ 'fifteenth': '15', 'sixteenth': '16', 'seventeenth': '17', 'eighteenth': '18',
84
+ 'nineteenth': '19', 'twentieth': '20', 'thirtieth': '30', 'twenty-first': '21',
85
+ 'twenty-second': '22', 'twenty-third': '23', 'twenty-fourth': '24',
86
+ 'twenty-fifth': '25', 'twenty-sixth': '26', 'twenty-seventh': '27',
87
+ 'twenty-eighth': '28', 'twenty-ninth': '29', 'thirty-first': '31',
88
+ 'twenty first': '21', 'twenty second': '22', 'twenty third': '23',
89
+ 'twenty fourth': '24', 'twenty fifth': '25', 'twenty sixth': '26',
90
+ 'twenty seventh': '27', 'twenty eighth': '28', 'twenty ninth': '29',
91
+ 'thirty first': '31', '1st': '1', '2nd': '2', '3rd': '3', '4th': '4',
92
+ '5th': '5', '6th': '6', '7th': '7', '8th': '8', '9th': '9', '10th': '10',
93
+ '11th': '11', '12th': '12', '13th': '13', '14th': '14', '15th': '15',
94
+ '16th': '16', '17th': '17', '18th': '18', '19th': '19', '20th': '20',
95
+ '21st': '21', '22nd': '22', '23rd': '23', '24th': '24', '25th': '25',
96
+ '26th': '26', '27th': '27', '28th': '28', '29th': '29', '30th': '30', '31st': '31'
97
+ }
98
+
99
+ def clean_text_answer(text: str) -> str:
100
+ """A robust function to clean all text inputs before scoring."""
101
+ if not text: return ""
102
+ text = text.lower()
103
+ text = re.sub(r'[^\w\s]', '', text)
104
+ text = " ".join(text.split())
105
+ return text
106
+
107
+ def normalize_date_answer(text: str) -> str:
108
+ """Converts spoken ordinals and phrases into a clean numeric string for dates."""
109
+ if not text: return ""
110
+ clean_text = text.lower().strip()
111
+ if clean_text.startswith("the "):
112
+ clean_text = clean_text[4:]
113
+ for word, digit in ORDINAL_TO_DIGIT.items():
114
+ if word in clean_text:
115
+ clean_text = clean_text.replace(word, digit)
116
+ break
117
+ return re.sub(r'\D', '', clean_text)
118
+
119
+ def clean_numeric_answer(text: str) -> str:
120
+ """Removes all non-digit characters from a string."""
121
+ return re.sub(r'\D', '', text or "")
122
+
123
+ def normalize_numeric_words(text: str) -> str:
124
+ """Converts spoken number words in a string to digits."""
125
+ if not text: return ""
126
+ text = text.lower().strip()
127
+ for word, digit in WORD_TO_DIGIT.items():
128
+ text = re.sub(r'\b' + re.escape(word) + r'\b', digit, text)
129
+ return text
130
+
131
+
132
+ # --- Generic Scoring Utilities ---
133
+ def score_keyword_match(expected, user_input):
134
+ """Checks if any expected keywords (separated by '|') are in the user's answer."""
135
+ if not expected or not user_input:
136
+ return 0
137
+ cleaned_user = clean_text_answer(user_input)
138
+ possible_answers = expected.split('|')
139
+ for ans in possible_answers:
140
+ cleaned_ans = clean_text_answer(ans)
141
+ if cleaned_ans in cleaned_user:
142
+ return 1
143
+ return 0
144
+
145
+ def score_sentence_structure(raw_user_input):
146
+ """Checks for noun/verb in the original, un-cleaned text using NLTK."""
147
+ try:
148
+ text = nltk.word_tokenize(raw_user_input or "")
149
+ if len(text) < 2: return 0
150
+ pos_tags = nltk.pos_tag(text)
151
+ has_noun = any(tag.startswith('NN') for _, tag in pos_tags)
152
+ has_verb = any(tag.startswith('VB') for _, tag in pos_tags)
153
+ return 1 if has_noun and has_verb else 0
154
+ except Exception as e:
155
+ print(f"[NLTK ERROR] Failed to parse sentence: {e}")
156
+ return 0
157
+
158
+ def score_drawing(image_path, expected_sides):
159
+ """Scores a drawing by finding the number of sides of the smallest significant polygon."""
160
+ if not image_path or not os.path.exists(image_path):
161
+ return 0, 0
162
+ try:
163
+ img = cv2.imread(image_path)
164
+ gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
165
+ _, thresh = cv2.threshold(gray, 240, 255, cv2.THRESH_BINARY_INV)
166
+ contours, _ = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
167
+
168
+ significant_contours = [c for c in contours if cv2.contourArea(c) > 500]
169
+ if len(significant_contours) < 3:
170
+ return 0, 0 # Not enough shapes to form a valid intersection
171
+
172
+ min_area = float('inf')
173
+ sides_of_smallest_shape = 0
174
+ for contour in significant_contours:
175
+ area = cv2.contourArea(contour)
176
+ if area < min_area:
177
+ min_area = area
178
+ epsilon = 0.04 * cv2.arcLength(contour, True)
179
+ approx = cv2.approxPolyDP(contour, epsilon, True)
180
+ sides_of_smallest_shape = len(approx)
181
+
182
+ score = 1 if sides_of_smallest_shape == expected_sides else 0
183
+ return score, sides_of_smallest_shape
184
+ except Exception as e:
185
+ print(f"[OpenCV ERROR] Failed to process image: {e}")
186
+ return 0, 0