pseudotensor commited on
Commit
8f3dc34
1 Parent(s): 8cb62ff

Update with h2oGPT hash 3e927fb6330dd3d1256b47eb201bd376230dd20a

Browse files
Files changed (2) hide show
  1. generate.py +3 -7
  2. utils.py +0 -50
generate.py CHANGED
@@ -3,8 +3,9 @@ import sys
3
  import os
4
  import traceback
5
  import typing
 
6
 
7
- from utils import set_seed, clear_torch_cache, save_generate_output, NullContext, KThread, wrapped_partial
8
 
9
  SEED = 1236
10
  set_seed(SEED)
@@ -828,15 +829,10 @@ def evaluate(
828
  skip_prompt = False
829
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=skip_prompt)
830
  gen_kwargs.update(dict(streamer=streamer))
831
- if debug:
832
- KThread.show_threads()
833
  target_func = generate_with_exceptions
834
- if concurrency_count == 1:
835
- # otherwise can't do this
836
- KThread.kill_threads(target_func.__name__, debug=debug)
837
  target = wrapped_partial(generate_with_exceptions, model.generate, prompt, inputs_decoded,
838
  raise_generate_gpu_exceptions, **gen_kwargs)
839
- thread = KThread(target=target)
840
  thread.start()
841
  outputs = ""
842
  for new_text in streamer:
 
3
  import os
4
  import traceback
5
  import typing
6
+ from threading import Thread
7
 
8
+ from utils import set_seed, clear_torch_cache, save_generate_output, NullContext, wrapped_partial
9
 
10
  SEED = 1236
11
  set_seed(SEED)
 
829
  skip_prompt = False
830
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=skip_prompt)
831
  gen_kwargs.update(dict(streamer=streamer))
 
 
832
  target_func = generate_with_exceptions
 
 
 
833
  target = wrapped_partial(generate_with_exceptions, model.generate, prompt, inputs_decoded,
834
  raise_generate_gpu_exceptions, **gen_kwargs)
835
+ thread = Thread(target=target)
836
  thread.start()
837
  outputs = ""
838
  for new_text in streamer:
utils.py CHANGED
@@ -244,56 +244,6 @@ class NullContext(threading.local):
244
  pass
245
 
246
 
247
- class KThread(threading.Thread):
248
- """Thread with a kill method."""
249
-
250
- def __init__(self, *args, **keywords):
251
- threading.Thread.__init__(self, *args, **keywords)
252
- self.killed = False
253
-
254
- def start(self):
255
- """Start the thread."""
256
- self.__run_backup = self.run
257
- self.run = self.__run # Force the Thread to install our trace.
258
- threading.Thread.start(self)
259
-
260
- def __run(self):
261
- """install trace."""
262
- sys.settrace(self.globaltrace)
263
- self.__run_backup()
264
- self.run = self.__run_backup
265
-
266
- def globaltrace(self, frame, why, arg):
267
- if why == 'call':
268
- return self.localtrace
269
- else:
270
- return None
271
-
272
- def localtrace(self, frame, why, arg):
273
- if self.killed:
274
- if why == 'line':
275
- raise SystemExit()
276
- return self.localtrace
277
-
278
- def kill(self):
279
- self.killed = True
280
-
281
- @staticmethod
282
- def show_threads():
283
- for thread in threading.enumerate():
284
- print(thread.name, flush=True)
285
-
286
- @staticmethod
287
- def kill_threads(name, debug=False):
288
- for thread in threading.enumerate():
289
- if name in thread.name:
290
- if debug:
291
- print("Trying to kill %s %s" % (thread.ident, thread), flush=True)
292
- thread.kill()
293
- if debug:
294
- print(thread, flush=True)
295
-
296
-
297
  def wrapped_partial(func, *args, **kwargs):
298
  """
299
  Give partial properties of normal function, like __name__ attribute etc.
 
244
  pass
245
 
246
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
  def wrapped_partial(func, *args, **kwargs):
248
  """
249
  Give partial properties of normal function, like __name__ attribute etc.