pseudotensor commited on
Commit
7a7ff47
·
1 Parent(s): 8f3dc34

Update with h2oGPT hash 4c76cca4305a462e443618a90e5abf6ee5db6bf8

Browse files
Files changed (2) hide show
  1. generate.py +5 -3
  2. prompter.py +1 -1
generate.py CHANGED
@@ -824,8 +824,8 @@ def evaluate(
824
  decoder = decoder_raw
825
  else:
826
  print("WARNING: Special characters in prompt", flush=True)
 
827
  if stream_output:
828
- #skip_prompt = prompt_type != 'plain'
829
  skip_prompt = False
830
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=skip_prompt)
831
  gen_kwargs.update(dict(streamer=streamer))
@@ -839,13 +839,15 @@ def evaluate(
839
  outputs += new_text
840
  yield prompter.get_response(outputs, prompt=inputs_decoded,
841
  sanitize_bot_response=sanitize_bot_response)
 
842
  else:
843
  outputs = model.generate(**gen_kwargs)
844
  outputs = [decoder(s) for s in outputs.sequences]
845
  yield prompter.get_response(outputs, prompt=inputs_decoded,
846
  sanitize_bot_response=sanitize_bot_response)
847
- if save_dir and outputs and len(outputs) >= 1:
848
- decoded_output = prompt + outputs[0]
 
849
  save_generate_output(output=decoded_output, base_model=base_model, save_dir=save_dir)
850
 
851
 
 
824
  decoder = decoder_raw
825
  else:
826
  print("WARNING: Special characters in prompt", flush=True)
827
+ decoded_output = None
828
  if stream_output:
 
829
  skip_prompt = False
830
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=skip_prompt)
831
  gen_kwargs.update(dict(streamer=streamer))
 
839
  outputs += new_text
840
  yield prompter.get_response(outputs, prompt=inputs_decoded,
841
  sanitize_bot_response=sanitize_bot_response)
842
+ decoded_output = outputs
843
  else:
844
  outputs = model.generate(**gen_kwargs)
845
  outputs = [decoder(s) for s in outputs.sequences]
846
  yield prompter.get_response(outputs, prompt=inputs_decoded,
847
  sanitize_bot_response=sanitize_bot_response)
848
+ if outputs and len(outputs) >= 1:
849
+ decoded_output = prompt + outputs[0]
850
+ if save_dir and decoded_output:
851
  save_generate_output(output=decoded_output, base_model=base_model, save_dir=save_dir)
852
 
853
 
prompter.py CHANGED
@@ -71,7 +71,7 @@ class Prompter(object):
71
  output = output.split(self.pre_response)[1]
72
  allow_terminate = True
73
  else:
74
- print("Failure of parsing: %s" % output, flush=True)
75
  allow_terminate = False
76
  else:
77
  allow_terminate = True
 
71
  output = output.split(self.pre_response)[1]
72
  allow_terminate = True
73
  else:
74
+ print("Failure of parsing or not enough output yet: %s" % output, flush=True)
75
  allow_terminate = False
76
  else:
77
  allow_terminate = True