dar-tau commited on
Commit
9a230a0
1 Parent(s): 7673f3b

Update configs.py

Browse files
Files changed (1) hide show
  1. configs.py +25 -17
configs.py CHANGED
@@ -1,5 +1,8 @@
1
  import os
2
 
 
 
 
3
 
4
  dataset_info = [
5
  {'name': 'Common Sense', 'hf_repo': 'tau/commonsense_qa', 'text_col': 'question'},
@@ -12,29 +15,34 @@ dataset_info = [
12
 
13
  model_info = {
14
  'LLAMA2-7B': dict(model_path='meta-llama/Llama-2-7b-chat-hf', token=os.environ['hf_token'],
15
- original_prompt_template='<s>{prompt}',
16
- interpretation_prompt_template='<s>[INST] [X] [/INST] {prompt}',
17
- ), # , load_in_8bit=True
18
-
19
- # 'Gemma-2B': dict(model_path='google/gemma-2b', device_map='cpu', token=os.environ['hf_token'],
20
- # original_prompt_template='<bos>{prompt}',
21
- # interpretation_prompt_template='<bos>User: [X]\n\nAnswer: {prompt}',
22
- # ),
23
  'GPT-2 Small': dict(model_path='gpt2', original_prompt_template='{prompt}',
24
- interpretation_prompt_template='User: [X]\n\nAnswer: {prompt}'),
 
25
  'GPT-2 Medium': dict(model_path='gpt2-medium', original_prompt_template='{prompt}',
26
- interpretation_prompt_template='User: [X]\n\nAnswer: {prompt}'),
 
27
  'GPT-2 Large': dict(model_path='gpt2-large', original_prompt_template='{prompt}',
28
- interpretation_prompt_template='User: [X]\n\nAnswer: {prompt}'),
 
29
  'GPT-2 XL': dict(model_path='gpt2-xl', original_prompt_template='{prompt}',
30
- interpretation_prompt_template='User: [X]\n\nAnswer: {prompt}'),
 
31
  'GPT-J 6B': dict(model_path='EleutherAI/gpt-j-6b', original_prompt_template='{prompt}',
32
- interpretation_prompt_template='User: [X]\n\nAnswer: {prompt}'),
 
33
  'Mistral-7B Instruct': dict(model_path='mistralai/Mistral-7B-Instruct-v0.2', device_map='cpu',
34
- original_prompt_template='<s>{prompt}',
35
- interpretation_prompt_template='<s>[INST] [X] [/INST] {prompt}',
36
- ),
37
-
 
 
 
 
 
38
  # 'TheBloke/Mistral-7B-Instruct-v0.2-GGUF': dict(model_file='mistral-7b-instruct-v0.2.Q5_K_S.gguf',
39
  # tokenizer='mistralai/Mistral-7B-Instruct-v0.2',
40
  # model_type='llama', hf=True, ctransformers=True,
 
1
  import os
2
 
3
+ llama_layers_format = 'model.layers.{k}'
4
+ gpt_layers_format = 'transformer.h.{k}'
5
+
6
 
7
  dataset_info = [
8
  {'name': 'Common Sense', 'hf_repo': 'tau/commonsense_qa', 'text_col': 'question'},
 
15
 
16
  model_info = {
17
  'LLAMA2-7B': dict(model_path='meta-llama/Llama-2-7b-chat-hf', token=os.environ['hf_token'],
18
+ original_prompt_template='<s>{prompt}',
19
+ interpretation_prompt_template='<s>[INST] [X] [/INST] {prompt}',
20
+ layers_format=llama_layers_format), # , load_in_8bit=True
 
 
 
 
 
21
  'GPT-2 Small': dict(model_path='gpt2', original_prompt_template='{prompt}',
22
+ interpretation_prompt_template='User: [X]\n\nAnswer: {prompt}',
23
+ layers_format=gpt_layers_format),
24
  'GPT-2 Medium': dict(model_path='gpt2-medium', original_prompt_template='{prompt}',
25
+ interpretation_prompt_template='User: [X]\n\nAnswer: {prompt}',
26
+ layers_format=gpt_layers_format),
27
  'GPT-2 Large': dict(model_path='gpt2-large', original_prompt_template='{prompt}',
28
+ interpretation_prompt_template='User: [X]\n\nAnswer: {prompt}',
29
+ layers_format=gpt_layers_format),
30
  'GPT-2 XL': dict(model_path='gpt2-xl', original_prompt_template='{prompt}',
31
+ interpretation_prompt_template='User: [X]\n\nAnswer: {prompt}',
32
+ layers_format=gpt_layers_format),
33
  'GPT-J 6B': dict(model_path='EleutherAI/gpt-j-6b', original_prompt_template='{prompt}',
34
+ interpretation_prompt_template='User: [X]\n\nAnswer: {prompt}',
35
+ layers_format=gpt_layers_format),
36
  'Mistral-7B Instruct': dict(model_path='mistralai/Mistral-7B-Instruct-v0.2', device_map='cpu',
37
+ original_prompt_template='<s>{prompt}',
38
+ interpretation_prompt_template='<s>[INST] [X] [/INST] {prompt}',
39
+ layers_format=llama_layers_format),
40
+
41
+ # 'Gemma-2B': dict(model_path='google/gemma-2b', device_map='cpu', token=os.environ['hf_token'],
42
+ # original_prompt_template='<bos>{prompt}',
43
+ # interpretation_prompt_template='<bos>User: [X]\n\nAnswer: {prompt}',
44
+ # ),
45
+
46
  # 'TheBloke/Mistral-7B-Instruct-v0.2-GGUF': dict(model_file='mistral-7b-instruct-v0.2.Q5_K_S.gguf',
47
  # tokenizer='mistralai/Mistral-7B-Instruct-v0.2',
48
  # model_type='llama', hf=True, ctransformers=True,