merve HF staff commited on
Commit
03ff95e
β€’
1 Parent(s): 052dcc8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -1
app.py CHANGED
@@ -1,3 +1,72 @@
 
1
  import gradio as gr
 
 
2
 
3
- gr.Interface.load("huggingface/deprem-ml/deprem-ner").launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
  import gradio as gr
3
+ from gradio import FlaggingCallback
4
+ from gradio.components import IOComponent
5
 
6
+ from transformers import pipeline
7
+
8
+ from typing import List, Optional, Any
9
+
10
+ import argilla as rg
11
+
12
+ nlp = pipeline("ner", model="deprem-ml/deprem-ner")
13
+
14
+ examples = [
15
+ ["Lütfen yardım Akevler mahallesi Rüzgar sokak Tuncay apartmanı zemin kat Antakya akrabalarım gâçük altında #hatay #Afad"]
16
+ ]
17
+
18
+ def create_record(input_text):
19
+ # Making the prediction
20
+ predictions = nlp(input_text, aggregation_strategy="first")
21
+
22
+ # Creating the predicted entities as a list of tuples (entity, start_char, end_char, score)
23
+ prediction = [(pred["entity_group"], pred["start"], pred["end"], pred["score"]) for pred in predictions]
24
+
25
+ # Create word tokens
26
+ batch_encoding = nlp.tokenizer(input_text)
27
+ word_ids = sorted(set(batch_encoding.word_ids()) - {None})
28
+ words = []
29
+ for word_id in word_ids:
30
+ char_span = batch_encoding.word_to_chars(word_id)
31
+ words.append(input_text[char_span.start:char_span.end])
32
+
33
+ # Building a TokenClassificationRecord
34
+ record = rg.TokenClassificationRecord(
35
+ text=input_text,
36
+ tokens=words,
37
+ prediction=prediction,
38
+ prediction_agent="deprem-ml/deprem-ner",
39
+ )
40
+ print(record)
41
+ return record
42
+
43
+ class ArgillaLogger(FlaggingCallback):
44
+ def __init__(self, api_url, api_key, dataset_name):
45
+ rg.init(api_url=api_url, api_key=api_key)
46
+ self.dataset_name = dataset_name
47
+ def setup(self, components: List[IOComponent], flagging_dir: str):
48
+ pass
49
+ def flag(
50
+ self,
51
+ flag_data: List[Any],
52
+ flag_option: Optional[str] = None,
53
+ flag_index: Optional[int] = None,
54
+ username: Optional[str] = None,
55
+ ) -> int:
56
+ text = flag_data[0]
57
+ inference = flag_data[1]
58
+ rg.log(name=self.dataset_name, records=create_record(text))
59
+
60
+
61
+
62
+ gr.Interface.load(
63
+ "models/deprem-ml/deprem-ner",
64
+ examples=examples,
65
+ allow_flagging="manual",
66
+ flagging_callback=ArgillaLogger(
67
+ api_url="https://dvilasuero-argilla-template-1-3.hf.space",
68
+ api_key="team.apikey",
69
+ dataset_name="ner-flags"
70
+ ),
71
+ flaggging_options=["Correct", "Incorrect", "Ambiguous"]
72
+ ).launch()