doberst commited on
Commit
8b3064c
1 Parent(s): faef3db

Upload generation_test_llmware_script.py

Browse files
Files changed (1) hide show
  1. generation_test_llmware_script.py +64 -0
generation_test_llmware_script.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from llmware.prompts import Prompt
3
+
4
+
5
+ def load_rag_benchmark_tester_ds():
6
+
7
+ # pull 200 question rag benchmark test dataset from LLMWare HuggingFace repo
8
+ from datasets import load_dataset
9
+
10
+ ds_name = "llmware/rag_instruct_benchmark_tester"
11
+
12
+ dataset = load_dataset(ds_name)
13
+
14
+ print("update: loading test dataset - ", dataset)
15
+
16
+ test_set = []
17
+ for i, samples in enumerate(dataset["train"]):
18
+ test_set.append(samples)
19
+
20
+ # to view test set samples
21
+ # print("rag benchmark dataset test samples: ", i, samples)
22
+
23
+ return test_set
24
+
25
+
26
+ def run_test(model_name, prompt_list):
27
+
28
+ print("\nupdate: Starting RAG Benchmark Inference Test")
29
+
30
+ prompter = Prompt().load_model(model_name,from_hf=True)
31
+
32
+ for i, entries in enumerate(prompt_list):
33
+
34
+ prompt = entries["query"]
35
+ context = entries["context"]
36
+
37
+ response = prompter.prompt_main(prompt,context=context,prompt_name="default_with_context", temperature=0.3)
38
+
39
+ fc = prompter.evidence_check_numbers(response)
40
+ sc = prompter.evidence_comparison_stats(response)
41
+ sr = prompter.evidence_check_sources(response)
42
+
43
+ print("\nupdate: model inference output - ", i, response["llm_response"])
44
+ print("update: gold_answer - ", i, entries["answer"])
45
+
46
+ for entries in fc:
47
+ print("update: fact check - ", entries["fact_check"])
48
+
49
+ for entries in sc:
50
+ print("update: comparison stats - ", entries["comparison_stats"])
51
+
52
+ for entries in sr:
53
+ print("update: sources - ", entries["source_review"])
54
+
55
+ return 0
56
+
57
+
58
+ if __name__ == "__main__":
59
+
60
+ core_test_set = load_rag_benchmark_tester_ds()
61
+
62
+ model_name = "llmware/bling-1b-0.1"
63
+
64
+ output = run_test(model_name, core_test_set)