alfraser commited on
Commit
e35ef72
1 Parent(s): f89cac3

Made updates to support automatic reload of the TestGroups after a test run

Browse files
Files changed (2) hide show
  1. src/architectures.py +18 -7
  2. src/testing.py +15 -3
src/architectures.py CHANGED
@@ -17,7 +17,7 @@ from huggingface_hub import Repository
17
  from queue import Queue
18
  from threading import Thread, Timer
19
  from time import time
20
- from typing import List, Optional, Dict
21
  from better_profanity import profanity
22
 
23
  from src.common import config_dir, data_dir, hf_api_token, escape_dollars
@@ -197,16 +197,22 @@ class LogWorker(Thread):
197
  trace_file_name = "trace.json"
198
  trace_file = os.path.join(trace_dir, trace_file_name)
199
  queue = Queue()
200
- commit_time = 10 # Number of seconds after which to commit with no activity
201
- commit_after = 10 # Number of records after which to commit irrespective of time
202
  commit_count = 0 # Current uncommitted records
203
  commit_timer = None # The actual commit timer - we will schedule the commit on this
 
204
 
205
  def run(self):
206
  while True:
207
  arch_name, request, trace, trace_tags, trace_comment = LogWorker.queue.get()
208
  if request is None:
209
- LogWorker.commit_repo()
 
 
 
 
 
210
  else:
211
  if LogWorker.commit_timer is not None and LogWorker.commit_timer.is_alive():
212
  LogWorker.commit_timer.cancel()
@@ -249,12 +255,16 @@ class LogWorker(Thread):
249
 
250
  @classmethod
251
  def commit_repo(cls):
252
- print(f"LogWorker committing {LogWorker.commit_count} open records")
253
- cls.save_repo.push_to_hub()
254
- LogWorker.commit_count = 0
 
255
 
256
  @classmethod
257
  def signal_commit(cls):
 
 
 
258
  print("LogWorker signalling commit based on time elapsed")
259
  cls.queue.put((None, None, None, None, None))
260
 
@@ -271,6 +281,7 @@ if LogWorker.instance is None:
271
  LogWorker.instance = LogWorker()
272
  LogWorker.daemon = True
273
  LogWorker.instance.start()
 
274
 
275
 
276
  class Architecture:
 
17
  from queue import Queue
18
  from threading import Thread, Timer
19
  from time import time
20
+ from typing import List, Optional, Dict, Callable
21
  from better_profanity import profanity
22
 
23
  from src.common import config_dir, data_dir, hf_api_token, escape_dollars
 
197
  trace_file_name = "trace.json"
198
  trace_file = os.path.join(trace_dir, trace_file_name)
199
  queue = Queue()
200
+ commit_time = 5 # Number of seconds after which to commit with no activity
201
+ commit_after = 20 # Number of records after which to commit irrespective of time
202
  commit_count = 0 # Current uncommitted records
203
  commit_timer = None # The actual commit timer - we will schedule the commit on this
204
+ timeout_functions: List[Callable[[], None]] = [] # Callbacks which will be fired on timeout
205
 
206
  def run(self):
207
  while True:
208
  arch_name, request, trace, trace_tags, trace_comment = LogWorker.queue.get()
209
  if request is None:
210
+ for func in LogWorker.timeout_functions:
211
+ print(f"LogWorker commit running {func.__name__}")
212
+ try:
213
+ func()
214
+ except Exception as e:
215
+ print(f"Timeout func {func.__name__} had error {e}")
216
  else:
217
  if LogWorker.commit_timer is not None and LogWorker.commit_timer.is_alive():
218
  LogWorker.commit_timer.cancel()
 
255
 
256
  @classmethod
257
  def commit_repo(cls):
258
+ if cls.commit_count > 0:
259
+ print(f"LogWorker committing {LogWorker.commit_count} open records")
260
+ cls.save_repo.push_to_hub()
261
+ LogWorker.commit_count = 0
262
 
263
  @classmethod
264
  def signal_commit(cls):
265
+ # Signalling this back via the queue and not doing the work here as it would
266
+ # be executed on the Timer thread and may conflict with resources if the main
267
+ # LogWorker starts doing work concurrently.
268
  print("LogWorker signalling commit based on time elapsed")
269
  cls.queue.put((None, None, None, None, None))
270
 
 
281
  LogWorker.instance = LogWorker()
282
  LogWorker.daemon = True
283
  LogWorker.instance.start()
284
+ LogWorker.timeout_functions.append(LogWorker.commit_repo)
285
 
286
 
287
  class Architecture:
src/testing.py CHANGED
@@ -9,11 +9,11 @@ import sys
9
  from huggingface_hub import Repository
10
  from queue import Queue
11
  from random import sample
12
- from threading import Thread, Timer
13
  from typing import Dict, List, Optional, Tuple
14
 
15
- from src.architectures import Architecture, ArchitectureRequest, ArchitectureTrace
16
- from src.common import data_dir, hf_api_token
17
 
18
 
19
  class ArchitectureTestWorker(Thread):
@@ -323,12 +323,24 @@ class TestGroup:
323
  test_groups.append(tg)
324
  return test_groups
325
 
 
 
 
 
 
 
 
 
326
  @classmethod
327
  def load_all(cls, reload: bool = False):
328
  """
329
  Load all the available TestGroups, from both the json file and the DB
330
  into the class variable - for efficiency do not reload unless requested
331
  """
 
 
 
 
332
  if cls.all is None or reload:
333
  working_test_groups = {}
334
 
 
9
  from huggingface_hub import Repository
10
  from queue import Queue
11
  from random import sample
12
+ from threading import Thread
13
  from typing import Dict, List, Optional, Tuple
14
 
15
+ from src.architectures import Architecture, ArchitectureRequest, LogWorker
16
+ from src.common import data_dir
17
 
18
 
19
  class ArchitectureTestWorker(Thread):
 
323
  test_groups.append(tg)
324
  return test_groups
325
 
326
+ @classmethod
327
+ def force_load_all(cls):
328
+ """
329
+ Convenience wrapper to allow a no parameter call to force the reload of the
330
+ TestGroups without any parameters, for the LogWorker callback
331
+ """
332
+ cls.load_all(True)
333
+
334
  @classmethod
335
  def load_all(cls, reload: bool = False):
336
  """
337
  Load all the available TestGroups, from both the json file and the DB
338
  into the class variable - for efficiency do not reload unless requested
339
  """
340
+ if cls.force_load_all not in LogWorker.timeout_functions:
341
+ print("TestGroup adding forced refresh to LogWorker timeout")
342
+ LogWorker.timeout_functions.append(TestGroup.force_load_all)
343
+
344
  if cls.all is None or reload:
345
  working_test_groups = {}
346