import os import gradio as gr from gradio import FlaggingCallback from gradio.components import IOComponent from transformers import pipeline from typing import List, Optional, Any import argilla as rg import os nlp = pipeline("ner", model="mrm8488/bert-spanish-cased-finetuned-ner") examples = [ ["Mi nombre es Juan y vivo en Barcelona"] ] def create_record(input_text, feedback): # define the record status based on feedback # default means it needs to be reviewed --> "Incorrect" or "Ambiguous" # validated means it's correct and has been checked --> "Correct" status = "Validated" if feedback == "Doğru" else "Default" # Making the prediction predictions = nlp(input_text, aggregation_strategy="first") # Creating the predicted entities as a list of tuples (entity, start_char, end_char, score) prediction = [(pred["entity_group"], pred["start"], pred["end"], pred["score"]) for pred in predictions] # Create word tokens batch_encoding = nlp.tokenizer(input_text) word_ids = sorted(set(batch_encoding.word_ids()) - {None}) words = [] for word_id in word_ids: char_span = batch_encoding.word_to_chars(word_id) words.append(input_text[char_span.start:char_span.end]) # Building a TokenClassificationRecord record = rg.TokenClassificationRecord( text=input_text, tokens=words, prediction=prediction, prediction_agent="gradio_crowd", status=status, metadata={"feedback": feedback} ) print(record) return record class ArgillaLogger(FlaggingCallback): def __init__(self, api_url, api_key, dataset_name): rg.init(api_url=api_url, api_key=api_key) self.dataset_name = dataset_name def setup(self, components: List[IOComponent], flagging_dir: str): pass def flag( self, flag_data: List[Any], flag_option: Optional[str] = None, flag_index: Optional[int] = None, username: Optional[str] = None, ) -> int: text = flag_data[0] inference = flag_data[1] rg.log(name=self.dataset_name, records=create_record(text, flag_option)) gr.Interface.load( "mrm8488/bert-spanish-cased-finetuned-ner", examples=examples, title = "NER en Español, crowdsource con Argilla", description = "Ayudanos a mejorar este model introduciendo un ejemplo clasificandolo como correcto, incorrecto o ambiguo", allow_flagging="manual", flagging_callback=ArgillaLogger( api_url="https://dvilasuero-taller-somosnlp.hf.space", api_key=os.getenv("TEAM_API_KEY"), dataset_name="ner-flags" ), flagging_options=["Correcto", "Incorrecto", "Ambiguo"] ).launch()