import os import gradio as gr from gradio import FlaggingCallback from gradio.components import IOComponent from transformers import pipeline from typing import List, Optional, Any import argilla as rg import os nlp = pipeline("ner", model="deprem-ml/deprem-ner") examples = [ ["Lütfen yardım Akevler mahallesi Rüzgar sokak Tuncay apartmanı zemin kat Antakya akrabalarım göçük altında #hatay #Afad"] ] def create_record(input_text, feedback): # define the record status based on feedback # default means it needs to be reviewed --> "Incorrect" or "Ambiguous" # validated means it's correct and has been checked --> "Correct" status = "Validated" if feedback == "Correct" else "Default" # Making the prediction predictions = nlp(input_text, aggregation_strategy="first") # Creating the predicted entities as a list of tuples (entity, start_char, end_char, score) prediction = [(pred["entity_group"], pred["start"], pred["end"], pred["score"]) for pred in predictions] # Create word tokens batch_encoding = nlp.tokenizer(input_text) word_ids = sorted(set(batch_encoding.word_ids()) - {None}) words = [] for word_id in word_ids: char_span = batch_encoding.word_to_chars(word_id) words.append(input_text[char_span.start:char_span.end]) # Building a TokenClassificationRecord record = rg.TokenClassificationRecord( text=input_text, tokens=words, prediction=prediction, prediction_agent="deprem-ml/deprem-ner", status=status, metadata={"feedback": feedback} ) print(record) return record class ArgillaLogger(FlaggingCallback): def __init__(self, api_url, api_key, dataset_name): rg.init(api_url=api_url, api_key=api_key) self.dataset_name = dataset_name def setup(self, components: List[IOComponent], flagging_dir: str): pass def flag( self, flag_data: List[Any], flag_option: Optional[str] = None, flag_index: Optional[int] = None, username: Optional[str] = None, ) -> int: text = flag_data[0] inference = flag_data[1] rg.log(name=self.dataset_name, records=create_record(text), flag_option) gr.Interface.load( "models/deprem-ml/deprem-ner", examples=examples, allow_flagging="manual", flagging_callback=ArgillaLogger( api_url="https://merve-argilla.hf.space", api_key=os.getenv("TEAM_API_KEY"), dataset_name="ner-flags" ), flagging_options=["Correct", "Incorrect", "Ambiguous"] ).launch()