kevineen's picture
class 使用使用
c6abd78
raw
history blame
No virus
9.51 kB
from typing import Optional, List
import gradio as gr
from datasets import load_dataset
from type.dataset_type import TanukiPhase2AnnotationDataset
class CurrentData:
def __init__(self, gr, annotation_dataset_list: List[str]):
# 選択中のデータセット 初期値は"hatakeyama-llm-team/AutoGeneratedJapaneseQA"
self.dropdown_dataset = gr.State(value=annotation_dataset_list[0])
self.current_dataset = gr.State(None) # 現在のデータ
self.current_idx = gr.State(0) # 現在のインデックス
self.target_dataset = gr.State(None) # 出力先データセット
self.initial_answer_text = gr.State("") # 回答1を整形したかチェック用
self.initial_answer_text_2 = gr.State("") # 回答2を整形したかチェック用
# アノテーション中のデータセット
class AnnotationState (TanukiPhase2AnnotationDataset):
def __init__(self, gr):
self.id = gr.State(0) # 出力先のデータセットをチェックし、末尾IDを追加
self.dataset = gr.State("") # 編集に使用したデータセット
self.dataset_id = gr.State(0) # 加工元データセットのidx
self.who = gr.State("") # アノテーション者
self.good = gr.State(False) # 良
self.bad = gr.State(False) # 悪
self.score = gr.State(3) # スコア 初期値は3
self.is_proofreading_1 = gr.State(False) # 回答1を整形したか_1
self.answer_text = gr.State("") # answer_1 回答
self.is_proofreading_2 = gr.State(False) # 回答2を整形したか_2
self.answer_text_2 = gr.State("") # answer_2 回答
# HF保存先
output_dataset = [
"kevineen/Phase2_dataset_annotation"
]
# アノテーションするデータセット
annotation_dataset_list = [
# "hatakeyama-llm-team/WikiBookJa", # 良・悪のみ
"hatakeyama-llm-team/AutoGeneratedJapaneseQA",
"hatakeyama-llm-team/AutoGeneratedJapaneseQA-other",
"kanhatakeyama/AutoWikiQA",
"kanhatakeyama/ChatbotArenaJaMixtral8x22b",
"kanhatakeyama/OrcaJaMixtral8x22b",
# "kanhatakeyama/AutoMultiTurnByMixtral8x22b", # マルチターン
"kanhatakeyama/LogicalDatasetsByMixtral8x22b",
]
current_data = CurrentData(gr, annotation_dataset_list)
annotation_state = AnnotationState(gr)
# 選択中のデータセット 初期値は"hatakeyama-llm-team/AutoGeneratedJapaneseQA",
dropdown_dataset = gr.State(value = annotation_dataset_list[0])
current_dataset = gr.State(None) # 現在のデータ
target_dataset = gr.State(None) # データセット
current_idx = gr.State(0) # 現在のインデックス
score = gr.State(3) # スコア 初期値は3
initial_answer_text = gr.State("") # 整えた答え
initial_answer_text_2 = gr.State("") # マルチターン用 整えた答え
labeled_output_dataset = gr.State(None) # 出力用
# 後のデザイン変更用
def load_css():
with open("style.css", "r") as file:
css_content = file.read()
return css_content
def hello(profile: gr.OAuthProfile | None) -> str:
if profile is None:
return "プライベートデータセット取得のためにログインしてください。"
return f'{profile.username}さん、よろしくお願いいたします。'
# 出力先のデータセットをロード
def load_target_dataset(oauth_token: gr.OAuthToken | None) -> str:
if oauth_token is None:
return "ログインしてからデータセットを表示してください。"
try:
labeled_output_dataset = load_dataset(output_dataset)
print(labeled_output_dataset)
return labeled_output_dataset
except Exception as e:
print(e)
return None
# データセットをロード
def load_data(current_dataset, oauth_token: gr.OAuthToken | None) -> str:
if oauth_token is None:
return "ログインしてからデータセットを表示してください。"
try:
dropdown_dataset = load_dataset(current_dataset)
return dropdown_dataset
except Exception as e:
return None
# データセットを表示
def display_dataset(current_dataset, profile: gr.OAuthProfile | None, oauth_token: gr.OAuthToken | None):
if profile is None:
return gr.update(visible=True, value="ログインしてデータセットを表示してください。"), None, None
target_dataset = load_data(current_dataset, oauth_token)
if target_dataset:
# "train" スプリットの最初のデータを取得
current_dataset = target_dataset['train'][current_idx.value]
question = current_dataset.get("question", "質問が見つかりません")
answer = current_dataset.get("answer", "解答が見つかりません")
# 初期値を設定しておく (出力前に、変更があればis_proofreadingをTrueにし、t_proofreadingにanswerを設定する)
initial_answer_text = answer
return gr.update(visible=False), gr.update(value=question, interactive=False), gr.update(value=answer)
else:
return gr.update(visible=True, value="データセットのロードに失敗しました。"), None, None
def switch_theme(theme):
if theme == "Dark":
return gr.themes.Default()
else:
return gr.themes.Monochrome()
theme_ = gr.State("Light")
with gr.Blocks(theme=theme_, css=load_css()) as demo:
gr.Markdown("# データセット アノテーション for Tanuki (Phase2)")
with gr.Tab("アノテーション"):
def update_theme():
new_theme = "Dark" if theme_.value == "Light" else "Light"
theme_.value = new_theme
return switch_theme(new_theme)
with gr.Row(equal_height=True):
gr.LoginButton(value="ログイン",logout_value="ログアウト", scale=1)
profile_name = gr.Markdown()
# お名前表示 出力データセット用
demo.load(hello, inputs=None, outputs=profile_name)
def choice_dataset_fn(choice_dataset):
return f"{choice_dataset}"
with gr.Row():
gr_current_dataset = gr.Dropdown(label="アノテーションするデータセット",
choices=annotation_dataset_list,
value=dropdown_dataset.value,
elem_id="dataset_sel")
data_load_btn = gr.Button("データセットを読み込む")
# データセットのロードメッセージ表示
dataset_load_message = gr.Markdown(visible=False)
gr_current_dataset.change(choice_dataset_fn, inputs=[gr_current_dataset])
dataset_display = gr.Markdown(visible=False)
question_text = gr.Textbox(label="質問: ", interactive=False)
with gr.Tab("シンプルモード(良い・悪いのみ選択)"):
with gr.Column():
with gr.Row(equal_height=True):
good_btn = gr.Button("良い")
bad_btn = gr.Button("悪い")
answer_text = gr.Textbox(label="回答:",lines=20, interactive=False)
data_load_btn.click(
display_dataset,
inputs=[gr_current_dataset],
outputs=[dataset_display, question_text, answer_text],
)
def good_click(current_dataset, profile: gr.OAuthProfile | None, oauth_token: gr.OAuthToken | None):
if profile is None:
return gr.update(visible=True, value="評価を行うにはログインしてください。"), None, None
# 現在のインデックスを更新
current_idx.value += 1
# データセットの更新と表示
return display_dataset(current_dataset, profile, oauth_token)
good_btn.click(
good_click,
inputs=[gr_current_dataset], # good_click に current_dataset を渡す
outputs=[dataset_display, question_text, answer_text]
)
def bad_click():
print("bad")
bad_btn.click(
bad_click,
inputs=[],
outputs=None
)
with gr.Tab("丁寧モード(5段階評価・文章校正)"):
score_slider = gr.Slider(1, 5, label="スコア: 1-5 (1:大変悪い、2:悪い、3:普通、4:良い、5:大変良い)", step=1, value=score.value, interactive=True)
answer_text = gr.Textbox(label="回答: 改行はEnterです。 文章を修正して頂けると、さらに高品質になります。", lines=20, elem_id="answer", interactive=True)
data_load_btn.click(
display_dataset,
inputs=[gr_current_dataset],
outputs=[dataset_display, question_text, answer_text],
)
with gr.Tab("アノテ済みデータセット(管理画面)"):
gr.Textbox("データセットID", lines=1, placeholder="データセットIDを入力してください。")
if __name__ == "__main__":
demo.launch()