BucketOfFish commited on
Commit
4f25dda
1 Parent(s): c07c430

Fixed inference script bug and made deterministic

Browse files
Files changed (1) hide show
  1. streaming_inference.py +2 -1
streaming_inference.py CHANGED
@@ -39,11 +39,12 @@ if __name__ == "__main__":
39
  key = key.replace("lm_head.linear.", "lm_head_linear.")
40
  model_state_dict[key] = value
41
  model.load_state_dict(model_state_dict)
 
42
 
43
  thread = Thread(
44
  target=model.generate,
45
  kwargs=dict(
46
- inputs=tokenizer( # returns a torch dictionary
47
  "Here is an essay on sea monkeys: ",
48
  return_tensors="pt",
49
  return_attention_mask=False,
 
39
  key = key.replace("lm_head.linear.", "lm_head_linear.")
40
  model_state_dict[key] = value
41
  model.load_state_dict(model_state_dict)
42
+ model.eval()
43
 
44
  thread = Thread(
45
  target=model.generate,
46
  kwargs=dict(
47
+ tokenizer( # returns a torch dictionary
48
  "Here is an essay on sea monkeys: ",
49
  return_tensors="pt",
50
  return_attention_mask=False,