winglian commited on
Commit
c4e4f81
1 Parent(s): 0124825

pass a prompt in from stdin for inference

Browse files
Files changed (2) hide show
  1. README.md +5 -0
  2. scripts/finetune.py +18 -5
README.md CHANGED
@@ -495,6 +495,11 @@ Pass the appropriate flag to the train command:
495
  ```bash
496
  --inference --base_model ./completed-model
497
  ```
 
 
 
 
 
498
 
499
  ### Merge LORA to base
500
 
 
495
  ```bash
496
  --inference --base_model ./completed-model
497
  ```
498
+ - Full weights finetune w/ a prompt from a text file:
499
+ ```bash
500
+ cat /tmp/prompt.txt | python scripts/finetune.py configs/your_config.yml \
501
+ --base_model ./completed-model --inference --prompter=None --load_in_8bit=True
502
+ ```
503
 
504
  ### Merge LORA to base
505
 
scripts/finetune.py CHANGED
@@ -71,7 +71,11 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
71
  if not (cfg.special_tokens and token in cfg.special_tokens):
72
  tokenizer.add_special_tokens({token: symbol})
73
 
74
- prompter_module = getattr(importlib.import_module("axolotl.prompters"), prompter)
 
 
 
 
75
 
76
  while True:
77
  print("=" * 80)
@@ -79,9 +83,12 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
79
  instruction = get_multi_line_input()
80
  if not instruction:
81
  return
82
- prompt: str = next(
83
- prompter_module().build_prompt(instruction=instruction.strip("\n"))
84
- )
 
 
 
85
  batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
86
  print("=" * 40)
87
  model.eval()
@@ -242,7 +249,13 @@ def train(
242
 
243
  if "inference" in kwargs:
244
  logging.info("calling do_inference function")
245
- do_inference(cfg, model, tokenizer)
 
 
 
 
 
 
246
  return
247
 
248
  if "shard" in kwargs:
 
71
  if not (cfg.special_tokens and token in cfg.special_tokens):
72
  tokenizer.add_special_tokens({token: symbol})
73
 
74
+ prompter_module = None
75
+ if prompter:
76
+ prompter_module = getattr(
77
+ importlib.import_module("axolotl.prompters"), prompter
78
+ )
79
 
80
  while True:
81
  print("=" * 80)
 
83
  instruction = get_multi_line_input()
84
  if not instruction:
85
  return
86
+ if prompter_module:
87
+ prompt: str = next(
88
+ prompter_module().build_prompt(instruction=instruction.strip("\n"))
89
+ )
90
+ else:
91
+ prompt = instruction.strip()
92
  batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
93
  print("=" * 40)
94
  model.eval()
 
249
 
250
  if "inference" in kwargs:
251
  logging.info("calling do_inference function")
252
+ inf_kwargs: Dict[str, Any] = {}
253
+ if "prompter" in kwargs:
254
+ if kwargs["prompter"] == "None":
255
+ inf_kwargs["prompter"] = None
256
+ else:
257
+ inf_kwargs["prompter"] = kwargs["prompter"]
258
+ do_inference(cfg, model, tokenizer, **inf_kwargs)
259
  return
260
 
261
  if "shard" in kwargs: