AlzbetaStrompova commited on
Commit
f3898ef
·
1 Parent(s): 2a9fe7e

change output

Browse files
Files changed (2) hide show
  1. app.py +6 -3
  2. website_script.py +13 -1
app.py CHANGED
@@ -6,7 +6,11 @@ tokenizer, model, gazetteers_for_matching = load()
6
  print("Loaded model")
7
 
8
  examples = [
9
- "Masarykova univerzita",
 
 
 
 
10
  ]
11
 
12
  def ner(text):
@@ -15,8 +19,7 @@ def ner(text):
15
 
16
  demo = gr.Interface(ner,
17
  gr.Textbox(placeholder="Enter sentence here..."),
18
- "textbox",
19
- #gr.HighlightedText(), # TODO https://www.gradio.app/guides/named-entity-recognition
20
  examples=examples)
21
 
22
  if __name__ == "__main__":
 
6
  print("Loaded model")
7
 
8
  examples = [
9
+ "Masarykova univerzita se nachází v Brně.",
10
+ "Barack Obama navštívil Prahu minulý týden.",
11
+ "Angela Merkelová se setkala s francouzským prezidentem v Paříži.",
12
+ "Karel Čapek napsal knihu R.U.R., která byla poprvé představena v Praze.",
13
+ "Nobelova cena za fyziku byla udělena týmu vědců z MIT."
14
  ]
15
 
16
  def ner(text):
 
19
 
20
  demo = gr.Interface(ner,
21
  gr.Textbox(placeholder="Enter sentence here..."),
22
+ gr.HighlightedText(show_legend=True,),
 
23
  examples=examples)
24
 
25
  if __name__ == "__main__":
website_script.py CHANGED
@@ -44,4 +44,16 @@ def run(tokenizer, model, gazetteers_for_matching, text):
44
  output = model(input_ids=input_ids, attention_mask=attention_mask, per=per, org=org, loc=loc).logits
45
  predictions = torch.argmax(output, dim=2).tolist()
46
  predicted_tags = [[model.config.id2label[idx] for idx in sentence] for sentence in predictions]
47
- return " ".join(predicted_tags[0])
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  output = model(input_ids=input_ids, attention_mask=attention_mask, per=per, org=org, loc=loc).logits
45
  predictions = torch.argmax(output, dim=2).tolist()
46
  predicted_tags = [[model.config.id2label[idx] for idx in sentence] for sentence in predictions]
47
+
48
+ softmax = torch.nn.Softmax(dim=2)
49
+ scores = softmax(output).squeeze(0).tolist()
50
+ result = []
51
+ for pos, entity, score in zip(tokenized_inputs.offset_mapping, predicted_tags[0], scores):
52
+ result.append({
53
+ "start": pos[0],
54
+ "end": pos[1],
55
+ "entity": entity,
56
+ "score": max(score),
57
+ "word": text[pos[0]:pos[1]],
58
+ })
59
+ return result