doberst commited on
Commit
5bdef98
1 Parent(s): 5000a4b

Upload generation_test_hf_script.py

Browse files
Files changed (1) hide show
  1. generation_test_hf_script.py +7 -4
generation_test_hf_script.py CHANGED
@@ -27,10 +27,14 @@ def load_rag_benchmark_tester_ds():
27
 
28
  def run_test(model_name, test_ds):
29
 
 
 
 
 
30
  model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
31
- tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
32
 
33
- device = "cuda" if torch.cuda.is_available() else "cpu"
34
 
35
  for i, entries in enumerate(test_ds):
36
 
@@ -63,7 +67,7 @@ def run_test(model_name, test_ds):
63
  bot = output_only.find("<bot>:")
64
  if bot > -1:
65
  output_only = output_only[bot+len("<bot>:"):]
66
-
67
  # end - post-processing
68
 
69
  print("\n")
@@ -78,7 +82,6 @@ if __name__ == "__main__":
78
  test_ds = load_rag_benchmark_tester_ds()
79
 
80
  model_name = "llmware/bling-1.4b-0.1"
81
-
82
  output = run_test(model_name,test_ds)
83
 
84
 
 
27
 
28
  def run_test(model_name, test_ds):
29
 
30
+ device = "cuda" if torch.cuda.is_available() else "cpu"
31
+
32
+ print("update: model will be loaded on device - ", device)
33
+
34
  model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
35
+ model.to(device)
36
 
37
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
38
 
39
  for i, entries in enumerate(test_ds):
40
 
 
67
  bot = output_only.find("<bot>:")
68
  if bot > -1:
69
  output_only = output_only[bot+len("<bot>:"):]
70
+
71
  # end - post-processing
72
 
73
  print("\n")
 
82
  test_ds = load_rag_benchmark_tester_ds()
83
 
84
  model_name = "llmware/bling-1.4b-0.1"
 
85
  output = run_test(model_name,test_ds)
86
 
87