IliaLarchenko commited on
Commit
e12b285
1 Parent(s): 81629c5

Huge refactoring

Browse files
Files changed (10) hide show
  1. api/audio.py +107 -37
  2. api/llm.py +59 -80
  3. app.py +40 -30
  4. tests/analysis.py +93 -60
  5. tests/candidate.py +25 -6
  6. tests/grader.py +82 -30
  7. tests/test_e2e.py +14 -4
  8. tests/test_models.py +20 -7
  9. ui/coding.py +1 -0
  10. utils/config.py +20 -10
api/audio.py CHANGED
@@ -7,6 +7,7 @@ import requests
7
  from openai import OpenAI
8
 
9
  from utils.errors import APIError, AudioConversionError
 
10
 
11
 
12
  class STTManager:
@@ -20,7 +21,13 @@ class STTManager:
20
  self.status = self.test_stt()
21
  self.streaming = self.test_streaming()
22
 
23
- def numpy_audio_to_bytes(self, audio_data):
 
 
 
 
 
 
24
  num_channels = 1
25
  sampwidth = 2
26
 
@@ -35,20 +42,34 @@ class STTManager:
35
  raise AudioConversionError(f"Error converting numpy array to audio bytes: {e}")
36
  return buffer.getvalue()
37
 
38
- def process_audio_chunk(self, audio, audio_buffer, transcript):
39
- """Process streamed audio data to accumulate and transcribe with overlapping segments."""
 
 
 
 
 
 
 
 
 
40
  audio_buffer = np.concatenate((audio_buffer, audio[1]))
41
 
42
  if len(audio_buffer) >= self.SAMPLE_RATE * self.CHUNK_LENGTH or len(audio_buffer) % (self.SAMPLE_RATE // 2) != 0:
43
  audio_bytes = self.numpy_audio_to_bytes(audio_buffer[: self.SAMPLE_RATE * self.CHUNK_LENGTH])
44
  audio_buffer = audio_buffer[self.SAMPLE_RATE * self.STEP_LENGTH :]
45
-
46
  new_transcript = self.speech_to_text_stream(audio_bytes)
47
  transcript = self.merge_transcript(transcript, new_transcript)
48
 
49
  return transcript, audio_buffer, transcript["text"]
50
 
51
- def speech_to_text_stream(self, audio):
 
 
 
 
 
 
52
  if self.config.stt.type == "HF_API":
53
  raise APIError("STT Error: Streaming not supported for this STT type")
54
  try:
@@ -57,18 +78,24 @@ class STTManager:
57
  transcription = client.audio.transcriptions.create(
58
  model=self.config.stt.name, file=data, response_format="verbose_json", timestamp_granularities=["word"]
59
  )
60
- except APIError as e:
61
  raise
62
  except Exception as e:
63
  raise APIError(f"STT Error: Unexpected error: {e}")
64
  return transcription.words
65
 
66
- def merge_transcript(self, transcript, new_transcript):
 
 
 
 
 
 
 
67
  cut_off = transcript["last_cutoff"]
68
  transcript["last_cutoff"] = self.MAX_RELIABILITY_CUTOFF - self.STEP_LENGTH
69
 
70
  transcript["words"] = transcript["words"][: len(transcript["words"]) - transcript["not_confirmed"]]
71
-
72
  transcript["not_confirmed"] = 0
73
  first_word = True
74
 
@@ -85,40 +112,55 @@ class STTManager:
85
  transcript["last_cutoff"] = max(1.0, word_dict["end"] - self.STEP_LENGTH)
86
 
87
  transcript["text"] = " ".join(transcript["words"])
88
-
89
  return transcript
90
 
91
- def speech_to_text_full(self, audio):
92
- audio = self.numpy_audio_to_bytes(audio[1])
 
 
 
 
 
 
93
  try:
94
  if self.config.stt.type == "OPENAI_API":
95
- data = ("temp.wav", audio, "audio/wav")
96
  client = OpenAI(base_url=self.config.stt.url, api_key=self.config.stt.key)
97
  transcription = client.audio.transcriptions.create(model=self.config.stt.name, file=data, response_format="text")
98
  elif self.config.stt.type == "HF_API":
99
  headers = {"Authorization": "Bearer " + self.config.stt.key}
100
- response = requests.post(self.config.stt.url, headers=headers, data=audio)
101
  if response.status_code != 200:
102
  error_details = response.json().get("error", "No error message provided")
103
  raise APIError("STT Error: HF API error", status_code=response.status_code, details=error_details)
104
  transcription = response.json().get("text", None)
105
  if transcription is None:
106
  raise APIError("STT Error: No transcription returned by HF API")
107
- except APIError as e:
108
  raise
109
  except Exception as e:
110
  raise APIError(f"STT Error: Unexpected error: {e}")
111
 
112
  return transcription
113
 
114
- def test_stt(self):
 
 
 
 
 
115
  try:
116
  self.speech_to_text_full((48000, np.zeros(10000)))
117
  return True
118
  except:
119
  return False
120
 
121
- def test_streaming(self):
 
 
 
 
 
122
  try:
123
  self.speech_to_text_stream(self.numpy_audio_to_bytes(np.zeros(10000)))
124
  return True
@@ -127,14 +169,30 @@ class STTManager:
127
 
128
 
129
  class TTSManager:
130
- def test_tts(self):
 
 
 
 
 
 
 
 
 
 
 
131
  try:
132
  self.read_text("Handshake")
133
  return True
134
  except:
135
  return False
136
 
137
- def test_tts_stream(self):
 
 
 
 
 
138
  try:
139
  for _ in self.read_text_stream("Handshake"):
140
  pass
@@ -142,19 +200,13 @@ class TTSManager:
142
  except:
143
  return False
144
 
145
- def __init__(self, config):
146
- self.config = config
147
- self.status = self.test_tts()
148
- if self.status:
149
- self.streaming = self.test_tts_stream()
150
- else:
151
- self.streaming = False
152
- if self.streaming:
153
- self.read_last_message = self.rlm_stream
154
- else:
155
- self.read_last_message = self.rlm
156
-
157
- def read_text(self, text):
158
  headers = {"Authorization": "Bearer " + self.config.tts.key}
159
  try:
160
  if self.config.tts.type == "OPENAI_API":
@@ -165,15 +217,21 @@ class TTSManager:
165
  if response.status_code != 200:
166
  error_details = response.json().get("error", "No error message provided")
167
  raise APIError(f"TTS Error: {self.config.tts.type} error", status_code=response.status_code, details=error_details)
168
- except APIError as e:
169
  raise
170
  except Exception as e:
171
  raise APIError(f"TTS Error: Unexpected error: {e}")
172
 
173
  return response.content
174
 
175
- def read_text_stream(self, text):
176
- if self.config.tts.type not in ["OPENAI_API"]:
 
 
 
 
 
 
177
  raise APIError("TTS Error: Streaming not supported for this TTS type")
178
  headers = {"Authorization": "Bearer " + self.config.tts.key}
179
  data = {"model": self.config.tts.name, "input": text, "voice": "alloy", "response_format": "opus"}
@@ -187,15 +245,27 @@ class TTSManager:
187
  yield from response.iter_content(chunk_size=1024)
188
  except StopIteration:
189
  pass
190
- except APIError as e:
191
  raise
192
  except Exception as e:
193
  raise APIError(f"TTS Error: Unexpected error: {e}")
194
 
195
- def rlm(self, chat_history):
 
 
 
 
 
 
196
  if len(chat_history) > 0 and chat_history[-1][1]:
197
  return self.read_text(chat_history[-1][1])
198
 
199
- def rlm_stream(self, chat_history):
 
 
 
 
 
 
200
  if len(chat_history) > 0 and chat_history[-1][1]:
201
  yield from self.read_text_stream(chat_history[-1][1])
 
7
  from openai import OpenAI
8
 
9
  from utils.errors import APIError, AudioConversionError
10
+ from typing import List, Dict, Optional, Generator, Tuple
11
 
12
 
13
  class STTManager:
 
21
  self.status = self.test_stt()
22
  self.streaming = self.test_streaming()
23
 
24
+ def numpy_audio_to_bytes(self, audio_data: np.ndarray) -> bytes:
25
+ """
26
+ Convert a numpy array of audio data to bytes.
27
+
28
+ :param audio_data: Numpy array containing audio data.
29
+ :return: Bytes representation of the audio data.
30
+ """
31
  num_channels = 1
32
  sampwidth = 2
33
 
 
42
  raise AudioConversionError(f"Error converting numpy array to audio bytes: {e}")
43
  return buffer.getvalue()
44
 
45
+ def process_audio_chunk(
46
+ self, audio: Tuple[int, np.ndarray], audio_buffer: np.ndarray, transcript: Dict
47
+ ) -> Tuple[Dict, np.ndarray, str]:
48
+ """
49
+ Process streamed audio data to accumulate and transcribe with overlapping segments.
50
+
51
+ :param audio: Tuple containing the sample rate and audio data as numpy array.
52
+ :param audio_buffer: Current audio buffer as numpy array.
53
+ :param transcript: Current transcript dictionary.
54
+ :return: Updated transcript, updated audio buffer, and transcript text.
55
+ """
56
  audio_buffer = np.concatenate((audio_buffer, audio[1]))
57
 
58
  if len(audio_buffer) >= self.SAMPLE_RATE * self.CHUNK_LENGTH or len(audio_buffer) % (self.SAMPLE_RATE // 2) != 0:
59
  audio_bytes = self.numpy_audio_to_bytes(audio_buffer[: self.SAMPLE_RATE * self.CHUNK_LENGTH])
60
  audio_buffer = audio_buffer[self.SAMPLE_RATE * self.STEP_LENGTH :]
 
61
  new_transcript = self.speech_to_text_stream(audio_bytes)
62
  transcript = self.merge_transcript(transcript, new_transcript)
63
 
64
  return transcript, audio_buffer, transcript["text"]
65
 
66
+ def speech_to_text_stream(self, audio: bytes) -> List[Dict[str, str]]:
67
+ """
68
+ Convert speech to text from a byte stream using streaming.
69
+
70
+ :param audio: Bytes representation of audio data.
71
+ :return: List of dictionaries containing transcribed words and their timestamps.
72
+ """
73
  if self.config.stt.type == "HF_API":
74
  raise APIError("STT Error: Streaming not supported for this STT type")
75
  try:
 
78
  transcription = client.audio.transcriptions.create(
79
  model=self.config.stt.name, file=data, response_format="verbose_json", timestamp_granularities=["word"]
80
  )
81
+ except APIError:
82
  raise
83
  except Exception as e:
84
  raise APIError(f"STT Error: Unexpected error: {e}")
85
  return transcription.words
86
 
87
+ def merge_transcript(self, transcript: Dict, new_transcript: List[Dict[str, str]]) -> Dict:
88
+ """
89
+ Merge new transcript data with the existing transcript.
90
+
91
+ :param transcript: Existing transcript dictionary.
92
+ :param new_transcript: New transcript data to merge.
93
+ :return: Updated transcript dictionary.
94
+ """
95
  cut_off = transcript["last_cutoff"]
96
  transcript["last_cutoff"] = self.MAX_RELIABILITY_CUTOFF - self.STEP_LENGTH
97
 
98
  transcript["words"] = transcript["words"][: len(transcript["words"]) - transcript["not_confirmed"]]
 
99
  transcript["not_confirmed"] = 0
100
  first_word = True
101
 
 
112
  transcript["last_cutoff"] = max(1.0, word_dict["end"] - self.STEP_LENGTH)
113
 
114
  transcript["text"] = " ".join(transcript["words"])
 
115
  return transcript
116
 
117
+ def speech_to_text_full(self, audio: Tuple[int, np.ndarray]) -> str:
118
+ """
119
+ Convert speech to text from a full audio segment.
120
+
121
+ :param audio: Tuple containing the sample rate and audio data as numpy array.
122
+ :return: Transcribed text.
123
+ """
124
+ audio_bytes = self.numpy_audio_to_bytes(audio[1])
125
  try:
126
  if self.config.stt.type == "OPENAI_API":
127
+ data = ("temp.wav", audio_bytes, "audio/wav")
128
  client = OpenAI(base_url=self.config.stt.url, api_key=self.config.stt.key)
129
  transcription = client.audio.transcriptions.create(model=self.config.stt.name, file=data, response_format="text")
130
  elif self.config.stt.type == "HF_API":
131
  headers = {"Authorization": "Bearer " + self.config.stt.key}
132
+ response = requests.post(self.config.stt.url, headers=headers, data=audio_bytes)
133
  if response.status_code != 200:
134
  error_details = response.json().get("error", "No error message provided")
135
  raise APIError("STT Error: HF API error", status_code=response.status_code, details=error_details)
136
  transcription = response.json().get("text", None)
137
  if transcription is None:
138
  raise APIError("STT Error: No transcription returned by HF API")
139
+ except APIError:
140
  raise
141
  except Exception as e:
142
  raise APIError(f"STT Error: Unexpected error: {e}")
143
 
144
  return transcription
145
 
146
+ def test_stt(self) -> bool:
147
+ """
148
+ Test if the STT service is working correctly.
149
+
150
+ :return: True if the STT service is working, False otherwise.
151
+ """
152
  try:
153
  self.speech_to_text_full((48000, np.zeros(10000)))
154
  return True
155
  except:
156
  return False
157
 
158
+ def test_streaming(self) -> bool:
159
+ """
160
+ Test if the STT streaming service is working correctly.
161
+
162
+ :return: True if the STT streaming service is working, False otherwise.
163
+ """
164
  try:
165
  self.speech_to_text_stream(self.numpy_audio_to_bytes(np.zeros(10000)))
166
  return True
 
169
 
170
 
171
  class TTSManager:
172
+ def __init__(self, config):
173
+ self.config = config
174
+ self.status = self.test_tts()
175
+ self.streaming = self.test_tts_stream() if self.status else False
176
+ self.read_last_message = self.rlm_stream if self.streaming else self.rlm
177
+
178
+ def test_tts(self) -> bool:
179
+ """
180
+ Test if the TTS service is working correctly.
181
+
182
+ :return: True if the TTS service is working, False otherwise.
183
+ """
184
  try:
185
  self.read_text("Handshake")
186
  return True
187
  except:
188
  return False
189
 
190
+ def test_tts_stream(self) -> bool:
191
+ """
192
+ Test if the TTS streaming service is working correctly.
193
+
194
+ :return: True if the TTS streaming service is working, False otherwise.
195
+ """
196
  try:
197
  for _ in self.read_text_stream("Handshake"):
198
  pass
 
200
  except:
201
  return False
202
 
203
+ def read_text(self, text: str) -> bytes:
204
+ """
205
+ Convert text to speech and return the audio bytes.
206
+
207
+ :param text: Text to convert to speech.
208
+ :return: Bytes representation of the audio.
209
+ """
 
 
 
 
 
 
210
  headers = {"Authorization": "Bearer " + self.config.tts.key}
211
  try:
212
  if self.config.tts.type == "OPENAI_API":
 
217
  if response.status_code != 200:
218
  error_details = response.json().get("error", "No error message provided")
219
  raise APIError(f"TTS Error: {self.config.tts.type} error", status_code=response.status_code, details=error_details)
220
+ except APIError:
221
  raise
222
  except Exception as e:
223
  raise APIError(f"TTS Error: Unexpected error: {e}")
224
 
225
  return response.content
226
 
227
+ def read_text_stream(self, text: str) -> Generator[bytes, None, None]:
228
+ """
229
+ Convert text to speech using streaming and return the audio bytes.
230
+
231
+ :param text: Text to convert to speech.
232
+ :return: Generator yielding chunks of audio bytes.
233
+ """
234
+ if self.config.tts.type != "OPENAI_API":
235
  raise APIError("TTS Error: Streaming not supported for this TTS type")
236
  headers = {"Authorization": "Bearer " + self.config.tts.key}
237
  data = {"model": self.config.tts.name, "input": text, "voice": "alloy", "response_format": "opus"}
 
245
  yield from response.iter_content(chunk_size=1024)
246
  except StopIteration:
247
  pass
248
+ except APIError:
249
  raise
250
  except Exception as e:
251
  raise APIError(f"TTS Error: Unexpected error: {e}")
252
 
253
+ def rlm(self, chat_history: List[List[Optional[str]]]) -> bytes:
254
+ """
255
+ Read the last message in the chat history and convert it to speech.
256
+
257
+ :param chat_history: List of chat messages.
258
+ :return: Bytes representation of the audio.
259
+ """
260
  if len(chat_history) > 0 and chat_history[-1][1]:
261
  return self.read_text(chat_history[-1][1])
262
 
263
+ def rlm_stream(self, chat_history: List[List[Optional[str]]]) -> Generator[bytes, None, None]:
264
+ """
265
+ Read the last message in the chat history and convert it to speech using streaming.
266
+
267
+ :param chat_history: List of chat messages.
268
+ :return: Generator yielding chunks of audio bytes.
269
+ """
270
  if len(chat_history) > 0 and chat_history[-1][1]:
271
  yield from self.read_text_stream(chat_history[-1][1])
api/llm.py CHANGED
@@ -1,40 +1,38 @@
1
  import os
2
-
3
  from openai import OpenAI
4
-
5
  from utils.errors import APIError
 
6
 
7
 
8
  class PromptManager:
9
- def __init__(self, prompts):
10
  self.prompts = prompts
11
  self.limit = os.getenv("DEMO_WORD_LIMIT")
12
 
13
- def add_limit(self, prompt):
14
  if self.limit:
15
  prompt += f" Keep your responses very short and simple, no more than {self.limit} words."
16
  return prompt
17
 
18
- def get_system_prompt(self, key):
19
  prompt = self.prompts[key]
20
  return self.add_limit(prompt)
21
 
22
- def get_problem_requirements_prompt(self, type, difficulty=None, topic=None, requirements=None):
23
- prompt = f"Create a {type} problem. Difficulty: {difficulty}. Topic: {topic} " f"Additional requirements: {requirements}. "
 
 
24
  return self.add_limit(prompt)
25
 
26
 
27
  class LLMManager:
28
- def __init__(self, config, prompts):
29
  self.config = config
30
  self.client = OpenAI(base_url=config.llm.url, api_key=config.llm.key)
31
  self.prompt_manager = PromptManager(prompts)
32
 
33
  self.status = self.test_llm()
34
- if self.status:
35
- self.streaming = self.test_llm_stream()
36
- else:
37
- self.streaming = False
38
 
39
  if self.streaming:
40
  self.end_interview = self.end_interview_stream
@@ -45,19 +43,7 @@ class LLMManager:
45
  self.get_problem = self.get_problem_full
46
  self.send_request = self.send_request_full
47
 
48
- def text_processor(self):
49
- def ans_full(response):
50
- return response
51
-
52
- def ans_stream(response):
53
- yield from response
54
-
55
- if self.streaming:
56
- return ans_full
57
- else:
58
- return ans_stream
59
-
60
- def get_text(self, messages):
61
  try:
62
  response = self.client.chat.completions.create(model=self.config.llm.name, messages=messages, temperature=1, max_tokens=2000)
63
  if not response.choices:
@@ -66,14 +52,10 @@ class LLMManager:
66
  except Exception as e:
67
  raise APIError(f"LLM Get Text Error: Unexpected error: {e}")
68
 
69
- def get_text_stream(self, messages):
70
  try:
71
  response = self.client.chat.completions.create(
72
- model=self.config.llm.name,
73
- messages=messages,
74
- temperature=1,
75
- stream=True,
76
- max_tokens=2000,
77
  )
78
  except Exception as e:
79
  raise APIError(f"LLM End Interview Error: Unexpected error: {e}")
@@ -83,110 +65,107 @@ class LLMManager:
83
  text += chunk.choices[0].delta.content
84
  yield text
85
 
86
- test_messages = [
87
- {"role": "system", "content": "You just help me test the connection."},
88
- {"role": "user", "content": "Hi!"},
89
- {"role": "user", "content": "Ping!"},
90
- ]
91
-
92
- def test_llm(self):
93
  try:
94
- self.get_text(self.test_messages)
 
 
 
 
 
 
95
  return True
96
  except:
97
  return False
98
 
99
- def test_llm_stream(self):
100
  try:
101
- for _ in self.get_text_stream(self.test_messages):
 
 
 
 
 
 
102
  pass
103
  return True
104
  except:
105
  return False
106
 
107
- def init_bot(self, problem, interview_type="coding"):
108
  system_prompt = self.prompt_manager.get_system_prompt(f"{interview_type}_interviewer_prompt")
 
109
 
110
- return [
111
- {"role": "system", "content": system_prompt + f"\nThe candidate is solving the following problem:\n {problem}"},
112
- ]
113
-
114
- def get_problem_prepare_messages(self, requirements, difficulty, topic, interview_type):
115
  system_prompt = self.prompt_manager.get_system_prompt(f"{interview_type}_problem_generation_prompt")
116
  full_prompt = self.prompt_manager.get_problem_requirements_prompt(interview_type, difficulty, topic, requirements)
117
-
118
- messages = [
119
  {"role": "system", "content": system_prompt},
120
  {"role": "user", "content": full_prompt},
121
  ]
122
 
123
- return messages
124
-
125
- def get_problem_full(self, requirements, difficulty, topic, interview_type="coding"):
126
  messages = self.get_problem_prepare_messages(requirements, difficulty, topic, interview_type)
127
  return self.get_text(messages)
128
 
129
- def get_problem_stream(self, requirements, difficulty, topic, interview_type="coding"):
 
 
130
  messages = self.get_problem_prepare_messages(requirements, difficulty, topic, interview_type)
131
  yield from self.get_text_stream(messages)
132
 
133
- def update_chat_history(self, code, previous_code, chat_history, chat_display):
 
 
134
  message = chat_display[-1][0]
135
-
136
  if code != previous_code:
137
- message += "\nMY NOTES AND CODE:\n"
138
- message += code
139
-
140
  chat_history.append({"role": "user", "content": message})
141
-
142
  return chat_history
143
 
144
- def send_request_full(self, code, previous_code, chat_history, chat_display):
 
 
145
  chat_history = self.update_chat_history(code, previous_code, chat_history, chat_display)
146
-
147
  reply = self.get_text(chat_history)
148
  chat_display.append([None, reply.split("#NOTES#")[0].strip()])
149
  chat_history.append({"role": "assistant", "content": reply})
150
-
151
  return chat_history, chat_display, code
152
 
153
- def send_request_stream(self, code, previous_code, chat_history, chat_display):
 
 
154
  chat_history = self.update_chat_history(code, previous_code, chat_history, chat_display)
155
-
156
  chat_display.append([None, ""])
157
  chat_history.append({"role": "assistant", "content": ""})
158
-
159
  reply = self.get_text_stream(chat_history)
160
  for message in reply:
161
  chat_display[-1][1] = message.split("#NOTES#")[0].strip()
162
  chat_history[-1]["content"] = message
163
-
164
  yield chat_history, chat_display, code
165
 
166
- def end_interview_prepare_messages(self, problem_description, chat_history, interview_type):
 
 
167
  transcript = [f"{message['role'].capitalize()}: {message['content']}" for message in chat_history[1:]]
168
-
169
  system_prompt = self.prompt_manager.get_system_prompt(f"{interview_type}_grading_feedback_prompt")
170
-
171
- messages = [
172
  {"role": "system", "content": system_prompt},
173
  {"role": "user", "content": f"The original problem to solve: {problem_description}"},
174
  {"role": "user", "content": "\n\n".join(transcript)},
175
  {"role": "user", "content": "Grade the interview based on the transcript provided and give feedback."},
176
  ]
177
 
178
- return messages
179
-
180
- def end_interview_full(self, problem_description, chat_history, interview_type="coding"):
181
  if len(chat_history) <= 2:
182
  return "No interview history available"
183
- else:
184
- messages = self.end_interview_prepare_messages(problem_description, chat_history, interview_type)
185
- return self.get_text(messages)
186
 
187
- def end_interview_stream(self, problem_description, chat_history, interview_type="coding"):
 
 
188
  if len(chat_history) <= 2:
189
  yield "No interview history available"
190
- else:
191
- messages = self.end_interview_prepare_messages(problem_description, chat_history, interview_type)
192
- yield from self.get_text_stream(messages)
 
1
  import os
 
2
  from openai import OpenAI
 
3
  from utils.errors import APIError
4
+ from typing import List, Dict, Generator, Optional, Tuple
5
 
6
 
7
  class PromptManager:
8
+ def __init__(self, prompts: Dict[str, str]):
9
  self.prompts = prompts
10
  self.limit = os.getenv("DEMO_WORD_LIMIT")
11
 
12
+ def add_limit(self, prompt: str) -> str:
13
  if self.limit:
14
  prompt += f" Keep your responses very short and simple, no more than {self.limit} words."
15
  return prompt
16
 
17
+ def get_system_prompt(self, key: str) -> str:
18
  prompt = self.prompts[key]
19
  return self.add_limit(prompt)
20
 
21
+ def get_problem_requirements_prompt(
22
+ self, type: str, difficulty: Optional[str] = None, topic: Optional[str] = None, requirements: Optional[str] = None
23
+ ) -> str:
24
+ prompt = f"Create a {type} problem. Difficulty: {difficulty}. Topic: {topic} Additional requirements: {requirements}. "
25
  return self.add_limit(prompt)
26
 
27
 
28
  class LLMManager:
29
+ def __init__(self, config, prompts: Dict[str, str]):
30
  self.config = config
31
  self.client = OpenAI(base_url=config.llm.url, api_key=config.llm.key)
32
  self.prompt_manager = PromptManager(prompts)
33
 
34
  self.status = self.test_llm()
35
+ self.streaming = self.test_llm_stream() if self.status else False
 
 
 
36
 
37
  if self.streaming:
38
  self.end_interview = self.end_interview_stream
 
43
  self.get_problem = self.get_problem_full
44
  self.send_request = self.send_request_full
45
 
46
+ def get_text(self, messages: List[Dict[str, str]]) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
47
  try:
48
  response = self.client.chat.completions.create(model=self.config.llm.name, messages=messages, temperature=1, max_tokens=2000)
49
  if not response.choices:
 
52
  except Exception as e:
53
  raise APIError(f"LLM Get Text Error: Unexpected error: {e}")
54
 
55
+ def get_text_stream(self, messages: List[Dict[str, str]]) -> Generator[str, None, None]:
56
  try:
57
  response = self.client.chat.completions.create(
58
+ model=self.config.llm.name, messages=messages, temperature=1, stream=True, max_tokens=2000
 
 
 
 
59
  )
60
  except Exception as e:
61
  raise APIError(f"LLM End Interview Error: Unexpected error: {e}")
 
65
  text += chunk.choices[0].delta.content
66
  yield text
67
 
68
+ def test_llm(self) -> bool:
 
 
 
 
 
 
69
  try:
70
+ self.get_text(
71
+ [
72
+ {"role": "system", "content": "You just help me test the connection."},
73
+ {"role": "user", "content": "Hi!"},
74
+ {"role": "user", "content": "Ping!"},
75
+ ]
76
+ )
77
  return True
78
  except:
79
  return False
80
 
81
+ def test_llm_stream(self) -> bool:
82
  try:
83
+ for _ in self.get_text_stream(
84
+ [
85
+ {"role": "system", "content": "You just help me test the connection."},
86
+ {"role": "user", "content": "Hi!"},
87
+ {"role": "user", "content": "Ping!"},
88
+ ]
89
+ ):
90
  pass
91
  return True
92
  except:
93
  return False
94
 
95
+ def init_bot(self, problem: str, interview_type: str = "coding") -> List[Dict[str, str]]:
96
  system_prompt = self.prompt_manager.get_system_prompt(f"{interview_type}_interviewer_prompt")
97
+ return [{"role": "system", "content": f"{system_prompt}\nThe candidate is solving the following problem:\n {problem}"}]
98
 
99
+ def get_problem_prepare_messages(self, requirements: str, difficulty: str, topic: str, interview_type: str) -> List[Dict[str, str]]:
 
 
 
 
100
  system_prompt = self.prompt_manager.get_system_prompt(f"{interview_type}_problem_generation_prompt")
101
  full_prompt = self.prompt_manager.get_problem_requirements_prompt(interview_type, difficulty, topic, requirements)
102
+ return [
 
103
  {"role": "system", "content": system_prompt},
104
  {"role": "user", "content": full_prompt},
105
  ]
106
 
107
+ def get_problem_full(self, requirements: str, difficulty: str, topic: str, interview_type: str = "coding") -> str:
 
 
108
  messages = self.get_problem_prepare_messages(requirements, difficulty, topic, interview_type)
109
  return self.get_text(messages)
110
 
111
+ def get_problem_stream(
112
+ self, requirements: str, difficulty: str, topic: str, interview_type: str = "coding"
113
+ ) -> Generator[str, None, None]:
114
  messages = self.get_problem_prepare_messages(requirements, difficulty, topic, interview_type)
115
  yield from self.get_text_stream(messages)
116
 
117
+ def update_chat_history(
118
+ self, code: str, previous_code: str, chat_history: List[Dict[str, str]], chat_display: List[List[Optional[str]]]
119
+ ) -> List[Dict[str, str]]:
120
  message = chat_display[-1][0]
 
121
  if code != previous_code:
122
+ message += "\nMY NOTES AND CODE:\n" + code
 
 
123
  chat_history.append({"role": "user", "content": message})
 
124
  return chat_history
125
 
126
+ def send_request_full(
127
+ self, code: str, previous_code: str, chat_history: List[Dict[str, str]], chat_display: List[List[Optional[str]]]
128
+ ) -> Tuple[List[Dict[str, str]], List[List[Optional[str]]], str]:
129
  chat_history = self.update_chat_history(code, previous_code, chat_history, chat_display)
 
130
  reply = self.get_text(chat_history)
131
  chat_display.append([None, reply.split("#NOTES#")[0].strip()])
132
  chat_history.append({"role": "assistant", "content": reply})
 
133
  return chat_history, chat_display, code
134
 
135
+ def send_request_stream(
136
+ self, code: str, previous_code: str, chat_history: List[Dict[str, str]], chat_display: List[List[Optional[str]]]
137
+ ) -> Generator[Tuple[List[Dict[str, str]], List[List[Optional[str]]], str], None, None]:
138
  chat_history = self.update_chat_history(code, previous_code, chat_history, chat_display)
 
139
  chat_display.append([None, ""])
140
  chat_history.append({"role": "assistant", "content": ""})
 
141
  reply = self.get_text_stream(chat_history)
142
  for message in reply:
143
  chat_display[-1][1] = message.split("#NOTES#")[0].strip()
144
  chat_history[-1]["content"] = message
 
145
  yield chat_history, chat_display, code
146
 
147
+ def end_interview_prepare_messages(
148
+ self, problem_description: str, chat_history: List[Dict[str, str]], interview_type: str
149
+ ) -> List[Dict[str, str]]:
150
  transcript = [f"{message['role'].capitalize()}: {message['content']}" for message in chat_history[1:]]
 
151
  system_prompt = self.prompt_manager.get_system_prompt(f"{interview_type}_grading_feedback_prompt")
152
+ return [
 
153
  {"role": "system", "content": system_prompt},
154
  {"role": "user", "content": f"The original problem to solve: {problem_description}"},
155
  {"role": "user", "content": "\n\n".join(transcript)},
156
  {"role": "user", "content": "Grade the interview based on the transcript provided and give feedback."},
157
  ]
158
 
159
+ def end_interview_full(self, problem_description: str, chat_history: List[Dict[str, str]], interview_type: str = "coding") -> str:
 
 
160
  if len(chat_history) <= 2:
161
  return "No interview history available"
162
+ messages = self.end_interview_prepare_messages(problem_description, chat_history, interview_type)
163
+ return self.get_text(messages)
 
164
 
165
+ def end_interview_stream(
166
+ self, problem_description: str, chat_history: List[Dict[str, str]], interview_type: str = "coding"
167
+ ) -> Generator[str, None, None]:
168
  if len(chat_history) <= 2:
169
  yield "No interview history available"
170
+ messages = self.end_interview_prepare_messages(problem_description, chat_history, interview_type)
171
+ yield from self.get_text_stream(messages)
 
app.py CHANGED
@@ -1,5 +1,4 @@
1
  import os
2
-
3
  import gradio as gr
4
 
5
  from api.audio import STTManager, TTSManager
@@ -10,33 +9,44 @@ from ui.coding import get_problem_solving_ui
10
  from ui.instructions import get_instructions_ui
11
  from utils.params import default_audio_params
12
 
13
- config = Config()
14
- llm = LLMManager(config, prompts)
15
- tts = TTSManager(config)
16
- stt = STTManager(config)
17
-
18
- default_audio_params["streaming"] = stt.streaming
19
-
20
- if os.getenv("SILENT", False):
21
- tts.read_last_message = lambda x: None
22
-
23
- # Interface
24
-
25
- with gr.Blocks(title="AI Interviewer") as demo:
26
- audio_output = gr.Audio(label="Play audio", autoplay=True, visible=os.environ.get("DEBUG", False), streaming=tts.streaming)
27
- tabs = [
28
- get_instructions_ui(llm, tts, stt, default_audio_params),
29
- get_problem_solving_ui(llm, tts, stt, default_audio_params, audio_output, name="Coding", interview_type="coding"),
30
- get_problem_solving_ui(llm, tts, stt, default_audio_params, audio_output, name="ML Design (Beta)", interview_type="ml_design"),
31
- get_problem_solving_ui(llm, tts, stt, default_audio_params, audio_output, name="ML Theory (Beta)", interview_type="ml_theory"),
32
- get_problem_solving_ui(
33
- llm, tts, stt, default_audio_params, audio_output, name="System Design (Beta)", interview_type="system_design"
34
- ),
35
- get_problem_solving_ui(llm, tts, stt, default_audio_params, audio_output, name="Math (Beta)", interview_type="math"),
36
- get_problem_solving_ui(llm, tts, stt, default_audio_params, audio_output, name="SQL (Beta)", interview_type="sql"),
37
- ]
38
-
39
- for tab in tabs:
40
- tab.render()
41
 
42
- demo.launch(show_api=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
 
2
  import gradio as gr
3
 
4
  from api.audio import STTManager, TTSManager
 
9
  from ui.instructions import get_instructions_ui
10
  from utils.params import default_audio_params
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
+ def initialize_services():
14
+ """Initialize configuration, LLM, TTS, and STT services."""
15
+ config = Config()
16
+ llm = LLMManager(config, prompts)
17
+ tts = TTSManager(config)
18
+ stt = STTManager(config)
19
+ default_audio_params["streaming"] = stt.streaming
20
+ if os.getenv("SILENT", False):
21
+ tts.read_last_message = lambda x: None
22
+ return config, llm, tts, stt
23
+
24
+
25
+ def create_interface(llm, tts, stt, audio_params):
26
+ """Create and configure the Gradio interface."""
27
+ with gr.Blocks(title="AI Interviewer") as demo:
28
+ audio_output = gr.Audio(label="Play audio", autoplay=True, visible=os.environ.get("DEBUG", False), streaming=tts.streaming)
29
+ tabs = [
30
+ get_instructions_ui(llm, tts, stt, audio_params),
31
+ get_problem_solving_ui(llm, tts, stt, audio_params, audio_output, name="Coding", interview_type="coding"),
32
+ get_problem_solving_ui(llm, tts, stt, audio_params, audio_output, name="ML Design (Beta)", interview_type="ml_design"),
33
+ get_problem_solving_ui(llm, tts, stt, audio_params, audio_output, name="ML Theory (Beta)", interview_type="ml_theory"),
34
+ get_problem_solving_ui(llm, tts, stt, audio_params, audio_output, name="System Design (Beta)", interview_type="system_design"),
35
+ get_problem_solving_ui(llm, tts, stt, audio_params, audio_output, name="Math (Beta)", interview_type="math"),
36
+ get_problem_solving_ui(llm, tts, stt, audio_params, audio_output, name="SQL (Beta)", interview_type="sql"),
37
+ ]
38
+
39
+ for tab in tabs:
40
+ tab.render()
41
+ return demo
42
+
43
+
44
+ def main():
45
+ """Main function to initialize services and launch the Gradio interface."""
46
+ config, llm, tts, stt = initialize_services()
47
+ demo = create_interface(llm, tts, stt, default_audio_params)
48
+ demo.launch(show_api=False)
49
+
50
+
51
+ if __name__ == "__main__":
52
+ main()
tests/analysis.py CHANGED
@@ -15,6 +15,7 @@ from IPython.display import Markdown, display
15
  from openai import OpenAI
16
  from tests.testing_prompts import feedback_analyzer
17
  from resources.prompts import prompts, base_prompts
 
18
 
19
  criteria_list = {
20
  "problem_statement",
@@ -65,7 +66,15 @@ criteria_list = {
65
  }
66
 
67
 
68
- def grade_attempt(file_path, grader_model, attempt_index):
 
 
 
 
 
 
 
 
69
  for retry in range(3): # Retry mechanism
70
  try:
71
  feedback = grade(file_path, grader_model, str(attempt_index))
@@ -76,26 +85,36 @@ def grade_attempt(file_path, grader_model, attempt_index):
76
  return None
77
 
78
 
79
- def complete_and_grade(interview_params, exp_name, grader_models, candidate_model):
80
- interview_type, attempt_num, llm_config = interview_params
 
 
 
81
 
 
 
 
 
 
 
 
82
  feedback_list = []
83
- attempt_successful = False
84
- for attempt in range(3): # Retry up to 3 times
 
85
  try:
86
  file_path, _ = complete_interview(interview_type, exp_name, llm_config, model=candidate_model, pause=attempt * 5)
87
  print(
88
  f"Attempt {attempt_num + 1}, retry {attempt + 1} interview simulation of {interview_type} by {llm_config.name} completed successfully"
89
  )
90
- attempt_successful = True
91
  break
92
  except Exception as e:
93
  print(f"Retry {attempt + 1} for attempt {attempt_num + 1} of {interview_type} by {llm_config.name} failed with error: {e}")
94
-
95
- if not attempt_successful:
96
  print(f"All retries failed for attempt {attempt_num + 1} of {interview_type} by {llm_config.name}")
97
  return feedback_list
98
 
 
99
  try:
100
  for i, grader_model in enumerate(grader_models):
101
  feedback = grade_attempt(file_path, grader_model, i)
@@ -103,19 +122,36 @@ def complete_and_grade(interview_params, exp_name, grader_models, candidate_mode
103
  feedback_list.append(feedback)
104
  print(f"Attempt {attempt_num + 1} of {interview_type} by {llm_config.name} graded by {grader_model} successfully")
105
  print(f"Overall score: {feedback['overall_score']}")
106
-
107
  except Exception as e:
108
  print(f"Grading for attempt {attempt_num + 1} of {interview_type} by {llm_config.name} failed with error: {e}")
109
 
110
- if len(feedback_list) == 0:
111
  print(f"Attempt {attempt_num + 1} of {interview_type} by {llm_config.name} returned an empty list")
112
 
113
  return feedback_list
114
 
115
 
116
  def run_evaluation(
117
- exp_name, num_attempts=5, interview_types=None, grader_models=None, llm_configs=None, candidate_model="gpt-3.5-turbo", num_workers=3
118
- ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  if interview_types is None:
120
  interview_types = ["ml_design", "math", "ml_theory", "system_design", "sql", "coding"]
121
  if grader_models is None:
@@ -143,12 +179,25 @@ def run_evaluation(
143
  return exp_name
144
 
145
 
146
- def highlight_color(val):
147
- color = "red" if val < 0.7 else "orange" if val < 0.9 else "lightgreen" if val < 0.95 else "green"
 
 
 
 
 
 
 
148
  return f"color: {color}"
149
 
150
 
151
- def generate_and_display_tables(df):
 
 
 
 
 
 
152
  # Grouping by prefix
153
  prefixes = ["problem", "interviewer", "feedback"]
154
  prefix_columns = [col for col in df.columns if any(col.startswith(prefix) for prefix in prefixes)]
@@ -235,15 +284,19 @@ def generate_and_display_tables(df):
235
  return tables_dict
236
 
237
 
238
- def filter_df(df, prefixes=["problem", "interviewer", "feedback"]):
239
- # Identify all columns starting with any of the prefixes
 
 
 
 
 
 
240
  columns_to_check = [col for col in df.columns if any(col.startswith(prefix) for prefix in prefixes)]
241
 
242
- # Function to check if a value is a boolean, None, or string representations of boolean types
243
  def is_valid_value(val):
244
  return isinstance(val, bool) or val is None or val is np.nan or val in {"True", "False", "None", "NaN"}
245
 
246
- # Function to convert string representations to actual booleans
247
  def to_bool(val):
248
  if val == "True":
249
  return True
@@ -253,25 +306,17 @@ def filter_df(df, prefixes=["problem", "interviewer", "feedback"]):
253
  return None
254
  return val
255
 
256
- # Check if all values in the specified columns are valid
257
  def all_values_valid(row):
258
  return all(is_valid_value(row[col]) for col in columns_to_check)
259
 
260
- # Apply filtering to keep only rows with valid values
261
  valid_df = df[df.apply(all_values_valid, axis=1)].copy()
262
-
263
- # Convert string representations to booleans
264
  for col in columns_to_check:
265
  valid_df[col] = valid_df[col].apply(to_bool)
266
 
267
- # Identify removed rows
268
  removed_rows = df[~df.index.isin(valid_df.index)]
269
-
270
- # Print the number of rows removed
271
  num_removed = len(removed_rows)
272
  print(f"Number of rows removed: {num_removed}")
273
 
274
- # Print the value from the "file_name" column for each removed row, or `None` if not present
275
  if "file_name" in removed_rows.columns:
276
  for value in removed_rows["file_name"].tolist():
277
  print(f"Removed row file_name: {value}")
@@ -281,26 +326,30 @@ def filter_df(df, prefixes=["problem", "interviewer", "feedback"]):
281
  return valid_df
282
 
283
 
284
- def generate_analysis_report(df, folder, focus=None, model="gpt-4o"):
 
 
285
 
 
 
 
 
 
 
286
  client = OpenAI(base_url="https://api.openai.com/v1")
287
 
288
  all_comments = "\n\n".join([f"Interview type: {t}. Feedback: {str(f)}" for t, f in zip(df["type"].values, df["comments"].values)])
289
 
290
- messages = [
291
- {"role": "system", "content": feedback_analyzer},
292
- {"role": "user", "content": f"Interview feedback: {all_comments}"},
293
- ]
294
 
295
  if focus:
296
  messages.append({"role": "user", "content": f"Focus only on comments about {focus} part of the interview"})
297
 
298
  response = client.chat.completions.create(model=model, messages=messages, temperature=1)
299
-
300
  comments_analysis = response.choices[0].message.content
301
  display(Markdown(comments_analysis))
302
 
303
- if folder is not None:
304
  with open(os.path.join(folder, "analysis.md"), "w") as f:
305
  f.write(comments_analysis)
306
  f.write("\n\n")
@@ -308,16 +357,18 @@ def generate_analysis_report(df, folder, focus=None, model="gpt-4o"):
308
  f.write(f"Type: {t}\n")
309
  f.write(df[[c for c in df.columns if c != "comments"]][df["type"] == t].T.to_markdown())
310
  f.write("\n\n")
311
- f.write(f"Type: all\n")
312
- f.write("\n\n")
313
- f.write("Feedback:\n")
314
- f.write(all_comments)
315
 
316
  return comments_analysis
317
 
318
 
319
- def analyze_and_improve_segment(df, segment_to_improve=None):
 
 
320
 
 
 
 
321
  sorted_stages = df[["problem", "interviewer", "feedback"]].mean().sort_values()
322
  if not segment_to_improve:
323
  segment_to_improve = sorted_stages.index[0]
@@ -326,43 +377,25 @@ def analyze_and_improve_segment(df, segment_to_improve=None):
326
  print(f"Let's try to improve {segment_to_improve}")
327
  print(f"Quality threshold {th_score}")
328
 
329
- # Identifying types that need improvement
330
  type_stage_scores = df.groupby("type")[segment_to_improve].mean()
331
- types_to_improve = []
332
- for t, s in type_stage_scores.items():
333
- if s < th_score:
334
- types_to_improve.append(t)
335
-
336
  print(f"We will focus on {types_to_improve}")
337
 
338
- # Filtering DataFrame based on identified types and scoring criteria
339
  filtered_df = df[df["type"].apply(lambda x: x in types_to_improve)]
340
  prefix_columns = [col for col in df.columns if col.startswith(segment_to_improve)]
341
  filtered_df = filtered_df[filtered_df[prefix_columns].mean(axis=1) < th_score]
342
 
343
- # Generating an analysis report
344
  comments_analysis = generate_analysis_report(filtered_df, None, focus=segment_to_improve, model="gpt-4o")
345
 
346
- # Constructing improvement prompt
347
- improvement_prompt = """You want to improve the prompts for LLM interviewer.
348
- Below you will see some of the prompts that are used right now.
349
- As well as a summary of mistakes that interviewer make.
350
- You can add 1-3 lines to each of prompts if needed, but you can't change or remove anything.
351
- """
352
 
353
- # Selecting the base prompt for the segment to improve
354
  base_prompt = base_prompts.get(f"base_{segment_to_improve}", "Base prompt not found for the segment")
355
 
356
- # Constructing the current prompts display
357
- current_prompts = "The current prompts are below. \n"
358
- current_prompts += "BASE PROMPT (applied to all interview types): \n"
359
- current_prompts += base_prompt + "\n"
360
-
361
  for k, v in prompts.items():
362
  if segment_to_improve in k:
363
  current_prompts += f"{k}: {v[len(base_prompt):]} \n\n"
364
 
365
- # Making API call to OpenAI
366
  client = OpenAI(base_url="https://api.openai.com/v1")
367
  model = "gpt-4o"
368
  messages = [
 
15
  from openai import OpenAI
16
  from tests.testing_prompts import feedback_analyzer
17
  from resources.prompts import prompts, base_prompts
18
+ from typing import List, Dict, Any, Tuple, Optional
19
 
20
  criteria_list = {
21
  "problem_statement",
 
66
  }
67
 
68
 
69
+ def grade_attempt(file_path: str, grader_model: str, attempt_index: int) -> Optional[Dict[str, Any]]:
70
+ """
71
+ Grade an interview attempt using the specified grader model.
72
+
73
+ :param file_path: Path to the JSON file containing interview data.
74
+ :param grader_model: Grader model to use for grading.
75
+ :param attempt_index: Index of the grading attempt.
76
+ :return: Feedback dictionary or None if grading fails.
77
+ """
78
  for retry in range(3): # Retry mechanism
79
  try:
80
  feedback = grade(file_path, grader_model, str(attempt_index))
 
85
  return None
86
 
87
 
88
+ def complete_and_grade(
89
+ interview_params: Tuple[str, int, Any], exp_name: str, grader_models: List[str], candidate_model: str
90
+ ) -> List[Dict[str, Any]]:
91
+ """
92
+ Complete an interview and grade it using specified grader models.
93
 
94
+ :param interview_params: Tuple containing interview type, attempt number, and LLM config.
95
+ :param exp_name: Experiment name.
96
+ :param grader_models: List of grader models.
97
+ :param candidate_model: Candidate model name.
98
+ :return: List of feedback dictionaries.
99
+ """
100
+ interview_type, attempt_num, llm_config = interview_params
101
  feedback_list = []
102
+
103
+ # Attempt interview completion with retries
104
+ for attempt in range(3):
105
  try:
106
  file_path, _ = complete_interview(interview_type, exp_name, llm_config, model=candidate_model, pause=attempt * 5)
107
  print(
108
  f"Attempt {attempt_num + 1}, retry {attempt + 1} interview simulation of {interview_type} by {llm_config.name} completed successfully"
109
  )
 
110
  break
111
  except Exception as e:
112
  print(f"Retry {attempt + 1} for attempt {attempt_num + 1} of {interview_type} by {llm_config.name} failed with error: {e}")
113
+ else:
 
114
  print(f"All retries failed for attempt {attempt_num + 1} of {interview_type} by {llm_config.name}")
115
  return feedback_list
116
 
117
+ # Grade the interview
118
  try:
119
  for i, grader_model in enumerate(grader_models):
120
  feedback = grade_attempt(file_path, grader_model, i)
 
122
  feedback_list.append(feedback)
123
  print(f"Attempt {attempt_num + 1} of {interview_type} by {llm_config.name} graded by {grader_model} successfully")
124
  print(f"Overall score: {feedback['overall_score']}")
 
125
  except Exception as e:
126
  print(f"Grading for attempt {attempt_num + 1} of {interview_type} by {llm_config.name} failed with error: {e}")
127
 
128
+ if not feedback_list:
129
  print(f"Attempt {attempt_num + 1} of {interview_type} by {llm_config.name} returned an empty list")
130
 
131
  return feedback_list
132
 
133
 
134
  def run_evaluation(
135
+ exp_name: str,
136
+ num_attempts: int = 5,
137
+ interview_types: Optional[List[str]] = None,
138
+ grader_models: Optional[List[str]] = None,
139
+ llm_configs: Optional[List[Any]] = None,
140
+ candidate_model: str = "gpt-3.5-turbo",
141
+ num_workers: int = 3,
142
+ ) -> str:
143
+ """
144
+ Run the evaluation by completing and grading interviews.
145
+
146
+ :param exp_name: Experiment name.
147
+ :param num_attempts: Number of attempts per interview type.
148
+ :param interview_types: List of interview types.
149
+ :param grader_models: List of grader models.
150
+ :param llm_configs: List of LLM configurations.
151
+ :param candidate_model: Candidate model name.
152
+ :param num_workers: Number of workers for concurrent execution.
153
+ :return: Experiment name.
154
+ """
155
  if interview_types is None:
156
  interview_types = ["ml_design", "math", "ml_theory", "system_design", "sql", "coding"]
157
  if grader_models is None:
 
179
  return exp_name
180
 
181
 
182
+ def highlight_color(val: float) -> str:
183
+ """
184
+ Highlight the cell color based on the value.
185
+
186
+ :param val: The value to determine the color.
187
+ :return: The color style string.
188
+ """
189
+ color_map = {val < 0.7: "red", 0.7 <= val < 0.9: "orange", 0.9 <= val < 0.95: "lightgreen", val >= 0.95: "green"}
190
+ color = next(color for condition, color in color_map.items() if condition)
191
  return f"color: {color}"
192
 
193
 
194
+ def generate_and_display_tables(df: pd.DataFrame) -> Dict[str, Any]:
195
+ """
196
+ Generate and display various tables for analysis.
197
+
198
+ :param df: DataFrame containing the data.
199
+ :return: Dictionary of styled tables.
200
+ """
201
  # Grouping by prefix
202
  prefixes = ["problem", "interviewer", "feedback"]
203
  prefix_columns = [col for col in df.columns if any(col.startswith(prefix) for prefix in prefixes)]
 
284
  return tables_dict
285
 
286
 
287
+ def filter_df(df: pd.DataFrame, prefixes: List[str] = ["problem", "interviewer", "feedback"]) -> pd.DataFrame:
288
+ """
289
+ Filter the DataFrame to keep only rows with valid values in specified columns.
290
+
291
+ :param df: DataFrame to filter.
292
+ :param prefixes: List of prefixes to identify columns to check.
293
+ :return: Filtered DataFrame.
294
+ """
295
  columns_to_check = [col for col in df.columns if any(col.startswith(prefix) for prefix in prefixes)]
296
 
 
297
  def is_valid_value(val):
298
  return isinstance(val, bool) or val is None or val is np.nan or val in {"True", "False", "None", "NaN"}
299
 
 
300
  def to_bool(val):
301
  if val == "True":
302
  return True
 
306
  return None
307
  return val
308
 
 
309
  def all_values_valid(row):
310
  return all(is_valid_value(row[col]) for col in columns_to_check)
311
 
 
312
  valid_df = df[df.apply(all_values_valid, axis=1)].copy()
 
 
313
  for col in columns_to_check:
314
  valid_df[col] = valid_df[col].apply(to_bool)
315
 
 
316
  removed_rows = df[~df.index.isin(valid_df.index)]
 
 
317
  num_removed = len(removed_rows)
318
  print(f"Number of rows removed: {num_removed}")
319
 
 
320
  if "file_name" in removed_rows.columns:
321
  for value in removed_rows["file_name"].tolist():
322
  print(f"Removed row file_name: {value}")
 
326
  return valid_df
327
 
328
 
329
+ def generate_analysis_report(df: pd.DataFrame, folder: Optional[str], focus: Optional[str] = None, model: str = "gpt-4o") -> str:
330
+ """
331
+ Generate an analysis report based on the feedback data.
332
 
333
+ :param df: DataFrame containing the feedback data.
334
+ :param folder: Folder to save the analysis report.
335
+ :param focus: Specific focus area for the analysis.
336
+ :param model: Model used for generating the analysis.
337
+ :return: Analysis report content.
338
+ """
339
  client = OpenAI(base_url="https://api.openai.com/v1")
340
 
341
  all_comments = "\n\n".join([f"Interview type: {t}. Feedback: {str(f)}" for t, f in zip(df["type"].values, df["comments"].values)])
342
 
343
+ messages = [{"role": "system", "content": feedback_analyzer}, {"role": "user", "content": f"Interview feedback: {all_comments}"}]
 
 
 
344
 
345
  if focus:
346
  messages.append({"role": "user", "content": f"Focus only on comments about {focus} part of the interview"})
347
 
348
  response = client.chat.completions.create(model=model, messages=messages, temperature=1)
 
349
  comments_analysis = response.choices[0].message.content
350
  display(Markdown(comments_analysis))
351
 
352
+ if folder:
353
  with open(os.path.join(folder, "analysis.md"), "w") as f:
354
  f.write(comments_analysis)
355
  f.write("\n\n")
 
357
  f.write(f"Type: {t}\n")
358
  f.write(df[[c for c in df.columns if c != "comments"]][df["type"] == t].T.to_markdown())
359
  f.write("\n\n")
360
+ f.write(f"Type: all\n\nFeedback:\n{all_comments}")
 
 
 
361
 
362
  return comments_analysis
363
 
364
 
365
+ def analyze_and_improve_segment(df: pd.DataFrame, segment_to_improve: Optional[str] = None) -> None:
366
+ """
367
+ Analyze and improve a specific segment of the interview process.
368
 
369
+ :param df: DataFrame containing the data.
370
+ :param segment_to_improve: Segment to focus on for improvement.
371
+ """
372
  sorted_stages = df[["problem", "interviewer", "feedback"]].mean().sort_values()
373
  if not segment_to_improve:
374
  segment_to_improve = sorted_stages.index[0]
 
377
  print(f"Let's try to improve {segment_to_improve}")
378
  print(f"Quality threshold {th_score}")
379
 
 
380
  type_stage_scores = df.groupby("type")[segment_to_improve].mean()
381
+ types_to_improve = [t for t, s in type_stage_scores.items() if s < th_score]
 
 
 
 
382
  print(f"We will focus on {types_to_improve}")
383
 
 
384
  filtered_df = df[df["type"].apply(lambda x: x in types_to_improve)]
385
  prefix_columns = [col for col in df.columns if col.startswith(segment_to_improve)]
386
  filtered_df = filtered_df[filtered_df[prefix_columns].mean(axis=1) < th_score]
387
 
 
388
  comments_analysis = generate_analysis_report(filtered_df, None, focus=segment_to_improve, model="gpt-4o")
389
 
390
+ improvement_prompt = "You want to improve the prompts for LLM interviewer. Below you will see some of the prompts that are used right now. As well as a summary of mistakes that interviewer make. You can add 1-3 lines to each of prompts if needed, but you can't change or remove anything."
 
 
 
 
 
391
 
 
392
  base_prompt = base_prompts.get(f"base_{segment_to_improve}", "Base prompt not found for the segment")
393
 
394
+ current_prompts = f"The current prompts are below. \nBASE PROMPT (applied to all interview types): \n{base_prompt}\n"
 
 
 
 
395
  for k, v in prompts.items():
396
  if segment_to_improve in k:
397
  current_prompts += f"{k}: {v[len(base_prompt):]} \n\n"
398
 
 
399
  client = OpenAI(base_url="https://api.openai.com/v1")
400
  model = "gpt-4o"
401
  messages = [
tests/candidate.py CHANGED
@@ -3,11 +3,10 @@ import os
3
  import random
4
  import string
5
  import time
6
-
7
  from collections import defaultdict
 
8
 
9
  from openai import OpenAI
10
-
11
  from api.llm import LLMManager
12
  from utils.config import Config
13
  from resources.data import fixed_messages, topic_lists
@@ -15,14 +14,36 @@ from resources.prompts import prompts
15
  from tests.testing_prompts import candidate_prompt
16
 
17
 
18
- def complete_interview(interview_type, exp_name, llm_config=None, requirements="", difficulty="", topic="", model="gpt-3.5-turbo", pause=0):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  client = OpenAI(base_url="https://api.openai.com/v1")
20
  config = Config()
21
  if llm_config:
22
  config.llm = llm_config
23
  llm = LLMManager(config, prompts)
24
  llm_name = config.llm.name
25
- print(f"Starting evaluation interviewer LLM: {llm_name}, candidate_LLM: {model} interview_type: {interview_type}")
26
  # Select a random topic or difficulty if not provided
27
  topic = topic or random.choice(topic_lists[interview_type])
28
  difficulty = difficulty or random.choice(["easy", "medium", "hard"])
@@ -46,7 +67,6 @@ def complete_interview(interview_type, exp_name, llm_config=None, requirements="
46
  "average_response_time_seconds": 0,
47
  },
48
  )
49
-
50
  # Initialize interviewer and candidate messages
51
  messages_interviewer = llm.init_bot(problem_statement_text, interview_type)
52
  chat_display = [[None, fixed_messages["start"]]]
@@ -82,7 +102,6 @@ def complete_interview(interview_type, exp_name, llm_config=None, requirements="
82
 
83
  chat_display.append([candidate_message, None])
84
 
85
- # Check if the interview should finish
86
  if response_json.get("finished") and not response_json.get("question"):
87
  break
88
 
 
3
  import random
4
  import string
5
  import time
 
6
  from collections import defaultdict
7
+ from typing import Dict, Optional, Tuple
8
 
9
  from openai import OpenAI
 
10
  from api.llm import LLMManager
11
  from utils.config import Config
12
  from resources.data import fixed_messages, topic_lists
 
14
  from tests.testing_prompts import candidate_prompt
15
 
16
 
17
+ def complete_interview(
18
+ interview_type: str,
19
+ exp_name: str,
20
+ llm_config: Optional[Config] = None,
21
+ requirements: str = "",
22
+ difficulty: str = "",
23
+ topic: str = "",
24
+ model: str = "gpt-3.5-turbo",
25
+ pause: int = 0,
26
+ ) -> Tuple[str, Dict]:
27
+ """
28
+ Complete an interview and record the results.
29
+
30
+ :param interview_type: Type of interview to complete.
31
+ :param exp_name: Experiment name for file saving.
32
+ :param llm_config: Optional LLM configuration.
33
+ :param requirements: Additional requirements for the interview.
34
+ :param difficulty: Difficulty level for the interview.
35
+ :param topic: Topic for the interview.
36
+ :param model: Model to use for the candidate.
37
+ :param pause: Pause duration between requests to prevent rate limits.
38
+ :return: Tuple containing the file path and interview data.
39
+ """
40
  client = OpenAI(base_url="https://api.openai.com/v1")
41
  config = Config()
42
  if llm_config:
43
  config.llm = llm_config
44
  llm = LLMManager(config, prompts)
45
  llm_name = config.llm.name
46
+ print(f"Starting evaluation interviewer LLM: {llm_name}, candidate LLM: {model}, interview type: {interview_type}")
47
  # Select a random topic or difficulty if not provided
48
  topic = topic or random.choice(topic_lists[interview_type])
49
  difficulty = difficulty or random.choice(["easy", "medium", "hard"])
 
67
  "average_response_time_seconds": 0,
68
  },
69
  )
 
70
  # Initialize interviewer and candidate messages
71
  messages_interviewer = llm.init_bot(problem_statement_text, interview_type)
72
  chat_display = [[None, fixed_messages["start"]]]
 
102
 
103
  chat_display.append([candidate_message, None])
104
 
 
105
  if response_json.get("finished") and not response_json.get("question"):
106
  break
107
 
tests/grader.py CHANGED
@@ -1,27 +1,23 @@
1
  import json
2
-
3
  from openai import OpenAI
4
-
5
  from tests.testing_prompts import grader_prompt
6
 
7
 
8
- def grade(json_file_path, model="gpt-4o", suffix=""):
9
- client = OpenAI(base_url="https://api.openai.com/v1")
 
10
 
 
 
 
 
 
 
11
  with open(json_file_path) as file:
12
  interview_data = json.load(file)
13
 
14
- interview_summary_list = []
15
- interview_summary_list.append(f"Interview type: {interview_data['inputs']['interview_type']}")
16
- interview_summary_list.append(f"Interview difficulty: {interview_data['inputs']['difficulty']}")
17
- interview_summary_list.append(f"Interview topic: {interview_data['inputs']['topic']}")
18
- if interview_data["inputs"]["requirements"] != "":
19
- interview_summary_list.append(f"Interview requirements: {interview_data['inputs']['requirements']}")
20
- interview_summary_list.append(f"Problem statement proposed by interviewer: {interview_data['problem_statement']}")
21
- interview_summary_list.append(f"\nTranscript of the whole interview below:")
22
- interview_summary_list += interview_data["transcript"]
23
- interview_summary_list.append(f"\nTHE MAIN PART OF THE INTERVIEW ENDED HERE.")
24
- interview_summary_list.append(f"Feedback provided by interviewer: {interview_data['feedback']}")
25
 
26
  messages = [
27
  {"role": "system", "content": grader_prompt},
@@ -31,25 +27,81 @@ def grade(json_file_path, model="gpt-4o", suffix=""):
31
  response = client.chat.completions.create(model=model, messages=messages, temperature=0, response_format={"type": "json_object"})
32
  feedback = json.loads(response.choices[0].message.content)
33
 
34
- feedback["file_name"] = json_file_path
35
- feedback["agent_llm"] = interview_data["interviewer_llm"]
36
- feedback["candidate_llm"] = interview_data["candidate_llm"]
37
- feedback["grader_model"] = model
38
- feedback["type"] = interview_data["inputs"]["interview_type"]
39
- feedback["difficulty"] = interview_data["inputs"]["difficulty"]
40
- feedback["topic"] = interview_data["inputs"]["topic"]
41
- feedback["average_response_time_seconds"] = interview_data["average_response_time_seconds"]
42
- feedback["number_of_messages"] = len(interview_data["transcript"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
 
 
 
 
 
 
44
  scores = [
45
- feedback[x]
46
- for x in feedback
47
- if (x.startswith("interviewer_") or x.startswith("feedback_") or x.startswith("problem_")) and feedback[x] is not None
48
  ]
49
  feedback["overall_score"] = sum(scores) / len(scores)
50
 
51
- # save results to json file in the same folder as the interview data
 
 
 
 
 
 
 
 
52
  with open(json_file_path.replace(".json", f"_feedback_{suffix}.json"), "w") as file:
53
  json.dump(feedback, file, indent=4)
54
-
55
- return feedback
 
1
  import json
2
+ from typing import Dict, Any, List
3
  from openai import OpenAI
 
4
  from tests.testing_prompts import grader_prompt
5
 
6
 
7
+ def grade(json_file_path: str, model: str = "gpt-4o", suffix: str = "") -> Dict[str, Any]:
8
+ """
9
+ Grade the interview data and provide feedback.
10
 
11
+ :param json_file_path: Path to the JSON file containing interview data.
12
+ :param model: Model to use for grading.
13
+ :param suffix: Suffix to add to the feedback file name.
14
+ :return: Feedback dictionary.
15
+ """
16
+ client = OpenAI(base_url="https://api.openai.com/v1")
17
  with open(json_file_path) as file:
18
  interview_data = json.load(file)
19
 
20
+ interview_summary_list = generate_interview_summary(interview_data)
 
 
 
 
 
 
 
 
 
 
21
 
22
  messages = [
23
  {"role": "system", "content": grader_prompt},
 
27
  response = client.chat.completions.create(model=model, messages=messages, temperature=0, response_format={"type": "json_object"})
28
  feedback = json.loads(response.choices[0].message.content)
29
 
30
+ populate_feedback_metadata(feedback, json_file_path, interview_data, model)
31
+ calculate_overall_score(feedback)
32
+
33
+ save_feedback(json_file_path, feedback, suffix)
34
+
35
+ return feedback
36
+
37
+
38
+ def generate_interview_summary(interview_data: Dict[str, Any]) -> List[str]:
39
+ """
40
+ Generate a summary of the interview data.
41
+
42
+ :param interview_data: Dictionary containing interview data.
43
+ :return: List of summary strings.
44
+ """
45
+ summary = [
46
+ f"Interview type: {interview_data['inputs']['interview_type']}",
47
+ f"Interview difficulty: {interview_data['inputs']['difficulty']}",
48
+ f"Interview topic: {interview_data['inputs']['topic']}",
49
+ ]
50
+ if interview_data["inputs"]["requirements"]:
51
+ summary.append(f"Interview requirements: {interview_data['inputs']['requirements']}")
52
+ summary.append(f"Problem statement proposed by interviewer: {interview_data['problem_statement']}")
53
+ summary.append(f"\nTranscript of the whole interview below:")
54
+ summary += interview_data["transcript"]
55
+ summary.append(f"\nTHE MAIN PART OF THE INTERVIEW ENDED HERE.")
56
+ summary.append(f"Feedback provided by interviewer: {interview_data['feedback']}")
57
+ return summary
58
+
59
+
60
+ def populate_feedback_metadata(feedback: Dict[str, Any], json_file_path: str, interview_data: Dict[str, Any], model: str) -> None:
61
+ """
62
+ Populate feedback metadata with interview details.
63
+
64
+ :param feedback: Feedback dictionary to populate.
65
+ :param json_file_path: Path to the JSON file containing interview data.
66
+ :param interview_data: Dictionary containing interview data.
67
+ :param model: Model used for grading.
68
+ """
69
+ feedback.update(
70
+ {
71
+ "file_name": json_file_path,
72
+ "agent_llm": interview_data["interviewer_llm"],
73
+ "candidate_llm": interview_data["candidate_llm"],
74
+ "grader_model": model,
75
+ "type": interview_data["inputs"]["interview_type"],
76
+ "difficulty": interview_data["inputs"]["difficulty"],
77
+ "topic": interview_data["inputs"]["topic"],
78
+ "average_response_time_seconds": interview_data["average_response_time_seconds"],
79
+ "number_of_messages": len(interview_data["transcript"]),
80
+ }
81
+ )
82
+
83
 
84
+ def calculate_overall_score(feedback: Dict[str, Any]) -> None:
85
+ """
86
+ Calculate the overall score from the feedback.
87
+
88
+ :param feedback: Feedback dictionary containing scores.
89
+ """
90
  scores = [
91
+ feedback[key]
92
+ for key in feedback
93
+ if (key.startswith("interviewer_") or key.startswith("feedback_") or key.startswith("problem_")) and feedback[key] is not None
94
  ]
95
  feedback["overall_score"] = sum(scores) / len(scores)
96
 
97
+
98
+ def save_feedback(json_file_path: str, feedback: Dict[str, Any], suffix: str) -> None:
99
+ """
100
+ Save the feedback to a JSON file.
101
+
102
+ :param json_file_path: Path to the original JSON file.
103
+ :param feedback: Feedback dictionary to save.
104
+ :param suffix: Suffix to add to the feedback file name.
105
+ """
106
  with open(json_file_path.replace(".json", f"_feedback_{suffix}.json"), "w") as file:
107
  json.dump(feedback, file, indent=4)
 
 
tests/test_e2e.py CHANGED
@@ -2,21 +2,31 @@ from tests.candidate import complete_interview
2
  from tests.grader import grade
3
  from concurrent.futures import ThreadPoolExecutor
4
 
 
5
 
6
- def complete_and_grade_interview(interview_type):
 
 
 
 
 
 
 
7
  file_path, _ = complete_interview(interview_type, "test", model="gpt-3.5-turbo")
8
  feedback = grade(file_path, model="gpt-4o")
9
  assert feedback["overall_score"] > 0.4
10
  return feedback["overall_score"]
11
 
12
 
13
- def test_complete_interview():
 
 
 
14
  interview_types = ["ml_design", "math", "ml_theory", "system_design", "sql", "coding"]
15
- scores = []
16
 
17
  with ThreadPoolExecutor(max_workers=3) as executor:
18
  futures = [executor.submit(complete_and_grade_interview, it) for it in interview_types]
19
-
20
  for future in futures:
21
  score = future.result()
22
  scores.append(score)
 
2
  from tests.grader import grade
3
  from concurrent.futures import ThreadPoolExecutor
4
 
5
+ from typing import List
6
 
7
+
8
+ def complete_and_grade_interview(interview_type: str) -> float:
9
+ """
10
+ Complete an interview and return the overall score.
11
+
12
+ :param interview_type: Type of the interview.
13
+ :return: Overall score of the interview.
14
+ """
15
  file_path, _ = complete_interview(interview_type, "test", model="gpt-3.5-turbo")
16
  feedback = grade(file_path, model="gpt-4o")
17
  assert feedback["overall_score"] > 0.4
18
  return feedback["overall_score"]
19
 
20
 
21
+ def test_complete_interview() -> None:
22
+ """
23
+ Test the complete interview process for various interview types.
24
+ """
25
  interview_types = ["ml_design", "math", "ml_theory", "system_design", "sql", "coding"]
26
+ scores: List[float] = []
27
 
28
  with ThreadPoolExecutor(max_workers=3) as executor:
29
  futures = [executor.submit(complete_and_grade_interview, it) for it in interview_types]
 
30
  for future in futures:
31
  score = future.result()
32
  scores.append(score)
tests/test_models.py CHANGED
@@ -13,15 +13,23 @@ def app_config():
13
  return Config()
14
 
15
 
16
- def test_llm_connection(app_config):
 
 
 
 
 
17
  llm = LLMManager(app_config, {})
18
- status = llm.status
19
- streaming = llm.streaming
20
- assert status, "LLM connection failed - status check failed"
21
- assert streaming, "LLM streaming failed - streaming check failed"
22
 
 
 
 
23
 
24
- def test_stt_connection(app_config):
 
25
  stt = STTManager(app_config)
26
  status = stt.status
27
  streaming = stt.streaming
@@ -29,7 +37,12 @@ def test_stt_connection(app_config):
29
  assert streaming, "STT streaming failed - streaming check failed"
30
 
31
 
32
- def test_tts_connection(app_config):
 
 
 
 
 
33
  tts = TTSManager(app_config)
34
  status = tts.status
35
  streaming = tts.streaming
 
13
  return Config()
14
 
15
 
16
+ def test_llm_connection(app_config: Config):
17
+ """
18
+ Test the connection and streaming capability of the LLM.
19
+
20
+ :param app_config: Configuration object.
21
+ """
22
  llm = LLMManager(app_config, {})
23
+ assert llm.status, "LLM connection failed - status check failed"
24
+ assert llm.streaming, "LLM streaming failed - streaming check failed"
25
+
 
26
 
27
+ def test_stt_connection(app_config: Config):
28
+ """
29
+ Test the connection and streaming capability of the STT.
30
 
31
+ :param app_config: Configuration object.
32
+ """
33
  stt = STTManager(app_config)
34
  status = stt.status
35
  streaming = stt.streaming
 
37
  assert streaming, "STT streaming failed - streaming check failed"
38
 
39
 
40
+ def test_tts_connection(app_config: Config):
41
+ """
42
+ Test the connection and streaming capability of the TTS.
43
+
44
+ :param app_config: Configuration object.
45
+ """
46
  tts = TTSManager(app_config)
47
  status = tts.status
48
  streaming = tts.streaming
ui/coding.py CHANGED
@@ -93,6 +93,7 @@ def get_problem_solving_ui(llm, tts, stt, default_audio_params, audio_output, na
93
  with gr.Accordion("Feedback", open=True) as feedback_acc:
94
  feedback = gr.Markdown(elem_id=f"{interview_type}_feedback")
95
 
 
96
  start_btn.click(fn=add_interviewer_message(fixed_messages["start"]), inputs=[chat], outputs=[chat]).success(
97
  fn=lambda: True, outputs=[started_coding]
98
  ).success(fn=tts.read_last_message, inputs=[chat], outputs=[audio_output]).success(
 
93
  with gr.Accordion("Feedback", open=True) as feedback_acc:
94
  feedback = gr.Markdown(elem_id=f"{interview_type}_feedback")
95
 
96
+ # Start button click action chain
97
  start_btn.click(fn=add_interviewer_message(fixed_messages["start"]), inputs=[chat], outputs=[chat]).success(
98
  fn=lambda: True, outputs=[started_coding]
99
  ).success(fn=tts.read_last_message, inputs=[chat], outputs=[audio_output]).success(
utils/config.py CHANGED
@@ -1,19 +1,29 @@
1
- import os
2
-
3
  from dotenv import load_dotenv
 
 
4
 
5
 
6
  class ServiceConfig:
7
- def __init__(self, url_var, type_var, name_var):
8
- self.url = os.getenv(url_var)
9
- self.type = os.getenv(type_var)
10
- self.name = os.getenv(name_var)
11
- self.key = os.getenv(f"{self.type}_KEY")
 
 
 
 
 
 
 
12
 
13
 
14
  class Config:
15
  def __init__(self):
 
 
 
16
  load_dotenv(override=True)
17
- self.llm = ServiceConfig("LLM_URL", "LLM_TYPE", "LLM_NAME")
18
- self.stt = ServiceConfig("STT_URL", "STT_TYPE", "STT_NAME")
19
- self.tts = ServiceConfig("TTS_URL", "TTS_TYPE", "TTS_NAME")
 
 
 
1
  from dotenv import load_dotenv
2
+ import os
3
+ from typing import Optional
4
 
5
 
6
  class ServiceConfig:
7
+ def __init__(self, url_var: str, type_var: str, name_var: str):
8
+ """
9
+ Initialize the ServiceConfig with environment variables.
10
+
11
+ :param url_var: Environment variable for the service URL.
12
+ :param type_var: Environment variable for the service type.
13
+ :param name_var: Environment variable for the service name.
14
+ """
15
+ self.url: Optional[str] = os.getenv(url_var)
16
+ self.type: Optional[str] = os.getenv(type_var)
17
+ self.name: Optional[str] = os.getenv(name_var)
18
+ self.key: Optional[str] = os.getenv(f"{self.type}_KEY")
19
 
20
 
21
  class Config:
22
  def __init__(self):
23
+ """
24
+ Load environment variables and initialize service configurations.
25
+ """
26
  load_dotenv(override=True)
27
+ self.llm: ServiceConfig = ServiceConfig("LLM_URL", "LLM_TYPE", "LLM_NAME")
28
+ self.stt: ServiceConfig = ServiceConfig("STT_URL", "STT_TYPE", "STT_NAME")
29
+ self.tts: ServiceConfig = ServiceConfig("TTS_URL", "TTS_TYPE", "TTS_NAME")