fcyin commited on
Commit
9256cd1
·
verified ·
1 Parent(s): eca2c55

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +53 -0
README.md ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Model Card for Model ID
2
+
3
+ This is a Llama-2-7b model fine-tuned on TruthfulQA using Localized Fine-tuning on LLM Representations (LoFiT; https://arxiv.org/abs/2406.01563). This model checkpoint modifies the attention outputs of 96 attention heads (10% of all attention heads).
4
+
5
+
6
+ ### Model Description
7
+
8
+ - **License:** mit
9
+ - **Finetuned from model:** meta-llama/Llama-2-7b-hf
10
+
11
+ ### Model Sources
12
+
13
+ <!-- Provide the basic links for the model. -->
14
+
15
+ - **Repository:** https://github.com/fc2869/lo-fit
16
+ - **Paper:** https://arxiv.org/abs/2406.01563
17
+
18
+ ## Uses
19
+
20
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
21
+ Please use the lofit github repo (https://github.com/fc2869/lo-fit) and then use the following code snippet to run evaluations on TruthfulQA in the repo with this checkpoint.
22
+ ```
23
+ from models.modeling_llama import LlamaModel,LlamaForCausalLM
24
+ from transformers import AutoTokenizer
25
+ import torch
26
+ from utils.evaluate import evaluate_tqa
27
+ from utils.dataloaders import TQA
28
+
29
+ checkpoint = 'fcyin/llama2_7B_base_lofit_truthfulqa'
30
+ model_name = 'llama2_7B'
31
+ device = 'cuda'
32
+ cache_dir = './'
33
+ applied_module = 'attention'
34
+ torch_dtype = torch.float32
35
+
36
+ model = LlamaForCausalLM.custom_from_pretrained(checkpoint,
37
+ device_map=device,
38
+ cache_dir=cache_dir,
39
+ applied_module = applied_module,
40
+ torch_dtype=torch_dtype).to(device)
41
+ tokenizer = AutoTokenizer.from_pretrained(checkpoint)
42
+ dataloader = TQA(
43
+ iti_split_dir = './dataset/truthfulqa',
44
+ fold_num = 0,
45
+ data_gen_seed = 42
46
+ )
47
+ dataset = dataloader.load_data()
48
+
49
+ evaluate_tqa(fname='./',eval_dataset = dataset['test'],model_name = model_name,metrics=['mc'],tokenizer=tokenizer,model=model)
50
+ ```
51
+
52
+ ## Training Details
53
+ Please refer to the [paper](https://arxiv.org/abs/2406.01563) for the training details.