merve's picture
merve HF staff
Update app.py
df494d8
raw
history blame
2.41 kB
import os
import gradio as gr
from gradio import FlaggingCallback
from gradio.components import IOComponent
from datasets import load_dataset
from typing import List, Optional, Any
import argilla as rg
import os
def load_data(idx):
df = load_dataset("merve/turkish_instructions", split="train").to_pandas()
sample = df.iloc[int(idx)]
instruction = sample[1]
if sample[2]:
input_sample = sample[2]
else:
input_sample="-"
response = sample[3]
return instruction, input_sample, response
def create_record(instruction, inpu_sample, response, feedback):
status = "Validated" if feedback == "Doğru" else "Default"
fields = {
"talimat": instruction,
"input": input_sample,
"response": response
}
# the label will come from the flag object in Gradio
label = "True"
record = rg.TextClassificationRecord(
inputs=fields,
annotation=label,
status=status,
metadata={"feedback": feedback}
)
print(record)
return record
class ArgillaLogger(FlaggingCallback):
def __init__(self, api_url, api_key, dataset_name):
rg.init(api_url=api_url, api_key=api_key)
self.dataset_name = dataset_name
def setup(self, components: List[IOComponent], flagging_dir: str):
pass
def flag(
self,
flag_data: List[Any],
flag_option: Optional[str] = None,
flag_index: Optional[int] = None,
username: Optional[str] = None,
) -> int:
text = flag_data[0]
inference = flag_data[1]
rg.log(name=self.dataset_name, records=create_record(text, flag_option))
idx_input = gr.Slider(minimum=0, maximum=51564, label="Satır")
instruction = gr.Textbox(label="Talimat")
input_sample = gr.Textbox(label="Girdi")
response = gr.Textbox(label="Çıktı")
gr.Interface(
load_data,
title = "ALPACA Veriseti Düzeltme Arayüzü",
description = "Bir satır sayısı verip örnek alın. Çeviride gözünüze doğru gelmeyen bir şey olursa işaretleyin.",
allow_flagging="manual",
flagging_callback=ArgillaLogger(
api_url="https://sandbox.argilla.io",
api_key=os.getenv("API_KEY"),
dataset_name="alpaca-flags"
),
inputs=[idx_input],
outputs=[instruction, input_sample, response],
flagging_options=["Doğru", "Yanlış", "Belirsiz"]
).launch()