Glavin001 commited on
Commit
fec6bcc
1 Parent(s): 931e606

Add streaming inference & fix stopping at EOS

Browse files
Files changed (1) hide show
  1. scripts/finetune.py +16 -5
scripts/finetune.py CHANGED
@@ -12,7 +12,7 @@ from typing import Any, Dict, List, Optional, Union
12
  import fire
13
  import torch
14
  import yaml
15
- from transformers import GenerationConfig
16
 
17
  from axolotl.utils.data import load_prepare_datasets
18
  from axolotl.utils.dict import DictDefault
@@ -64,13 +64,21 @@ def get_multi_line_input() -> Optional[str]:
64
 
65
 
66
  def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
67
- tokenizer.add_special_tokens({"unk_token": "<unk>"})
68
- tokenizer.add_special_tokens({"bos_token": "<s>"})
69
- tokenizer.add_special_tokens({"eos_token": "</s>"})
 
 
 
 
 
 
 
70
 
71
  prompter_module = getattr(importlib.import_module("axolotl.prompters"), prompter)
72
 
73
  while True:
 
74
  # support for multiline inputs
75
  instruction = get_multi_line_input()
76
  if not instruction:
@@ -79,7 +87,7 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
79
  prompter_module().build_prompt(instruction=instruction.strip("\n"))
80
  )
81
  batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
82
-
83
  model.eval()
84
  with torch.no_grad():
85
  generation_config = GenerationConfig(
@@ -98,10 +106,13 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
98
  output_hidden_states=False,
99
  output_scores=False,
100
  )
 
101
  generated = model.generate(
102
  inputs=batch["input_ids"].to(cfg.device),
103
  generation_config=generation_config,
 
104
  )
 
105
  print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
106
 
107
 
 
12
  import fire
13
  import torch
14
  import yaml
15
+ from transformers import GenerationConfig, TextStreamer
16
 
17
  from axolotl.utils.data import load_prepare_datasets
18
  from axolotl.utils.dict import DictDefault
 
64
 
65
 
66
  def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
67
+ default_tokens = {
68
+ "unk_token": "<unk>",
69
+ "bos_token": "<s>",
70
+ "eos_token": "</s>"
71
+ }
72
+
73
+ for token, symbol in default_tokens.items():
74
+ # If the token isn't already specified in the config, add it
75
+ if not (cfg.special_tokens and token in cfg.special_tokens):
76
+ tokenizer.add_special_tokens({token: symbol})
77
 
78
  prompter_module = getattr(importlib.import_module("axolotl.prompters"), prompter)
79
 
80
  while True:
81
+ print("=" * 80)
82
  # support for multiline inputs
83
  instruction = get_multi_line_input()
84
  if not instruction:
 
87
  prompter_module().build_prompt(instruction=instruction.strip("\n"))
88
  )
89
  batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
90
+ print("=" * 40)
91
  model.eval()
92
  with torch.no_grad():
93
  generation_config = GenerationConfig(
 
106
  output_hidden_states=False,
107
  output_scores=False,
108
  )
109
+ streamer = TextStreamer(tokenizer)
110
  generated = model.generate(
111
  inputs=batch["input_ids"].to(cfg.device),
112
  generation_config=generation_config,
113
+ streamer=streamer,
114
  )
115
+ print("=" * 40)
116
  print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
117
 
118