Tuchuanhuhuhu commited on
Commit
eb87ba4
1 Parent(s): 5e8fd04

feat: 训练支持预估 token 消耗

Browse files
Files changed (2) hide show
  1. modules/train_func.py +37 -26
  2. modules/utils.py +3 -2
modules/train_func.py CHANGED
@@ -8,7 +8,7 @@ import ujson as json
8
  import commentjson
9
 
10
  import modules.presets as presets
11
- from modules.utils import get_file_hash
12
  from modules.presets import i18n
13
 
14
  def excel_to_jsonl(filepath, preview=False):
@@ -20,7 +20,27 @@ def excel_to_jsonl(filepath, preview=False):
20
  jsonl.append(row[1].to_dict())
21
  if preview:
22
  break
23
- return jsonl
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  def jsonl_save_to_disk(jsonl, filepath):
26
  file_hash = get_file_hash(file_paths = [filepath])
@@ -30,15 +50,27 @@ def jsonl_save_to_disk(jsonl, filepath):
30
  f.write("\n".join([json.dumps(i, ensure_ascii=False) for i in jsonl]))
31
  return save_path
32
 
 
 
 
 
 
 
 
 
 
 
33
  def handle_dataset_selection(file_src):
34
  logging.info(f"Loading dataset {file_src.name}...")
35
  preview = ""
36
  if file_src.name.endswith(".jsonl"):
37
  with open(file_src.name, "r") as f:
38
- preview = f.readline()
39
  else:
40
- preview = excel_to_jsonl(file_src.name)[0]
41
- return preview, gr.update(interactive=True), "预估数据集 token 数量: 这个功能还没实现"
 
 
42
 
43
  def upload_to_openai(file_src):
44
  openai.api_key = os.getenv("OPENAI_API_KEY")
@@ -47,27 +79,6 @@ def upload_to_openai(file_src):
47
  logging.info(f"Uploading dataset {dspath}...")
48
  if dspath.endswith(".xlsx"):
49
  jsonl = excel_to_jsonl(dspath)
50
- tmp_jsonl = []
51
- for i in jsonl:
52
- if "提问" in i and "答案" in i:
53
- if "系统" in i :
54
- tmp_jsonl.append({
55
- "messages":[
56
- {"role": "system", "content": i["系统"]},
57
- {"role": "user", "content": i["提问"]},
58
- {"role": "assistant", "content": i["答案"]}
59
- ]
60
- })
61
- else:
62
- tmp_jsonl.append({
63
- "messages":[
64
- {"role": "user", "content": i["提问"]},
65
- {"role": "assistant", "content": i["答案"]}
66
- ]
67
- })
68
- else:
69
- logging.warning(f"跳过一行数据,因为没有找到提问和答案: {i}")
70
- jsonl = tmp_jsonl
71
  dspath = jsonl_save_to_disk(jsonl, dspath)
72
  try:
73
  uploaded = openai.File.create(
 
8
  import commentjson
9
 
10
  import modules.presets as presets
11
+ from modules.utils import get_file_hash, count_token
12
  from modules.presets import i18n
13
 
14
  def excel_to_jsonl(filepath, preview=False):
 
20
  jsonl.append(row[1].to_dict())
21
  if preview:
22
  break
23
+ formatted_jsonl = []
24
+ for i in jsonl:
25
+ if "提问" in i and "答案" in i:
26
+ if "系统" in i :
27
+ formatted_jsonl.append({
28
+ "messages":[
29
+ {"role": "system", "content": i["系统"]},
30
+ {"role": "user", "content": i["提问"]},
31
+ {"role": "assistant", "content": i["答案"]}
32
+ ]
33
+ })
34
+ else:
35
+ formatted_jsonl.append({
36
+ "messages":[
37
+ {"role": "user", "content": i["提问"]},
38
+ {"role": "assistant", "content": i["答案"]}
39
+ ]
40
+ })
41
+ else:
42
+ logging.warning(f"跳过一行数据,因为没有找到提问和答案: {i}")
43
+ return formatted_jsonl
44
 
45
  def jsonl_save_to_disk(jsonl, filepath):
46
  file_hash = get_file_hash(file_paths = [filepath])
 
50
  f.write("\n".join([json.dumps(i, ensure_ascii=False) for i in jsonl]))
51
  return save_path
52
 
53
+ def estimate_cost(ds):
54
+ dialogues = []
55
+ for l in ds:
56
+ for m in l["messages"]:
57
+ dialogues.append(m["content"])
58
+ dialogues = "\n".join(dialogues)
59
+ tokens = count_token(dialogues)
60
+ return f"Token 数约为 {tokens},预估每轮(epoch)费用约为 {tokens / 1000 * 0.008} 美元。"
61
+
62
+
63
  def handle_dataset_selection(file_src):
64
  logging.info(f"Loading dataset {file_src.name}...")
65
  preview = ""
66
  if file_src.name.endswith(".jsonl"):
67
  with open(file_src.name, "r") as f:
68
+ ds = [json.loads(l) for l in f.readlines()]
69
  else:
70
+ ds = excel_to_jsonl(file_src.name)
71
+ preview = ds[0]
72
+
73
+ return preview, gr.update(interactive=True), estimate_cost(ds)
74
 
75
  def upload_to_openai(file_src):
76
  openai.api_key = os.getenv("OPENAI_API_KEY")
 
79
  logging.info(f"Uploading dataset {dspath}...")
80
  if dspath.endswith(".xlsx"):
81
  jsonl = excel_to_jsonl(dspath)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  dspath = jsonl_save_to_disk(jsonl, dspath)
83
  try:
84
  uploaded = openai.File.create(
modules/utils.py CHANGED
@@ -126,9 +126,10 @@ def dislike(current_model, *args):
126
  return current_model.dislike(*args)
127
 
128
 
129
- def count_token(message):
130
  encoding = tiktoken.get_encoding("cl100k_base")
131
- input_str = f"role: {message['role']}, content: {message['content']}"
 
132
  length = len(encoding.encode(input_str))
133
  return length
134
 
 
126
  return current_model.dislike(*args)
127
 
128
 
129
+ def count_token(input_str):
130
  encoding = tiktoken.get_encoding("cl100k_base")
131
+ if type(input_str) == dict:
132
+ input_str = f"role: {input_str['role']}, content: {input_str['content']}"
133
  length = len(encoding.encode(input_str))
134
  return length
135