import os import gradio as gr from gradio import FlaggingCallback from gradio.components import IOComponent from datasets import load_dataset from typing import List, Optional, Any import argilla as rg import os def load_data(): ds = load_dataset("merve/turkish_instructions", split="train", streaming=True) sample = next(iter(ds)) return sample def create_record(sample, feedback): status = "Validated" if feedback == "Doğru" else "Default" #sample = next(iter(ds)) fields = { "talimat": sample["talimat"], "input": sample["giriş"], "response": sample["çıktı"] } # the label will come from the flag object in Gradio label = "True" record = rg.TextClassificationRecord( inputs=fields, annotation=label, status=status, metadata={"feedback": feedback} ) print(record) return record class ArgillaLogger(FlaggingCallback): def __init__(self, api_url, api_key, dataset_name): rg.init(api_url=api_url, api_key=api_key) self.dataset_name = dataset_name def setup(self, components: List[IOComponent], flagging_dir: str): pass def flag( self, flag_data: List[Any], flag_option: Optional[str] = None, flag_index: Optional[int] = None, username: Optional[str] = None, ) -> int: text = flag_data[0] inference = flag_data[1] rg.log(name=self.dataset_name, records=create_record(text, flag_option)) gr.Interface( title = "ALPACA Veriseti Düzeltme Arayüzü", description = "", allow_flagging="manual", flagging_callback=ArgillaLogger( api_url="https://sandbox.argilla.io", api_key=os.getenv("TEAM_API_KEY"), dataset_name="alpaca-flags" ), flagging_options=["Doğru", "Yanlış", "Belirsiz"] ).launch()