import gradio as gr from gradio import FlaggingCallback from gradio.components import IOComponent from typing import List, Optional, Any import argilla as rg class ArgillaLogger(FlaggingCallback): def __init__(self, api_url, api_key): rg.init(api_url=api_url, api_key=api_key) def setup(self, components: List[IOComponent], flagging_dir: str): pass def flag( self, flag_data: List[Any], flag_option: Optional[str] = None, flag_index: Optional[int] = None, username: Optional[str] = None, ) -> int: text = flag_data[0] inference = flag_data[1] prediction = [(pred["label"], pred["confidence"]) for pred in inference["confidences"]] rg.log( name="sentiment_feedback", records=rg.TextClassificationRecord(text=text, prediction=prediction) ) pass gr.Interface.load( "models/cardiffnlp/twitter-roberta-base-sentiment-latest", allow_flagging="manual", flagging_callback=ArgillaLogger(api_url="https://dvilasuero-argilla-template-space.hf.space", api_key="team.apikey") ).launch()