Kqte commited on
Commit
1906024
1 Parent(s): a4f7976

Upload 4 files

Browse files
LLaMIPa3/adapter_config.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alpha_pattern": {},
3
+ "auto_mapping": null,
4
+ "base_model_name_or_path": "/tmpdir/thompson/Meta-Llama-3-8B/",
5
+ "bias": "none",
6
+ "fan_in_fan_out": false,
7
+ "inference_mode": true,
8
+ "init_lora_weights": true,
9
+ "layers_pattern": null,
10
+ "layers_to_transform": null,
11
+ "loftq_config": {},
12
+ "lora_alpha": 16,
13
+ "lora_dropout": 0.1,
14
+ "megatron_config": null,
15
+ "megatron_core": "megatron.core",
16
+ "modules_to_save": null,
17
+ "peft_type": "LORA",
18
+ "r": 64,
19
+ "rank_pattern": {},
20
+ "revision": null,
21
+ "target_modules": [
22
+ "q_proj",
23
+ "v_proj"
24
+ ],
25
+ "task_type": "CAUSAL_LM"
26
+ }
LLaMIPa3/adapter_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:956bbad58e3c7d8101baa809d11f3aa025ddf5d36925d0c78eea43d656ad2b37
3
+ size 109069176
data/parser_test_moves_15.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
parser_generate.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import json
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
4
+ from datasets import load_dataset
5
+ from tqdm import tqdm
6
+
7
+ device_map = "auto"
8
+ model = AutoModelForCausalLM.from_pretrained(
9
+ "/path/to/meta-llama3-8b/",
10
+ return_dict=True,
11
+ torch_dtype=torch.float16,
12
+ device_map=device_map)
13
+
14
+
15
+ tokenizer = AutoTokenizer.from_pretrained("/path/to/meta-llama3-8b/",add_eos_token=True)
16
+
17
+ tokenizer.pad_token_id = tokenizer.eos_token_id + 1
18
+ tokenizer.padding_side = "right"
19
+
20
+ pipe = pipeline(task="text-generation", model=model, tokenizer=tokenizer, pad_token_id=tokenizer.pad_token_id, max_new_tokens=100)
21
+
22
+ test_dataset = load_dataset("json", data_files={'test':'/path/to/parser_test_moves_15.jsonl'})["test"]
23
+
24
+ def is_first_moves(sample):
25
+ answer = 0
26
+ slist = sample.split('\n')
27
+ if slist[0].startswith('Context: 0 <Buil> Mission has started.'):
28
+ struct = [i for i in slist if i.startswith('Structure:')]
29
+ rels = struct[0].split(':')[1].strip()
30
+ if len(rels) == 0:
31
+ answer = 1
32
+ return answer
33
+
34
+
35
+ def check_endpoints(struct, head):
36
+ """
37
+ takes a struct string and a head int and returns only
38
+ the struct rels with sources that are >= head
39
+ """
40
+ new_rels_list = []
41
+ new_rels = None
42
+ if struct:
43
+ rels = struct.split(' ')
44
+ for rel in rels:
45
+ if len(rel) > 0:
46
+ source = int(rel.split('(')[1].split(',')[0].strip())
47
+ if source >= head:
48
+ new_rels_list.append(rel)
49
+ if len(new_rels_list) > 0:
50
+ new_rels = ' '.join(new_rels_list)
51
+ return new_rels
52
+
53
+ def add_previous(sample, previous, predictions):
54
+ new_output = []
55
+ keep_str = None
56
+ #get head
57
+ slist = sample.split('\n')
58
+ head = int(slist[0].split('Context:')[1].split('<')[0].strip())
59
+ # check current structure
60
+ for s in slist:
61
+ if s.startswith('Structure:'):
62
+ new_structure = check_endpoints(previous, head)
63
+ if new_structure:
64
+ s = 'Structure: ' + new_structure + ' ' + predictions
65
+ keep_str = new_structure + ' ' + predictions
66
+ else:
67
+ s = 'Structure: ' + predictions
68
+ keep_str = predictions
69
+ new_output.append(s)
70
+ new_output_string = '\n'.join(new_output)
71
+ return keep_str, new_output_string
72
+
73
+ def format_gen(preds):
74
+ labels = ['COM','CONTR','CORR','QAP','ACK','ELAB','CLARIFQ','COND','CONTIN',
75
+ 'RES','EXPL','QELAB','ALT','NARR','CONFQ','SEQ']
76
+ split_list = [st.strip() for st in preds.split(' ')]
77
+ clean_list = []
78
+ for a in split_list:
79
+ s_tuple = None
80
+ rel = None
81
+ try:
82
+ s = a.split('(')[1].split(')')[0].split(',')
83
+ r = a.split('(')[0].strip()
84
+ except IndexError:
85
+ print('split error one')
86
+ else:
87
+ try:
88
+ s_tuple = (int(s[0]), int(s[1]))
89
+ except IndexError:
90
+ print('split error two')
91
+ except ValueError:
92
+ print('value error three')
93
+ if r in labels:
94
+ #make sure the label is well-formed
95
+ rel = r
96
+ if rel != None and s_tuple != None:
97
+ clean_list.append(rel + '(' + str(s_tuple[0]) + ',' + str(s_tuple[1]) + ')')
98
+ clean_preds = ' '.join(clean_list)
99
+ return clean_preds
100
+
101
+
102
+ def formatting_prompts_func(example):
103
+ output_text = '<|begin_of_text|>Identify the discourse structure (DS) for the new turn in the following excerpt :\n' + example + '\n ### DS:'
104
+ return output_text
105
+
106
+
107
+ f = open("/path/to/val-output-file.txt","w")
108
+
109
+ new_generations = None
110
+ previous_generations = None
111
+ for datum in tqdm(test_dataset['sample']):
112
+
113
+ #figure out if it's a first example
114
+ if is_first_moves(datum):
115
+ text = formatting_prompts_func(datum)
116
+ previous_generations = None
117
+ else:
118
+ #need to make sure head edu and relations match up
119
+ update_prev, amended_text = add_previous(datum, previous_generations, new_generations)
120
+ previous_generations = update_prev
121
+ text = formatting_prompts_func(amended_text)
122
+ generated = pipe(text)[0]['generated_text']
123
+ print(generated, file=f)
124
+ new_generations = format_gen(generated.split('### DS:')[1])
125
+
126
+ f.close()
127
+