johnowhitaker commited on
Commit
c89c483
1 Parent(s): 2d16aef

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +71 -0
README.md CHANGED
@@ -18,3 +18,74 @@ The following `bitsandbytes` quantization config was used during training:
18
 
19
 
20
  - PEFT 0.4.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
 
20
  - PEFT 0.4.0
21
+
22
+
23
+ notebook (training and inference): https://colab.research.google.com/drive/1GxbUYZiLidteVX4qu5iSox6oxxEOHk5O?usp=sharing
24
+
25
+
26
+ Usage:
27
+ ```python
28
+ import requests
29
+
30
+ # Get a random Wikipedia article summary using their API
31
+ def random_extract():
32
+ URL = "https://en.wikipedia.org/api/rest_v1/page/random/summary"
33
+ PARAMS = {}
34
+ r = requests.get(url = URL, params = PARAMS)
35
+ data = r.json()
36
+ return data['extract']
37
+
38
+ # Format this as a prompt that would hopefully result in the model completing with a question
39
+ def random_prompt():
40
+ e = random_extract()
41
+ return f"""### CONTEXT: {e} ### QUESTION:"""
42
+
43
+ import torch
44
+ from peft import AutoPeftModelForCausalLM
45
+ from transformers import AutoTokenizer
46
+
47
+ output_dir = "mcqgen_test"
48
+
49
+ # load base LLM model and tokenizer
50
+ model = AutoPeftModelForCausalLM.from_pretrained(
51
+ output_dir,
52
+ low_cpu_mem_usage=True,
53
+ torch_dtype=torch.float16,
54
+ load_in_4bit=True,
55
+ )
56
+ tokenizer = AutoTokenizer.from_pretrained(output_dir)
57
+
58
+ # We can feed in a random context prompt and see what question the model comes up with:
59
+ prompt = random_prompt()
60
+
61
+ input_ids = tokenizer(prompt, return_tensors="pt", truncation=True).input_ids.cuda()
62
+ # with torch.inference_mode():
63
+ outputs = model.generate(input_ids=input_ids, max_new_tokens=100, do_sample=True, top_p=0.9,temperature=0.9)
64
+
65
+ print(f"Prompt:\n{prompt}\n")
66
+ print(f"Generated MCQ:\n### QUESTION:{tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True)[0][len(prompt):]}")
67
+
68
+ def process_outputs(outputs):
69
+ s = tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True)[0]
70
+ split = s.split("### ")[1:][:7]
71
+ if len(split) != 7:
72
+ return None
73
+ # Check the starts
74
+ expected_starts = ['CONTEXT', 'QUESTION', 'A' , 'B', 'C', 'D', 'CORRECT']
75
+ for i, s in enumerate(split):
76
+ if not split[i].startswith(expected_starts[i]):
77
+ return None
78
+ return {
79
+ "context": split[0].replace("CONTEXT: ", ""),
80
+ "question": split[1].replace("QUESTION: ", ""),
81
+ "a": split[2].replace("A: ", ""),
82
+ "b": split[3].replace("B: ", ""),
83
+ "c": split[4].replace("C: ", ""),
84
+ "d": split[5].replace("D: ", ""),
85
+ "correct": split[6].replace("CORRECT: ", "")
86
+ }
87
+
88
+
89
+ process_outputs(outputs) # A nice dictionary hopefully
90
+
91
+ ```