vonewman commited on
Commit
f5b6e30
Β·
1 Parent(s): bbf7868

use predict_ner_labels

Browse files
Files changed (1) hide show
  1. app.py +53 -72
app.py CHANGED
@@ -1,125 +1,106 @@
1
  import streamlit as st
2
  import pandas as pd
3
- import numpy as np
4
  import re
5
  import json
6
- import base64
7
- import uuid
8
-
9
  import transformers
10
- from datasets import Dataset,load_dataset, load_from_disk
11
  from transformers import AutoTokenizer, AutoModelForTokenClassification, Trainer
12
 
13
-
14
  st.set_page_config(
15
- page_title="Named Entity Recognition Wolof", page_icon="πŸ“˜"
 
16
  )
17
 
 
 
18
 
19
- def convert_df(df:pd.DataFrame):
20
- return df.to_csv(index=False).encode('utf-8')
21
-
22
- #@st.cache
23
- def convert_json(df:pd.DataFrame):
24
  result = df.to_json(orient="index")
25
  parsed = json.loads(result)
26
  json_string = json.dumps(parsed)
27
- #st.json(json_string, expanded=True)
28
  return json_string
29
 
30
- st.title("πŸ“˜Named Entity Recognition Wolof")
31
-
32
- @st.cache(allow_output_mutation=True)
33
  def load_model():
34
-
35
  model = AutoModelForTokenClassification.from_pretrained("vonewman/wolof-finetuned-ner")
36
  trainer = Trainer(model=model)
37
-
38
  tokenizer = AutoTokenizer.from_pretrained("vonewman/wolof-finetuned-ner")
39
-
40
  return trainer, model, tokenizer
41
 
42
- id2tag = {0: 'O',
43
- 1: 'B-LOC',
44
- 2: 'B-PER',
45
- 3: 'I-PER',
46
- 4: 'B-ORG',
47
- 5: 'I-DATE',
48
- 6: 'B-DATE',
49
- 7: 'I-ORG',
50
- 8: 'I-LOC'
51
- }
52
-
53
- def tag_sentence(text:str):
54
- # convert our text to a tokenized sequence
55
- inputs = tokenizer(text, truncation=True, return_tensors="pt")
56
- # get outputs
57
- outputs = model(**inputs)
58
- # convert to probabilities with softmax
59
- probs = outputs[0][0].softmax(1)
60
- # get the tags with the highest probability
61
- word_tags = [(tokenizer.decode(inputs['input_ids'][0][i].item()), id2tag[tagid.item()], np.round(probs[i][tagid].item() *100,2) )
62
- for i, tagid in enumerate (probs.argmax(axis=1))]
63
-
64
- df=pd.DataFrame(word_tags, columns=['word', 'tag', 'probability'])
65
- return df
66
 
 
 
67
 
68
- with st.form(key='my_form'):
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
 
 
 
 
 
 
 
 
 
 
 
70
  x1 = st.text_input(label='Enter a sentence:', max_chars=250)
71
- print(x1)
72
  submit_button = st.form_submit_button(label='🏷️ Create tags')
73
 
74
-
75
  if submit_button:
76
- if re.sub('\s+','',x1)=='':
77
  st.error('Please enter a non-empty sentence.')
78
-
79
  elif re.match(r'\A\s*\w+\s*\Z', x1):
80
  st.error("Please enter a sentence with at least one word")
81
-
82
  else:
83
  st.markdown("### Tagged Sentence")
84
  st.header("")
85
 
86
- Trainer, model, tokenizer = load_model()
87
- results=tag_sentence(x1)
88
-
89
  cs, c1, c2, c3, cLast = st.columns([0.75, 1.5, 1.5, 1.5, 0.75])
90
 
91
  with c1:
92
- #csvbutton = download_button(results, "results.csv", "πŸ“₯ Download .csv")
93
- csvbutton = st.download_button(label="πŸ“₯ Download .csv", data=convert_df(results),
94
- file_name= "results.csv", mime='text/csv', key='csv')
95
  with c2:
96
- #textbutton = download_button(results, "results.txt", "πŸ“₯ Download .txt")
97
- textbutton = st.download_button(label="πŸ“₯ Download .txt", data=convert_df(results),
98
- file_name= "results.text", mime='text/plain', key='text')
99
  with c3:
100
- #jsonbutton = download_button(results, "results.json", "πŸ“₯ Download .json")
101
- jsonbutton = st.download_button(label="πŸ“₯ Download .json", data=convert_json(results),
102
- file_name= "results.json", mime='application/json', key='json')
103
 
104
  st.header("")
105
-
106
  c1, c2, c3 = st.columns([1, 3, 1])
107
-
108
- with c2:
109
 
110
- st.table(results.style.background_gradient(subset=['probability']).format(precision=2))
 
111
 
112
  st.header("")
113
  st.header("")
114
  st.header("")
115
  with st.expander("ℹ️ - About this app", expanded=True):
116
-
117
-
118
  st.write(
119
  """
120
  - The **Named Entity Recognition Wolof** app is a tool that performs named entity recognition in Wolof.
121
- - The available entitites are: *corporation*, *location*, *person* and *date*.
122
- - The app uses the [XLMRoberta model](https://huggingface.co/xlm-roberta-base), fine-tuned on the [masakhaNER](https://huggingface.co/datasets/masakhane/masakhaner2) dataset.
123
- - The model uses the **byte-level BPE tokenizer**. Each sentece is first tokenized.
124
- """
125
- )
 
1
  import streamlit as st
2
  import pandas as pd
 
3
  import re
4
  import json
 
 
 
5
  import transformers
6
+ import torch
7
  from transformers import AutoTokenizer, AutoModelForTokenClassification, Trainer
8
 
 
9
  st.set_page_config(
10
+ page_title="Named Entity Recognition Wolof",
11
+ page_icon="πŸ“˜"
12
  )
13
 
14
+ def convert_df(df: pd.DataFrame):
15
+ return df.to_csv(index=False).encode('utf-8')
16
 
17
+ def convert_json(df: pd.DataFrame):
 
 
 
 
18
  result = df.to_json(orient="index")
19
  parsed = json.loads(result)
20
  json_string = json.dumps(parsed)
 
21
  return json_string
22
 
 
 
 
23
  def load_model():
 
24
  model = AutoModelForTokenClassification.from_pretrained("vonewman/wolof-finetuned-ner")
25
  trainer = Trainer(model=model)
 
26
  tokenizer = AutoTokenizer.from_pretrained("vonewman/wolof-finetuned-ner")
 
27
  return trainer, model, tokenizer
28
 
29
+ def predict_ner_labels(model, tokenizer, sentence):
30
+ use_cuda = torch.cuda.is_available()
31
+ device = torch.device("cuda" if use_cuda else "cpu")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
+ if use_cuda:
34
+ model = model.cuda()
35
 
36
+ text = tokenizer(sentence, padding='max_length', max_length=218, truncation=True, return_tensors="pt")
37
+ mask = text['attention_mask'].to(device)
38
+ input_id = text['input_ids'].to(device)
39
+ label_ids = torch.Tensor(align_word_ids(sentence)).unsqueeze(0).to(device)
40
+
41
+ logits = model(input_id, mask, None)
42
+ logits_clean = logits[0][label_ids != -100]
43
+
44
+ predictions = logits_clean.argmax(dim=1).tolist()
45
+ prediction_label = [id2tag[i] for i in predictions]
46
+
47
+ return prediction_label
48
+
49
+ id2tag = {0: 'O', 1: 'B-LOC', 2: 'B-PER', 3: 'I-PER', 4: 'B-ORG', 5: 'I-DATE', 6: 'B-DATE', 7: 'I-ORG', 8: 'I-LOC'}
50
 
51
+ def tag_sentence(text):
52
+ trainer, model, tokenizer = load_model()
53
+ predictions = predict_ner_labels(model, tokenizer, text)
54
+ df = pd.DataFrame(predictions, columns=['tag'])
55
+ df['word'] = text.split()
56
+ df['probability'] = 100.0 # Vous pouvez ajuster cette valeur selon vos besoins
57
+ return df
58
+
59
+ st.title("πŸ“˜ Named Entity Recognition Wolof")
60
+
61
+ with st.form(key='my_form'):
62
  x1 = st.text_input(label='Enter a sentence:', max_chars=250)
 
63
  submit_button = st.form_submit_button(label='🏷️ Create tags')
64
 
 
65
  if submit_button:
66
+ if re.sub('\s+', '', x1) == '':
67
  st.error('Please enter a non-empty sentence.')
 
68
  elif re.match(r'\A\s*\w+\s*\Z', x1):
69
  st.error("Please enter a sentence with at least one word")
 
70
  else:
71
  st.markdown("### Tagged Sentence")
72
  st.header("")
73
 
74
+ results = tag_sentence(x1)
75
+
 
76
  cs, c1, c2, c3, cLast = st.columns([0.75, 1.5, 1.5, 1.5, 0.75])
77
 
78
  with c1:
79
+ csvbutton = st.download_button(label="πŸ“₯ Download .csv", data=convert_df(results),
80
+ file_name="results.csv", mime='text/csv', key='csv')
 
81
  with c2:
82
+ textbutton = st.download_button(label="πŸ“₯ Download .txt", data=convert_df(results),
83
+ file_name="results.text", mime='text/plain', key='text')
 
84
  with c3:
85
+ jsonbutton = st.download_button(label="πŸ“₯ Download .json", data=convert_json(results),
86
+ file_name="results.json", mime='application/json', key='json')
 
87
 
88
  st.header("")
89
+
90
  c1, c2, c3 = st.columns([1, 3, 1])
 
 
91
 
92
+ with c2:
93
+ st.table(results.style.background_gradient(subset=['probability']).format(precision=2))
94
 
95
  st.header("")
96
  st.header("")
97
  st.header("")
98
  with st.expander("ℹ️ - About this app", expanded=True):
 
 
99
  st.write(
100
  """
101
  - The **Named Entity Recognition Wolof** app is a tool that performs named entity recognition in Wolof.
102
+ - The available entities are: *corporation*, *location*, *person*, and *date*.
103
+ - The app uses the [XLMRoberta model](https://huggingface.co/xlm-roberta-base), fine-tuned on the [masakhaNER](https://huggingface.co/datasets/masakhane/masakhaner2) dataset.
104
+ - The model uses the **byte-level BPE tokenizer**. Each sentence is first tokenized.
105
+ """
106
+ )