Spaces:
Sleeping
Sleeping
Tuchuanhuhuhu
commited on
Commit
•
eb87ba4
1
Parent(s):
5e8fd04
feat: 训练支持预估 token 消耗
Browse files- modules/train_func.py +37 -26
- modules/utils.py +3 -2
modules/train_func.py
CHANGED
@@ -8,7 +8,7 @@ import ujson as json
|
|
8 |
import commentjson
|
9 |
|
10 |
import modules.presets as presets
|
11 |
-
from modules.utils import get_file_hash
|
12 |
from modules.presets import i18n
|
13 |
|
14 |
def excel_to_jsonl(filepath, preview=False):
|
@@ -20,7 +20,27 @@ def excel_to_jsonl(filepath, preview=False):
|
|
20 |
jsonl.append(row[1].to_dict())
|
21 |
if preview:
|
22 |
break
|
23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
def jsonl_save_to_disk(jsonl, filepath):
|
26 |
file_hash = get_file_hash(file_paths = [filepath])
|
@@ -30,15 +50,27 @@ def jsonl_save_to_disk(jsonl, filepath):
|
|
30 |
f.write("\n".join([json.dumps(i, ensure_ascii=False) for i in jsonl]))
|
31 |
return save_path
|
32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
def handle_dataset_selection(file_src):
|
34 |
logging.info(f"Loading dataset {file_src.name}...")
|
35 |
preview = ""
|
36 |
if file_src.name.endswith(".jsonl"):
|
37 |
with open(file_src.name, "r") as f:
|
38 |
-
|
39 |
else:
|
40 |
-
|
41 |
-
|
|
|
|
|
42 |
|
43 |
def upload_to_openai(file_src):
|
44 |
openai.api_key = os.getenv("OPENAI_API_KEY")
|
@@ -47,27 +79,6 @@ def upload_to_openai(file_src):
|
|
47 |
logging.info(f"Uploading dataset {dspath}...")
|
48 |
if dspath.endswith(".xlsx"):
|
49 |
jsonl = excel_to_jsonl(dspath)
|
50 |
-
tmp_jsonl = []
|
51 |
-
for i in jsonl:
|
52 |
-
if "提问" in i and "答案" in i:
|
53 |
-
if "系统" in i :
|
54 |
-
tmp_jsonl.append({
|
55 |
-
"messages":[
|
56 |
-
{"role": "system", "content": i["系统"]},
|
57 |
-
{"role": "user", "content": i["提问"]},
|
58 |
-
{"role": "assistant", "content": i["答案"]}
|
59 |
-
]
|
60 |
-
})
|
61 |
-
else:
|
62 |
-
tmp_jsonl.append({
|
63 |
-
"messages":[
|
64 |
-
{"role": "user", "content": i["提问"]},
|
65 |
-
{"role": "assistant", "content": i["答案"]}
|
66 |
-
]
|
67 |
-
})
|
68 |
-
else:
|
69 |
-
logging.warning(f"跳过一行数据,因为没有找到提问和答案: {i}")
|
70 |
-
jsonl = tmp_jsonl
|
71 |
dspath = jsonl_save_to_disk(jsonl, dspath)
|
72 |
try:
|
73 |
uploaded = openai.File.create(
|
|
|
8 |
import commentjson
|
9 |
|
10 |
import modules.presets as presets
|
11 |
+
from modules.utils import get_file_hash, count_token
|
12 |
from modules.presets import i18n
|
13 |
|
14 |
def excel_to_jsonl(filepath, preview=False):
|
|
|
20 |
jsonl.append(row[1].to_dict())
|
21 |
if preview:
|
22 |
break
|
23 |
+
formatted_jsonl = []
|
24 |
+
for i in jsonl:
|
25 |
+
if "提问" in i and "答案" in i:
|
26 |
+
if "系统" in i :
|
27 |
+
formatted_jsonl.append({
|
28 |
+
"messages":[
|
29 |
+
{"role": "system", "content": i["系统"]},
|
30 |
+
{"role": "user", "content": i["提问"]},
|
31 |
+
{"role": "assistant", "content": i["答案"]}
|
32 |
+
]
|
33 |
+
})
|
34 |
+
else:
|
35 |
+
formatted_jsonl.append({
|
36 |
+
"messages":[
|
37 |
+
{"role": "user", "content": i["提问"]},
|
38 |
+
{"role": "assistant", "content": i["答案"]}
|
39 |
+
]
|
40 |
+
})
|
41 |
+
else:
|
42 |
+
logging.warning(f"跳过一行数据,因为没有找到提问和答案: {i}")
|
43 |
+
return formatted_jsonl
|
44 |
|
45 |
def jsonl_save_to_disk(jsonl, filepath):
|
46 |
file_hash = get_file_hash(file_paths = [filepath])
|
|
|
50 |
f.write("\n".join([json.dumps(i, ensure_ascii=False) for i in jsonl]))
|
51 |
return save_path
|
52 |
|
53 |
+
def estimate_cost(ds):
|
54 |
+
dialogues = []
|
55 |
+
for l in ds:
|
56 |
+
for m in l["messages"]:
|
57 |
+
dialogues.append(m["content"])
|
58 |
+
dialogues = "\n".join(dialogues)
|
59 |
+
tokens = count_token(dialogues)
|
60 |
+
return f"Token 数约为 {tokens},预估每轮(epoch)费用约为 {tokens / 1000 * 0.008} 美元。"
|
61 |
+
|
62 |
+
|
63 |
def handle_dataset_selection(file_src):
|
64 |
logging.info(f"Loading dataset {file_src.name}...")
|
65 |
preview = ""
|
66 |
if file_src.name.endswith(".jsonl"):
|
67 |
with open(file_src.name, "r") as f:
|
68 |
+
ds = [json.loads(l) for l in f.readlines()]
|
69 |
else:
|
70 |
+
ds = excel_to_jsonl(file_src.name)
|
71 |
+
preview = ds[0]
|
72 |
+
|
73 |
+
return preview, gr.update(interactive=True), estimate_cost(ds)
|
74 |
|
75 |
def upload_to_openai(file_src):
|
76 |
openai.api_key = os.getenv("OPENAI_API_KEY")
|
|
|
79 |
logging.info(f"Uploading dataset {dspath}...")
|
80 |
if dspath.endswith(".xlsx"):
|
81 |
jsonl = excel_to_jsonl(dspath)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
dspath = jsonl_save_to_disk(jsonl, dspath)
|
83 |
try:
|
84 |
uploaded = openai.File.create(
|
modules/utils.py
CHANGED
@@ -126,9 +126,10 @@ def dislike(current_model, *args):
|
|
126 |
return current_model.dislike(*args)
|
127 |
|
128 |
|
129 |
-
def count_token(
|
130 |
encoding = tiktoken.get_encoding("cl100k_base")
|
131 |
-
input_str
|
|
|
132 |
length = len(encoding.encode(input_str))
|
133 |
return length
|
134 |
|
|
|
126 |
return current_model.dislike(*args)
|
127 |
|
128 |
|
129 |
+
def count_token(input_str):
|
130 |
encoding = tiktoken.get_encoding("cl100k_base")
|
131 |
+
if type(input_str) == dict:
|
132 |
+
input_str = f"role: {input_str['role']}, content: {input_str['content']}"
|
133 |
length = len(encoding.encode(input_str))
|
134 |
return length
|
135 |
|