Spaces:
Sleeping
Sleeping
from typing import Optional, List | |
import gradio as gr | |
from datasets import load_dataset | |
from type.dataset_type import TanukiPhase2AnnotationDataset | |
class CurrentData: | |
def __init__(self, gr, annotation_dataset_list: List[str]): | |
# 選択中のデータセット 初期値は"hatakeyama-llm-team/AutoGeneratedJapaneseQA" | |
self.dropdown_dataset = gr.State(value=annotation_dataset_list[0]) | |
self.current_dataset = gr.State(None) # 現在のデータ | |
self.current_idx = gr.State(0) # 現在のインデックス | |
self.target_dataset = gr.State(None) # 出力先データセット | |
self.initial_answer_text = gr.State("") # 回答1を整形したかチェック用 | |
self.initial_answer_text_2 = gr.State("") # 回答2を整形したかチェック用 | |
# アノテーション中のデータセット | |
class AnnotationState (TanukiPhase2AnnotationDataset): | |
def __init__(self, gr): | |
self.id = gr.State(0) # 出力先のデータセットをチェックし、末尾IDを追加 | |
self.dataset = gr.State("") # 編集に使用したデータセット | |
self.dataset_id = gr.State(0) # 加工元データセットのidx | |
self.who = gr.State("") # アノテーション者 | |
self.good = gr.State(False) # 良 | |
self.bad = gr.State(False) # 悪 | |
self.score = gr.State(3) # スコア 初期値は3 | |
self.is_proofreading_1 = gr.State(False) # 回答1を整形したか_1 | |
self.answer_text = gr.State("") # answer_1 回答 | |
self.is_proofreading_2 = gr.State(False) # 回答2を整形したか_2 | |
self.answer_text_2 = gr.State("") # answer_2 回答 | |
# HF保存先 | |
output_dataset = [ | |
"kevineen/Phase2_dataset_annotation" | |
] | |
# アノテーションするデータセット | |
annotation_dataset_list = [ | |
# "hatakeyama-llm-team/WikiBookJa", # 良・悪のみ | |
"hatakeyama-llm-team/AutoGeneratedJapaneseQA", | |
"hatakeyama-llm-team/AutoGeneratedJapaneseQA-other", | |
"kanhatakeyama/AutoWikiQA", | |
"kanhatakeyama/ChatbotArenaJaMixtral8x22b", | |
"kanhatakeyama/OrcaJaMixtral8x22b", | |
# "kanhatakeyama/AutoMultiTurnByMixtral8x22b", # マルチターン | |
"kanhatakeyama/LogicalDatasetsByMixtral8x22b", | |
] | |
current_data = CurrentData(gr, annotation_dataset_list) | |
annotation_state = AnnotationState(gr) | |
# 選択中のデータセット 初期値は"hatakeyama-llm-team/AutoGeneratedJapaneseQA", | |
dropdown_dataset = gr.State(value = annotation_dataset_list[0]) | |
current_dataset = gr.State(None) # 現在のデータ | |
target_dataset = gr.State(None) # データセット | |
current_idx = gr.State(0) # 現在のインデックス | |
score = gr.State(3) # スコア 初期値は3 | |
initial_answer_text = gr.State("") # 整えた答え | |
initial_answer_text_2 = gr.State("") # マルチターン用 整えた答え | |
labeled_output_dataset = gr.State(None) # 出力用 | |
# 後のデザイン変更用 | |
def load_css(): | |
with open("style.css", "r") as file: | |
css_content = file.read() | |
return css_content | |
def hello(profile: gr.OAuthProfile | None) -> str: | |
if profile is None: | |
return "プライベートデータセット取得のためにログインしてください。" | |
return f'{profile.username}さん、よろしくお願いいたします。' | |
# 出力先のデータセットをロード | |
def load_target_dataset(oauth_token: gr.OAuthToken | None) -> str: | |
if oauth_token is None: | |
return "ログインしてからデータセットを表示してください。" | |
try: | |
labeled_output_dataset = load_dataset(output_dataset) | |
print(labeled_output_dataset) | |
return labeled_output_dataset | |
except Exception as e: | |
print(e) | |
return None | |
# データセットをロード | |
def load_data(current_dataset, oauth_token: gr.OAuthToken | None) -> str: | |
if oauth_token is None: | |
return "ログインしてからデータセットを表示してください。" | |
try: | |
dropdown_dataset = load_dataset(current_dataset) | |
return dropdown_dataset | |
except Exception as e: | |
return None | |
# データセットを表示 | |
def display_dataset(current_dataset, profile: gr.OAuthProfile | None, oauth_token: gr.OAuthToken | None): | |
if profile is None: | |
return gr.update(visible=True, value="ログインしてデータセットを表示してください。"), None, None | |
target_dataset = load_data(current_dataset, oauth_token) | |
if target_dataset: | |
# "train" スプリットの最初のデータを取得 | |
current_dataset = target_dataset['train'][current_idx.value] | |
question = current_dataset.get("question", "質問が見つかりません") | |
answer = current_dataset.get("answer", "解答が見つかりません") | |
# 初期値を設定しておく (出力前に、変更があればis_proofreadingをTrueにし、t_proofreadingにanswerを設定する) | |
initial_answer_text = answer | |
return gr.update(visible=False), gr.update(value=question, interactive=False), gr.update(value=answer) | |
else: | |
return gr.update(visible=True, value="データセットのロードに失敗しました。"), None, None | |
def switch_theme(theme): | |
if theme == "Dark": | |
return gr.themes.Default() | |
else: | |
return gr.themes.Monochrome() | |
theme_ = gr.State("Light") | |
with gr.Blocks(theme=theme_, css=load_css()) as demo: | |
gr.Markdown("# データセット アノテーション for Tanuki (Phase2)") | |
with gr.Tab("アノテーション"): | |
def update_theme(): | |
new_theme = "Dark" if theme_.value == "Light" else "Light" | |
theme_.value = new_theme | |
return switch_theme(new_theme) | |
with gr.Row(equal_height=True): | |
gr.LoginButton(value="ログイン",logout_value="ログアウト", scale=1) | |
profile_name = gr.Markdown() | |
# お名前表示 出力データセット用 | |
demo.load(hello, inputs=None, outputs=profile_name) | |
def choice_dataset_fn(choice_dataset): | |
return f"{choice_dataset}" | |
with gr.Row(): | |
gr_current_dataset = gr.Dropdown(label="アノテーションするデータセット", | |
choices=annotation_dataset_list, | |
value=dropdown_dataset.value, | |
elem_id="dataset_sel") | |
data_load_btn = gr.Button("データセットを読み込む") | |
# データセットのロードメッセージ表示 | |
dataset_load_message = gr.Markdown(visible=False) | |
gr_current_dataset.change(choice_dataset_fn, inputs=[gr_current_dataset]) | |
dataset_display = gr.Markdown(visible=False) | |
question_text = gr.Textbox(label="質問: ", interactive=False) | |
with gr.Tab("シンプルモード(良い・悪いのみ選択)"): | |
with gr.Column(): | |
with gr.Row(equal_height=True): | |
good_btn = gr.Button("良い") | |
bad_btn = gr.Button("悪い") | |
answer_text = gr.Textbox(label="回答:",lines=20, interactive=False) | |
data_load_btn.click( | |
display_dataset, | |
inputs=[gr_current_dataset], | |
outputs=[dataset_display, question_text, answer_text], | |
) | |
def good_click(current_dataset, profile: gr.OAuthProfile | None, oauth_token: gr.OAuthToken | None): | |
if profile is None: | |
return gr.update(visible=True, value="評価を行うにはログインしてください。"), None, None | |
# 現在のインデックスを更新 | |
current_idx.value += 1 | |
# データセットの更新と表示 | |
return display_dataset(current_dataset, profile, oauth_token) | |
good_btn.click( | |
good_click, | |
inputs=[gr_current_dataset], # good_click に current_dataset を渡す | |
outputs=[dataset_display, question_text, answer_text] | |
) | |
def bad_click(): | |
print("bad") | |
bad_btn.click( | |
bad_click, | |
inputs=[], | |
outputs=None | |
) | |
with gr.Tab("丁寧モード(5段階評価・文章校正)"): | |
score_slider = gr.Slider(1, 5, label="スコア: 1-5 (1:大変悪い、2:悪い、3:普通、4:良い、5:大変良い)", step=1, value=score.value, interactive=True) | |
answer_text = gr.Textbox(label="回答: 改行はEnterです。 文章を修正して頂けると、さらに高品質になります。", lines=20, elem_id="answer", interactive=True) | |
data_load_btn.click( | |
display_dataset, | |
inputs=[gr_current_dataset], | |
outputs=[dataset_display, question_text, answer_text], | |
) | |
with gr.Tab("アノテ済みデータセット(管理画面)"): | |
gr.Textbox("データセットID", lines=1, placeholder="データセットIDを入力してください。") | |
if __name__ == "__main__": | |
demo.launch() | |