File size: 2,756 Bytes
a4f1003
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51d268c
a4f1003
 
51d268c
a4f1003
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51d268c
a4f1003
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51d268c
a4f1003
51d268c
 
a4f1003
 
51d268c
43183c7
a4f1003
 
51d268c
a4f1003
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import os
import gradio as gr
from gradio import FlaggingCallback
from gradio.components import IOComponent

from transformers import pipeline

from typing import List, Optional, Any

import argilla as rg

import os



nlp = pipeline("ner", model="mrm8488/bert-spanish-cased-finetuned-ner")

examples = [
  ["Mi nombre es Juan y vivo en Barcelona"]
]

def create_record(input_text, feedback):
    # define the record status based on feedback
    # default means it needs to be reviewed --> "Incorrect" or "Ambiguous"
    # validated means it's correct and has been checked --> "Correct"
    status = "Validated" if feedback == "Doğru" else "Default"
    
    # Making the prediction
    predictions = nlp(input_text, aggregation_strategy="first")
    
    # Creating the predicted entities as a list of tuples (entity, start_char, end_char, score)
    prediction = [(pred["entity_group"], pred["start"], pred["end"], pred["score"]) for pred in predictions]
    
    # Create word tokens
    batch_encoding = nlp.tokenizer(input_text)
    word_ids = sorted(set(batch_encoding.word_ids()) - {None})
    words = []
    for word_id in word_ids:
        char_span = batch_encoding.word_to_chars(word_id)
        words.append(input_text[char_span.start:char_span.end])
    
    # Building a TokenClassificationRecord
    record = rg.TokenClassificationRecord(
        text=input_text,
        tokens=words,
        prediction=prediction,
        prediction_agent="gradio_crowd",
        status=status,
        metadata={"feedback": feedback}
    )
    print(record)
    return record

class ArgillaLogger(FlaggingCallback):
    def __init__(self, api_url, api_key, dataset_name):
        rg.init(api_url=api_url, api_key=api_key)
        self.dataset_name = dataset_name
    def setup(self, components: List[IOComponent], flagging_dir: str):
        pass
    def flag(
        self,
        flag_data: List[Any],
        flag_option: Optional[str] = None,
        flag_index: Optional[int] = None,
        username: Optional[str] = None,
    ) -> int:
        text = flag_data[0]
        inference = flag_data[1]
        rg.log(name=self.dataset_name, records=create_record(text, flag_option))


        
gr.Interface.load(
    "mrm8488/bert-spanish-cased-finetuned-ner",
    examples=examples,
    title = "NER en EspaΓ±ol, crowdsource con Argilla",
    description = "Ayudanos a mejorar este model introduciendo un ejemplo clasificandolo como correcto, incorrecto o ambiguo",
    allow_flagging="manual",
    flagging_callback=ArgillaLogger(
        api_url="https://dvilasuero-taller-somosnlp.hf.space", 
        api_key=os.getenv("TEAM_API_KEY"), 
        dataset_name="ner-flags"
    ),
    flagging_options=["Correcto", "Incorrecto", "Ambiguo"]
).launch()