Spaces:
Sleeping
Sleeping
Tuchuanhuhuhu
commited on
Commit
·
9813f91
1
Parent(s):
6a49812
feat: 加入GPT 模型微调功能
Browse files- ChuanhuChatbot.py +16 -2
- modules/index_func.py +2 -15
- modules/train_func.py +116 -0
- modules/utils.py +15 -5
- requirements.txt +2 -1
ChuanhuChatbot.py
CHANGED
@@ -5,6 +5,7 @@ logging.basicConfig(
|
|
5 |
format="%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s",
|
6 |
)
|
7 |
|
|
|
8 |
import gradio as gr
|
9 |
|
10 |
from modules import config
|
@@ -15,6 +16,7 @@ from modules.overwrites import *
|
|
15 |
from modules.webui import *
|
16 |
from modules.repo import *
|
17 |
from modules.models.models import get_model
|
|
|
18 |
|
19 |
logging.getLogger("httpx").setLevel(logging.WARNING)
|
20 |
|
@@ -34,6 +36,7 @@ with gr.Blocks(theme=small_and_beautiful_theme) as demo:
|
|
34 |
assert type(my_api_key)==str
|
35 |
user_api_key = gr.State(my_api_key)
|
36 |
current_model = gr.State(create_new_model)
|
|
|
37 |
|
38 |
topic = gr.State(i18n("未命名对话历史记录"))
|
39 |
|
@@ -188,14 +191,17 @@ with gr.Blocks(theme=small_and_beautiful_theme) as demo:
|
|
188 |
with gr.Tab(label=i18n("训练")):
|
189 |
with gr.Column(variant="panel"):
|
190 |
dataset_preview_json = gr.JSON(label=i18n("数据集预览"), readonly=True)
|
191 |
-
|
|
|
|
|
192 |
with gr.Column(variant="panel"):
|
|
|
193 |
openai_train_epoch_slider = gr.Slider(label=i18n("训练轮数"), minimum=1, maximum=100, value=3, step=1, interactive=True)
|
194 |
openai_start_train_btn = gr.Button(i18n("开始训练"))
|
195 |
with gr.Column(variant="panel"):
|
196 |
openai_train_status = gr.Markdown(label=i18n("训练状态"), value=i18n("未开始训练"))
|
197 |
openai_status_refresh_btn = gr.Button(i18n("刷新状态"))
|
198 |
-
add_to_models_btn = gr.Button(i18n("
|
199 |
|
200 |
with gr.Tab(label=i18n("高级")):
|
201 |
gr.HTML(get_html("appearance_switcher.html").format(label=i18n("切换亮暗色主题")), elem_classes="insert-block")
|
@@ -485,6 +491,14 @@ with gr.Blocks(theme=small_and_beautiful_theme) as demo:
|
|
485 |
historyFileSelectDropdown.change(**load_history_from_file_args)
|
486 |
downloadFile.change(upload_chat_history, [current_model, downloadFile, user_name], [saveFileName, systemPromptTxt, chatbot])
|
487 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
488 |
# Advanced
|
489 |
max_context_length_slider.change(set_token_upper_limit, [current_model, max_context_length_slider], None)
|
490 |
temperature_slider.change(set_temperature, [current_model, temperature_slider], None)
|
|
|
5 |
format="%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s",
|
6 |
)
|
7 |
|
8 |
+
import colorama
|
9 |
import gradio as gr
|
10 |
|
11 |
from modules import config
|
|
|
16 |
from modules.webui import *
|
17 |
from modules.repo import *
|
18 |
from modules.models.models import get_model
|
19 |
+
from modules.train_func import handle_dataset_selection, handle_dataset_clear, upload_to_openai, start_training, get_training_status, add_to_models
|
20 |
|
21 |
logging.getLogger("httpx").setLevel(logging.WARNING)
|
22 |
|
|
|
36 |
assert type(my_api_key)==str
|
37 |
user_api_key = gr.State(my_api_key)
|
38 |
current_model = gr.State(create_new_model)
|
39 |
+
openai_ft_file_id = gr.State("")
|
40 |
|
41 |
topic = gr.State(i18n("未命名对话历史记录"))
|
42 |
|
|
|
191 |
with gr.Tab(label=i18n("训练")):
|
192 |
with gr.Column(variant="panel"):
|
193 |
dataset_preview_json = gr.JSON(label=i18n("数据集预览"), readonly=True)
|
194 |
+
dataset_selection = gr.Files(label = i18n("选择数据集"), file_types=[".xlsx", ".jsonl"], file_count="single")
|
195 |
+
upload_to_openai_btn = gr.Button(i18n("上传到OpenAI"), interactive=False)
|
196 |
+
|
197 |
with gr.Column(variant="panel"):
|
198 |
+
openai_ft_suffix = gr.Textbox(label=i18n("模型名称后缀"), value="", lines=1, placeholder=i18n("可选,用于区分不同的模型"))
|
199 |
openai_train_epoch_slider = gr.Slider(label=i18n("训练轮数"), minimum=1, maximum=100, value=3, step=1, interactive=True)
|
200 |
openai_start_train_btn = gr.Button(i18n("开始训练"))
|
201 |
with gr.Column(variant="panel"):
|
202 |
openai_train_status = gr.Markdown(label=i18n("训练状态"), value=i18n("未开始训练"))
|
203 |
openai_status_refresh_btn = gr.Button(i18n("刷新状态"))
|
204 |
+
add_to_models_btn = gr.Button(i18n("添加训练好的模型到模型列表"), interactive=False)
|
205 |
|
206 |
with gr.Tab(label=i18n("高级")):
|
207 |
gr.HTML(get_html("appearance_switcher.html").format(label=i18n("切换亮暗色主题")), elem_classes="insert-block")
|
|
|
491 |
historyFileSelectDropdown.change(**load_history_from_file_args)
|
492 |
downloadFile.change(upload_chat_history, [current_model, downloadFile, user_name], [saveFileName, systemPromptTxt, chatbot])
|
493 |
|
494 |
+
# Train
|
495 |
+
dataset_selection.upload(handle_dataset_selection, dataset_selection, [dataset_preview_json, upload_to_openai_btn, status_display])
|
496 |
+
dataset_selection.clear(handle_dataset_clear, [], [dataset_preview_json, upload_to_openai_btn])
|
497 |
+
upload_to_openai_btn.click(upload_to_openai, [dataset_selection], [openai_ft_file_id, status_display], show_progress=True)
|
498 |
+
openai_start_train_btn.click(start_training, [openai_ft_file_id, openai_ft_suffix, openai_train_epoch_slider], [openai_train_status])
|
499 |
+
openai_status_refresh_btn.click(get_training_status, [], [openai_train_status, add_to_models_btn])
|
500 |
+
add_to_models_btn.click(add_to_models, [], [model_select_dropdown, status_display], show_progress=True)
|
501 |
+
|
502 |
# Advanced
|
503 |
max_context_length_slider.change(set_token_upper_limit, [current_model, max_context_length_slider], None)
|
504 |
temperature_slider.change(set_temperature, [current_model, temperature_slider], None)
|
modules/index_func.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
import os
|
2 |
import logging
|
3 |
|
4 |
-
import
|
5 |
import PyPDF2
|
6 |
from tqdm import tqdm
|
7 |
|
@@ -10,19 +10,6 @@ from modules.utils import *
|
|
10 |
from modules.config import local_embedding
|
11 |
|
12 |
|
13 |
-
def get_index_name(file_src):
|
14 |
-
file_paths = [x.name for x in file_src]
|
15 |
-
file_paths.sort(key=lambda x: os.path.basename(x))
|
16 |
-
|
17 |
-
md5_hash = hashlib.md5()
|
18 |
-
for file_path in file_paths:
|
19 |
-
with open(file_path, "rb") as f:
|
20 |
-
while chunk := f.read(8192):
|
21 |
-
md5_hash.update(chunk)
|
22 |
-
|
23 |
-
return md5_hash.hexdigest()
|
24 |
-
|
25 |
-
|
26 |
def get_documents(file_src):
|
27 |
from langchain.schema import Document
|
28 |
from langchain.text_splitter import TokenTextSplitter
|
@@ -113,7 +100,7 @@ def construct_index(
|
|
113 |
embedding_limit = None if embedding_limit == 0 else embedding_limit
|
114 |
separator = " " if separator == "" else separator
|
115 |
|
116 |
-
index_name =
|
117 |
index_path = f"./index/{index_name}"
|
118 |
if local_embedding:
|
119 |
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
|
|
1 |
import os
|
2 |
import logging
|
3 |
|
4 |
+
import hashlib
|
5 |
import PyPDF2
|
6 |
from tqdm import tqdm
|
7 |
|
|
|
10 |
from modules.config import local_embedding
|
11 |
|
12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
def get_documents(file_src):
|
14 |
from langchain.schema import Document
|
15 |
from langchain.text_splitter import TokenTextSplitter
|
|
|
100 |
embedding_limit = None if embedding_limit == 0 else embedding_limit
|
101 |
separator = " " if separator == "" else separator
|
102 |
|
103 |
+
index_name = get_file_hash(file_src)
|
104 |
index_path = f"./index/{index_name}"
|
105 |
if local_embedding:
|
106 |
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
modules/train_func.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import logging
|
3 |
+
import traceback
|
4 |
+
|
5 |
+
import openai
|
6 |
+
import gradio as gr
|
7 |
+
import ujson as json
|
8 |
+
|
9 |
+
import modules.presets as presets
|
10 |
+
from modules.utils import get_file_hash
|
11 |
+
from modules.presets import i18n
|
12 |
+
|
13 |
+
def excel_to_jsonl(filepath, preview=False):
|
14 |
+
jsonl = []
|
15 |
+
with open(filepath, "rb") as f:
|
16 |
+
import pandas as pd
|
17 |
+
df = pd.read_excel(f)
|
18 |
+
for row in df.iterrows():
|
19 |
+
jsonl.append(row[1].to_dict())
|
20 |
+
if preview:
|
21 |
+
break
|
22 |
+
return jsonl
|
23 |
+
|
24 |
+
def jsonl_save_to_disk(jsonl, filepath):
|
25 |
+
file_hash = get_file_hash(file_paths = [filepath])
|
26 |
+
os.makedirs("files", exist_ok=True)
|
27 |
+
save_path = f"files/{file_hash}.jsonl"
|
28 |
+
with open(save_path, "w") as f:
|
29 |
+
f.write("\n".join([json.dumps(i, ensure_ascii=False) for i in jsonl]))
|
30 |
+
return save_path
|
31 |
+
|
32 |
+
def handle_dataset_selection(file_src):
|
33 |
+
logging.info(f"Loading dataset {file_src.name}...")
|
34 |
+
preview = ""
|
35 |
+
if file_src.name.endswith(".jsonl"):
|
36 |
+
with open(file_src.name, "r") as f:
|
37 |
+
preview = f.readline()
|
38 |
+
else:
|
39 |
+
preview = excel_to_jsonl(file_src.name)[0]
|
40 |
+
return preview, gr.update(interactive=True), "预估数据集 token 数量: 这个功能还没实现"
|
41 |
+
|
42 |
+
def upload_to_openai(file_src):
|
43 |
+
openai.api_key = os.getenv("OPENAI_API_KEY")
|
44 |
+
dspath = file_src.name
|
45 |
+
msg = ""
|
46 |
+
logging.info(f"Uploading dataset {dspath}...")
|
47 |
+
if dspath.endswith(".xlsx"):
|
48 |
+
jsonl = excel_to_jsonl(dspath)
|
49 |
+
tmp_jsonl = []
|
50 |
+
for i in jsonl:
|
51 |
+
if "提问" in i and "答案" in i:
|
52 |
+
if "系统" in i :
|
53 |
+
tmp_jsonl.append({
|
54 |
+
"messages":[
|
55 |
+
{"role": "system", "content": i["系统"]},
|
56 |
+
{"role": "user", "content": i["提问"]},
|
57 |
+
{"role": "assistant", "content": i["答案"]}
|
58 |
+
]
|
59 |
+
})
|
60 |
+
else:
|
61 |
+
tmp_jsonl.append({
|
62 |
+
"messages":[
|
63 |
+
{"role": "user", "content": i["提问"]},
|
64 |
+
{"role": "assistant", "content": i["答案"]}
|
65 |
+
]
|
66 |
+
})
|
67 |
+
else:
|
68 |
+
logging.warning(f"跳过一行数据,因为没有找到提问和答案: {i}")
|
69 |
+
jsonl = tmp_jsonl
|
70 |
+
dspath = jsonl_save_to_disk(jsonl, dspath)
|
71 |
+
try:
|
72 |
+
uploaded = openai.File.create(
|
73 |
+
file=open(dspath, "rb"),
|
74 |
+
purpose='fine-tune'
|
75 |
+
)
|
76 |
+
return uploaded.id, f"上传成功,文件ID: {uploaded.id}"
|
77 |
+
except Exception as e:
|
78 |
+
traceback.print_exc()
|
79 |
+
return "", f"上传失败,原因:{ e }"
|
80 |
+
|
81 |
+
def build_event_description(id, status, trained_tokens, name=i18n("暂时未知")):
|
82 |
+
# convert to markdown
|
83 |
+
return f"""
|
84 |
+
#### 训练任务 {id}
|
85 |
+
|
86 |
+
模型名称:{name}
|
87 |
+
|
88 |
+
状态:{status}
|
89 |
+
|
90 |
+
已经训练了 {trained_tokens} 个token
|
91 |
+
"""
|
92 |
+
|
93 |
+
def start_training(file_id, suffix, epochs):
|
94 |
+
openai.api_key = os.getenv("OPENAI_API_KEY")
|
95 |
+
try:
|
96 |
+
job = openai.FineTuningJob.create(training_file=file_id, model="gpt-3.5-turbo", suffix=suffix, hyperparameters={"n_epochs": epochs})
|
97 |
+
return build_event_description(job.id, job.status, job.trained_tokens)
|
98 |
+
except Exception as e:
|
99 |
+
traceback.print_exc()
|
100 |
+
if "is not ready" in str(e):
|
101 |
+
return "训练出错,因为文件还没准备好。OpenAI 需要一点时间准备文件,过几分钟再来试试。"
|
102 |
+
return f"训练失败,原因:{ e }"
|
103 |
+
|
104 |
+
def get_training_status():
|
105 |
+
openai.api_key = os.getenv("OPENAI_API_KEY")
|
106 |
+
active_jobs = [build_event_description(job["id"], job["status"], job["trained_tokens"], job["fine_tuned_model"]) for job in openai.FineTuningJob.list(limit=10)["data"] if job["status"] != "cancelled"]
|
107 |
+
return "\n\n".join(active_jobs), gr.update(interactive=True) if len(active_jobs) > 0 else gr.update(interactive=False)
|
108 |
+
|
109 |
+
def handle_dataset_clear():
|
110 |
+
return gr.update(value=None), gr.update(interactive=False)
|
111 |
+
|
112 |
+
def add_to_models():
|
113 |
+
openai.api_key = os.getenv("OPENAI_API_KEY")
|
114 |
+
succeeded_jobs = [job for job in openai.FineTuningJob.list(limit=10)["data"] if job["status"] == "succeeded"]
|
115 |
+
presets.MODELS.extend([job["fine_tuned_model"] for job in succeeded_jobs])
|
116 |
+
return gr.update(choices=presets.MODELS), f"成功添加了 {len(succeeded_jobs)} 个模型。"
|
modules/utils.py
CHANGED
@@ -5,14 +5,11 @@ import logging
|
|
5 |
import commentjson as json
|
6 |
import os
|
7 |
import datetime
|
8 |
-
from datetime import timezone
|
9 |
-
import hashlib
|
10 |
import csv
|
11 |
import requests
|
12 |
import re
|
13 |
import html
|
14 |
-
import
|
15 |
-
import subprocess
|
16 |
|
17 |
import gradio as gr
|
18 |
from pypinyin import lazy_pinyin
|
@@ -241,7 +238,7 @@ def convert_bot_before_marked(chat_message):
|
|
241 |
code_block_pattern = re.compile(r"```(.*?)(?:```|$)", re.DOTALL)
|
242 |
code_blocks = code_block_pattern.findall(chat_message)
|
243 |
non_code_parts = code_block_pattern.split(chat_message)[::2]
|
244 |
-
result = []
|
245 |
for non_code, code in zip(non_code_parts, code_blocks + [""]):
|
246 |
if non_code.strip():
|
247 |
result.append(non_code)
|
@@ -670,3 +667,16 @@ def auth_from_conf(username, password):
|
|
670 |
return False
|
671 |
except:
|
672 |
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
import commentjson as json
|
6 |
import os
|
7 |
import datetime
|
|
|
|
|
8 |
import csv
|
9 |
import requests
|
10 |
import re
|
11 |
import html
|
12 |
+
import hashlib
|
|
|
13 |
|
14 |
import gradio as gr
|
15 |
from pypinyin import lazy_pinyin
|
|
|
238 |
code_block_pattern = re.compile(r"```(.*?)(?:```|$)", re.DOTALL)
|
239 |
code_blocks = code_block_pattern.findall(chat_message)
|
240 |
non_code_parts = code_block_pattern.split(chat_message)[::2]
|
241 |
+
result = []
|
242 |
for non_code, code in zip(non_code_parts, code_blocks + [""]):
|
243 |
if non_code.strip():
|
244 |
result.append(non_code)
|
|
|
667 |
return False
|
668 |
except:
|
669 |
return False
|
670 |
+
|
671 |
+
def get_file_hash(file_src=None, file_paths=None):
|
672 |
+
if file_src:
|
673 |
+
file_paths = [x.name for x in file_src]
|
674 |
+
file_paths.sort(key=lambda x: os.path.basename(x))
|
675 |
+
|
676 |
+
md5_hash = hashlib.md5()
|
677 |
+
for file_path in file_paths:
|
678 |
+
with open(file_path, "rb") as f:
|
679 |
+
while chunk := f.read(8192):
|
680 |
+
md5_hash.update(chunk)
|
681 |
+
|
682 |
+
return md5_hash.hexdigest()
|
requirements.txt
CHANGED
@@ -21,7 +21,8 @@ duckduckgo-search
|
|
21 |
arxiv
|
22 |
wikipedia
|
23 |
google.generativeai
|
24 |
-
openai
|
25 |
unstructured
|
26 |
google-api-python-client
|
27 |
tabulate
|
|
|
|
21 |
arxiv
|
22 |
wikipedia
|
23 |
google.generativeai
|
24 |
+
openai>=0.27.9
|
25 |
unstructured
|
26 |
google-api-python-client
|
27 |
tabulate
|
28 |
+
ujson
|