Trace2333 commited on
Commit
8311920
1 Parent(s): 692a701

add gpt data generation and analysis

Browse files
Files changed (9) hide show
  1. build_openprompt.py +19 -13
  2. central_finetuning.py +0 -0
  3. corenlp_openie.py +104 -0
  4. generation_test.py +101 -0
  5. gpt2_generation.py +32 -7
  6. gpt_api.py +27 -0
  7. monitor.sh +15 -0
  8. sft.py +4 -6
  9. trible.py +56 -0
build_openprompt.py CHANGED
@@ -3,6 +3,9 @@ import pandas as pd
3
  import json
4
  import random
5
 
 
 
 
6
  from tqdm import tqdm
7
 
8
 
@@ -12,31 +15,34 @@ samples = {
12
  }
13
  little = False
14
  all_loaded_sample = 400000
15
-
16
  s_pro = all_loaded_sample / 1e+7
17
  # 读取概率
18
- with open("./data/prompts.csv") as f:
19
  csv_reader = csv.DictReader(f)
20
  process_reader = tqdm(enumerate(csv_reader))
21
  for row_number, row in process_reader:
22
  num_samples = len(samples['x'])
23
  process_reader.set_description(f"got data num: {num_samples}")
24
- if random.uniform(0, 1) > s_pro:
25
- continue
 
 
 
 
 
 
26
  if little:
27
  if len(samples["x"]) > 100:
28
  break
29
- if len(samples["x"]) > all_loaded_sample:
30
- break
31
 
32
  datum = row
33
- prompt = datum['prompt']
34
- modifiers = json.loads(datum['raw_data'])['modifiers']
35
- if len(modifiers) < 4:
36
- continue
37
-
38
- # TODO: 外挂一个entity识别,过滤掉存在entity实体的数据
39
-
40
  label = prompt
41
  x = prompt
42
  # 小文本到大文本,因此x更小,同时x按照6:3:1的比例分配
 
3
  import json
4
  import random
5
 
6
+ from torch.nn.utils.rnn import pad_sequence
7
+
8
+
9
  from tqdm import tqdm
10
 
11
 
 
15
  }
16
  little = False
17
  all_loaded_sample = 400000
18
+ normal = True # 全部读取,非采样方式
19
  s_pro = all_loaded_sample / 1e+7
20
  # 读取概率
21
+ with open("./data/cleaned_oie_prompts.csv") as f:
22
  csv_reader = csv.DictReader(f)
23
  process_reader = tqdm(enumerate(csv_reader))
24
  for row_number, row in process_reader:
25
  num_samples = len(samples['x'])
26
  process_reader.set_description(f"got data num: {num_samples}")
27
+ if not normal:
28
+ if random.uniform(0, 1) > s_pro:
29
+ continue
30
+ if len(samples["x"]) > all_loaded_sample:
31
+ break
32
+ else:
33
+ if row['prompt'] == "":
34
+ continue
35
  if little:
36
  if len(samples["x"]) > 100:
37
  break
 
 
38
 
39
  datum = row
40
+ # prompt = datum['prompt']
41
+ prompt = ",".join(eval(datum['raw_data'])['modifiers'])
42
+ if not normal:
43
+ modifiers = eval(datum['raw_data'])['modifiers']
44
+ if len(modifiers) < 4:
45
+ continue
 
46
  label = prompt
47
  x = prompt
48
  # 小文本到大文本,因此x更小,同时x按照6:3:1的比例分配
central_finetuning.py ADDED
File without changes
corenlp_openie.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import csv
4
+ import json
5
+ import jsonlines
6
+ from tqdm import tqdm
7
+ from stanfordcorenlp import StanfordCoreNLP
8
+
9
+ import concurrent.futures
10
+
11
+
12
+ nlp = StanfordCoreNLP('./stanford-corenlp-4.5.5')
13
+
14
+ SOURCE_FILE = "./data/raw_oie_source.jsonl"
15
+
16
+ def oie_extract(sentence):
17
+ output = nlp.annotate(sentence, properties={
18
+ 'annotators': 'tokenize, ssplit, pos, depparse, parse, openie',
19
+ 'outputFormat': 'json'
20
+ })
21
+ try:
22
+ data = json.loads(output)
23
+ sentences_ie = [i['openie'] for i in data['sentences'] if len(i['openie']) > 0]
24
+ oie_result = [max([sub["object"] for sub in sen], key=len) for sen in sentences_ie]
25
+ central_result = [sen[0]["subject"] for sen in sentences_ie][1:]
26
+
27
+ result = central_result + oie_result
28
+ result = ",".join(result)
29
+ except Exception as e:
30
+ print(f"An error occurred output: {output}")
31
+ result = ""
32
+ return result
33
+
34
+ def process_sentence(sentence):
35
+ row_data = {'raw_data': {'modifiers': sentence.split(".")}, 'prompt': ''}
36
+ oie_prompt = oie_extract(sentence)
37
+ row_data['prompt'] = oie_prompt
38
+ return row_data
39
+
40
+ def get_sentences(path):
41
+ if not os.path.exists(SOURCE_FILE):
42
+ raise FileNotFoundError(f"{SOURCE_FILE} not found.")
43
+
44
+ with jsonlines.open(path) as reader:
45
+ for obj in reader:
46
+ yield obj['description']
47
+
48
+ def main():
49
+ file_name = "./data/oie_prompts.csv"
50
+ fieldnames = ['prompt', 'raw_data']
51
+ csvfile = open(file_name, 'w', newline='', encoding='utf-8')
52
+ writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
53
+ writer.writeheader()
54
+
55
+ # for sentence in tqdm(get_sentences(SOURCE_FILE), desc="extracting oie prompts"):
56
+ # row_data = {'raw_data': {'modifiers': sentence.split(".")}, "prompt": ""}
57
+ # oie_prompt = oie_extract(sentence)
58
+ # row_data['prompt'] = oie_prompt
59
+ # writer.writerow(row_data)
60
+
61
+ with concurrent.futures.ThreadPoolExecutor() as executor:
62
+ results = list(tqdm(executor.map(process_sentence, get_sentences(SOURCE_FILE)),
63
+ total=len(list(get_sentences(SOURCE_FILE))),
64
+ desc="extracting oie prompts"))
65
+
66
+ for result in results:
67
+ writer.writerow(result)
68
+
69
+ def remove_chinese(text):
70
+ pattern = re.compile(r'[\u4e00-\u9fa5]')
71
+ result = re.sub(pattern, '', text)
72
+ return result
73
+
74
+
75
+ def remove_special_chars(text):
76
+ pattern = re.compile(r'[^\w\s.,]')
77
+ result = re.sub(pattern, '', text)
78
+ return result
79
+
80
+ def cleaning_dataset():
81
+ """只清理oie_prompts.csv,保存在cleaned_oie_prompts.csv中"""
82
+ file_name = "./data/cleaned_oie_prompts.csv"
83
+ fieldnames = ['prompt', 'raw_data']
84
+ csvfile = open(file_name, 'w', newline='', encoding='utf-8')
85
+ writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
86
+ writer.writeheader()
87
+ with open("./data/oie_prompts.csv") as f:
88
+ csv_reader = csv.DictReader(f)
89
+ process_reader = tqdm(enumerate(csv_reader))
90
+ for row_number, row in process_reader:
91
+ datum = row
92
+
93
+ cleaned_prompts = remove_special_chars(remove_chinese(datum['prompt']))
94
+ joined_modifiers = ",".join(eval(datum['raw_data'])['modifiers'])
95
+ cleaned_modifiers = remove_special_chars(remove_chinese(joined_modifiers))
96
+ row_data = {'raw_data': {'modifiers': cleaned_modifiers.split(",")}, "prompt": cleaned_prompts}
97
+ writer.writerow(row_data)
98
+
99
+
100
+ if __name__ == '__main__':
101
+ # main()
102
+ cleaning_dataset()
103
+
104
+
generation_test.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import spacy
3
+ from accelerate import PartialState
4
+ from accelerate.utils import set_seed
5
+
6
+ from gpt2_generation import Translator
7
+ from gpt2_generation import generate_prompt, MODEL_CLASSES
8
+
9
+ os.environ["http_proxy"] = "http://127.0.0.1:7890"
10
+ os.environ["https_proxy"] = "http://127.0.0.1:7890"
11
+
12
+
13
+ path_for_model = "./output/gpt2_openprompt/checkpoint-4500"
14
+
15
+ args = {
16
+ "model_type": "gpt2",
17
+ "model_name_or_path": path_for_model,
18
+ "length": 80,
19
+ "length_penalty": 1.2,
20
+ "stop_token": None,
21
+ "temperature": 1.0,
22
+ "repetition_penalty": 1.2,
23
+ "k": 3,
24
+ "p": 0.9,
25
+ "prefix": "",
26
+ "padding_text": "",
27
+ "xlm_language": "",
28
+ "seed": 42,
29
+ "use_cpu": False,
30
+ "num_return_sequences": 4,
31
+ "fp16": False,
32
+ "jit": False,
33
+ }
34
+
35
+ distributed_state = PartialState(cpu=args["use_cpu"])
36
+
37
+ if args["seed"] is not None:
38
+ set_seed(args["seed"])
39
+
40
+ tokenizer = None
41
+ model = None
42
+ zh_en_translator = None
43
+ nlp = None
44
+
45
+ def load_model_and_components():
46
+ global tokenizer, model, zh_en_translator, nlp
47
+
48
+ # Initialize the model and tokenizer
49
+ try:
50
+ args["model_type"] = args["model_type"].lower()
51
+ model_class, tokenizer_class = MODEL_CLASSES[args["model_type"]]
52
+ except KeyError:
53
+ raise KeyError("the model {} you specified is not supported. You are welcome to add it and open a PR :)")
54
+
55
+ tokenizer = tokenizer_class.from_pretrained(args["model_name_or_path"], padding_side='left')
56
+ tokenizer.pad_token = tokenizer.eos_token
57
+ tokenizer.mask_token = tokenizer.eos_token
58
+ model = model_class.from_pretrained(args["model_name_or_path"])
59
+ print("Model loaded!")
60
+
61
+ # translator
62
+ zh_en_translator = Translator("Helsinki-NLP/opus-mt-zh-en")
63
+ print("Translator loaded!")
64
+
65
+ # filter
66
+ nlp = spacy.load('en_core_web_sm')
67
+ print("Filter loaded!")
68
+
69
+ # Set the model to the right device
70
+ model.to(distributed_state.device)
71
+
72
+ if args["fp16"]:
73
+ model.half()
74
+
75
+ def chat():
76
+ phrase = input("Input Prompt >>")
77
+
78
+ if tokenizer is None or model is None or zh_en_translator is None or nlp is None:
79
+ load_model_and_components()
80
+
81
+ messages = generate_prompt(
82
+ prompt_text=phrase,
83
+ args=args,
84
+ zh_en_translator=zh_en_translator,
85
+ nlp=nlp,
86
+ model=model,
87
+ tokenizer=tokenizer,
88
+ distributed_state=distributed_state,
89
+ )
90
+
91
+ for n, m in enumerate(messages):
92
+ print(f"-----generated sequence {n} -----")
93
+ print(m)
94
+ print("*"*60)
95
+
96
+
97
+
98
+ if __name__ == '__main__':
99
+ load_model_and_components()
100
+ while True:
101
+ chat()
gpt2_generation.py CHANGED
@@ -2,6 +2,7 @@
2
  # coding=utf-8
3
  import inspect
4
  import logging
 
5
  from typing import Tuple
6
 
7
  import torch
@@ -261,6 +262,26 @@ class _ModelFallbackWrapper(GenerationMixin):
261
  return self._default._reorder_cache(past_key_values, beam_idx)
262
 
263
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
  def generate_prompt(
265
  prompt_text,
266
  args,
@@ -326,6 +347,7 @@ def generate_prompt(
326
  repeat_gen_time = repeat_gen_time + 1
327
  generated_sequence = model.generate(
328
  input_ids=input_ids,
 
329
  max_length=args["length"] + len(encoded_prompt[0]),
330
  temperature=args["temperature"],
331
  top_k=args["k"],
@@ -352,13 +374,16 @@ def generate_prompt(
352
  prompt_text + text[len(tokenizer.decode(encoded_prompt[0], clean_up_tokenization_spaces=True)) :]
353
  )
354
  # no checking for prompt_text.
355
- docs = nlp(text)
356
- nouns = [token.text for token in docs if token.pos_ == 'NOUN']
357
- nouns = set(nouns)
358
- if nouns.intersection(FORBIDDEN_NOUN) and repeat_gen_time < 10:
359
- continue
360
- else:
361
- break
 
 
 
362
  generated_sequences.append(total_sequence)
363
 
364
  return generated_sequences
 
2
  # coding=utf-8
3
  import inspect
4
  import logging
5
+ import nltk
6
  from typing import Tuple
7
 
8
  import torch
 
262
  return self._default._reorder_cache(past_key_values, beam_idx)
263
 
264
 
265
+ def remove_tokens_before_copula(text):
266
+ sentences = text.split(",")
267
+ result = [sentences[0]]
268
+ for sentence in sentences[1:]:
269
+ tokens = nltk.word_tokenize(sentence)
270
+
271
+ target_indices = [i for i, token in enumerate(tokens) if token.lower() in ["is", "are", "am"]]
272
+
273
+ if target_indices:
274
+ last_target_index = target_indices[-1]
275
+ result.append(tokens[last_target_index + 1:])
276
+ else:
277
+ result.append(tokens)
278
+
279
+ all_sentences = [" ".join(sen) for sen in result[1:]]
280
+ all_sentences.insert(0, result[0])
281
+ result_text = ",".join(all_sentences)
282
+ return result_text
283
+
284
+
285
  def generate_prompt(
286
  prompt_text,
287
  args,
 
347
  repeat_gen_time = repeat_gen_time + 1
348
  generated_sequence = model.generate(
349
  input_ids=input_ids,
350
+ length_penalty=args["length_penalty"],
351
  max_length=args["length"] + len(encoded_prompt[0]),
352
  temperature=args["temperature"],
353
  top_k=args["k"],
 
374
  prompt_text + text[len(tokenizer.decode(encoded_prompt[0], clean_up_tokenization_spaces=True)) :]
375
  )
376
  # no checking for prompt_text.
377
+ # 暂时删去关键词检测
378
+ # docs = nlp(text)
379
+ # nouns = [token.text for token in docs if token.pos_ == 'NOUN']
380
+ # nouns = set(nouns)
381
+ # if nouns.intersection(FORBIDDEN_NOUN) and repeat_gen_time < 10:
382
+ # continue
383
+ # else:
384
+ # break
385
+ break
386
+ total_sequence = remove_tokens_before_copula(total_sequence)
387
  generated_sequences.append(total_sequence)
388
 
389
  return generated_sequences
gpt_api.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import openai
2
+
3
+
4
+ def get_response_create_data(cn_text):
5
+ openai.api_type = "azure"
6
+ openai.api_base = "https://poster-pku-gpt4.openai.azure.com/"
7
+ openai.api_version = "2023-07-01-preview"
8
+ openai.api_key = '788c2b57f1954ddc92bb27786fbcdd6e'
9
+
10
+ response = openai.ChatCompletion.create(
11
+ engine="dragon",
12
+ messages=[{"role": "system", "content": "Now you are a home improvement designer,\
13
+ I give you some keywords, generate a brief interior design in English, no more than words: "},
14
+ {"role": "user", "content": cn_text}],
15
+ temperature=0.7,
16
+ max_tokens=800,
17
+ top_p=0.95,
18
+ frequency_penalty=0,
19
+ presence_penalty=0,
20
+ stop=None)
21
+ return response['choices'][0]["message"]["content"]
22
+
23
+
24
+ if __name__ == '__main__':
25
+ while (1):
26
+ input_text = input("输入:")
27
+ get_response_create_data(input_text)
monitor.sh ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ while true; do
4
+
5
+ seed=$(date +%s)
6
+
7
+ python trible.py ${seed}
8
+
9
+ if [ $? -eq 0 ]; then
10
+ echo "program complect, no need to restart..."
11
+ break
12
+ else
13
+ echo "program crash, restarting"
14
+ fi
15
+ done
sft.py CHANGED
@@ -14,7 +14,7 @@ from utils import (
14
  get_dict_dataset,
15
  get_advance_dataset,)
16
 
17
- base_model = "distilgpt2"
18
  tokenizer, model = get_tok_and_model(f"./models/{base_model}")
19
  tokenizer.pad_token = tokenizer.eos_token
20
  rouge = evaluate.load("rouge")
@@ -53,18 +53,16 @@ print(f"data tokenize done. process time : {t2 - t1}")
53
 
54
 
55
  training_args = TrainingArguments(
56
- output_dir=f"./output/{base_model}_openprpmpt",
57
  evaluation_strategy="steps",
58
  eval_steps=20000,
59
- learning_rate=2e-5,
60
  lr_scheduler_type="constant",
61
  report_to="tensorboard",
62
  per_device_train_batch_size=64,
63
  per_device_eval_batch_size=32,
64
- adam_beta1=0.9,
65
- adam_beta2=0.98,
66
  save_total_limit=1,
67
- num_train_epochs=80,
68
  fp16=True,
69
  push_to_hub=False,
70
  )
 
14
  get_dict_dataset,
15
  get_advance_dataset,)
16
 
17
+ base_model = "gpt2"
18
  tokenizer, model = get_tok_and_model(f"./models/{base_model}")
19
  tokenizer.pad_token = tokenizer.eos_token
20
  rouge = evaluate.load("rouge")
 
53
 
54
 
55
  training_args = TrainingArguments(
56
+ output_dir=f"./output/{base_model}_openprompt",
57
  evaluation_strategy="steps",
58
  eval_steps=20000,
59
+ learning_rate=3e-5,
60
  lr_scheduler_type="constant",
61
  report_to="tensorboard",
62
  per_device_train_batch_size=64,
63
  per_device_eval_batch_size=32,
 
 
64
  save_total_limit=1,
65
+ num_train_epochs=60,
66
  fp16=True,
67
  push_to_hub=False,
68
  )
trible.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import click
3
+ import random
4
+ import jsonlines
5
+
6
+ from tqdm import tqdm
7
+ from gpt_api import get_response_create_data
8
+
9
+
10
+ KEYWORDS_PATH = "/data/aigc/zw/task2/pg_distilgpt/data/raw_keywords.txt"
11
+ TARGET_PATH = "/data/aigc/zw/task2/pg_distilgpt/data/raw_discriptions.jsonl"
12
+
13
+ if not os.path.exists(TARGET_PATH):
14
+ with open(TARGET_PATH, "w") as f:
15
+ pass
16
+
17
+
18
+ def read_keywords(path=KEYWORDS_PATH):
19
+
20
+ keywords = []
21
+
22
+ with open(path, 'r', encoding='utf-8') as file:
23
+ for line in tqdm(file, desc="reading keywords"):
24
+ parts = line.strip().split('\t')
25
+ result = parts[0]
26
+ keywords.append(result)
27
+
28
+ return keywords
29
+
30
+ def keywords_sampler(num, key_words):
31
+ random.seed()
32
+ while(1):
33
+ sampled_words = random.sample(key_words, num)
34
+ yield sampled_words
35
+
36
+ def create_data(keywords, total_num=10000, n=4, seed=42):
37
+ random.seed(seed)
38
+ for n, key_words in tqdm(enumerate(keywords_sampler(n, keywords)), desc="generating data"):
39
+
40
+ res = get_response_create_data(" ".join(key_words))
41
+
42
+ with jsonlines.open(TARGET_PATH, mode='a') as writer:
43
+ writer.write({"keywrods": key_words, "description": res})
44
+
45
+ if n >= total_num:
46
+ print("generation data done.")
47
+ break
48
+
49
+ @click.command()
50
+ @click.argument('seed', type=int)
51
+ def main(seed):
52
+ keywords = read_keywords()
53
+ create_data(keywords, seed=seed)
54
+
55
+ if __name__ == '__main__':
56
+ main()