doberst commited on
Commit
0a98de0
1 Parent(s): cffbcd1

Upload generation_test_hf_script.py

Browse files
Files changed (1) hide show
  1. generation_test_hf_script.py +9 -6
generation_test_hf_script.py CHANGED
@@ -27,15 +27,19 @@ def load_rag_benchmark_tester_ds():
27
 
28
  def run_test(model_name, test_ds):
29
 
30
- model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto",torch_dtype="auto",trust_remote_code=True)
31
- tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
32
-
33
  device = "cuda" if torch.cuda.is_available() else "cpu"
34
 
 
 
 
 
 
 
 
35
  for i, entries in enumerate(test_ds):
36
 
37
  # prepare prompt packaging used in fine-tuning process
38
- new_prompt = "<human>: " + entries["context"] + "\n" + entries["query"] + "\n" + "<bot>:" + "\n"
39
 
40
  inputs = tokenizer(new_prompt, return_tensors="pt")
41
  start_of_output = len(inputs.input_ids[0])
@@ -63,7 +67,7 @@ def run_test(model_name, test_ds):
63
  bot = output_only.find("<bot>:")
64
  if bot > -1:
65
  output_only = output_only[bot+len("<bot>:"):]
66
-
67
  # end - post-processing
68
 
69
  print("\n")
@@ -78,7 +82,6 @@ if __name__ == "__main__":
78
  test_ds = load_rag_benchmark_tester_ds()
79
 
80
  model_name = "llmware/dragon-yi-6b-v0"
81
-
82
  output = run_test(model_name,test_ds)
83
 
84
 
 
27
 
28
  def run_test(model_name, test_ds):
29
 
 
 
 
30
  device = "cuda" if torch.cuda.is_available() else "cpu"
31
 
32
+ print("update: model will be loaded on device - ", device)
33
+
34
+ model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
35
+ model.to(device)
36
+
37
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
38
+
39
  for i, entries in enumerate(test_ds):
40
 
41
  # prepare prompt packaging used in fine-tuning process
42
+ new_prompt = "<human>: " + entries["context"] + "\n" + entries["query"] + "\n" + "<bot>:"
43
 
44
  inputs = tokenizer(new_prompt, return_tensors="pt")
45
  start_of_output = len(inputs.input_ids[0])
 
67
  bot = output_only.find("<bot>:")
68
  if bot > -1:
69
  output_only = output_only[bot+len("<bot>:"):]
70
+
71
  # end - post-processing
72
 
73
  print("\n")
 
82
  test_ds = load_rag_benchmark_tester_ds()
83
 
84
  model_name = "llmware/dragon-yi-6b-v0"
 
85
  output = run_test(model_name,test_ds)
86
 
87