File size: 2,727 Bytes
9841cab
bc39e07
 
9841cab
bc39e07
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a41e9e6
fc0609e
 
 
 
bc39e07
 
 
3b50e6e
 
 
 
 
 
 
6f857a8
bc39e07
ee6b045
bc39e07
 
3b50e6e
bc39e07
 
 
 
ee6b045
 
bc39e07
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch

tokenizer = AutoTokenizer.from_pretrained('allenai/longformer-scico')
model = AutoModelForSequenceClassification.from_pretrained('allenai/longformer-scico')

start_token = tokenizer.convert_tokens_to_ids("<m>")
end_token = tokenizer.convert_tokens_to_ids("</m>")

def get_global_attention(input_ids):
    global_attention_mask = torch.zeros(input_ids.shape)
    global_attention_mask[:, 0] = 1  # global attention to the CLS token
    start = torch.nonzero(input_ids == start_token) # global attention to the <m> token
    end = torch.nonzero(input_ids == end_token) # global attention to the </m> token
    globs = torch.cat((start, end))
    value = torch.ones(globs.shape[0])
    global_attention_mask.index_put_(tuple(globs.t()), value)
    return global_attention_mask
  
def inference(m1,m2):
  b = {}
  m1 = m1
  m2 = m2
  
  inputs = m1 + " </s></s> " + m2  
  
  tokens = tokenizer(inputs, return_tensors='pt')
  global_attention_mask = get_global_attention(tokens['input_ids'])
  
  with torch.no_grad():
      output = model(tokens['input_ids'], tokens['attention_mask'], global_attention_mask)
      
  scores = torch.softmax(output.logits, dim=-1)
  listscore = scores.tolist()
  print(listscore)
  b['not related'] = listscore[0][0]
  b['coref'] = listscore[0][1]
  b['parent'] = listscore[0][2]
  b['child'] = listscore[0][3]
  return b

title = "Longformer-scico"
description = """Gradio demo for Longformer-scico. To use it, simply add your text, or click one of the examples to load them. Read more at the links below. The model takes as input two mentions m1 and m2 with their corresponding context and outputs 4 scores:

0: not related
1: m1 and m2 corefer
2: m1 is a parent of m2
3: m1 is a child of m2."""

article = "<p style='text-align: center'><a href='https://openreview.net/forum?id=OFLbgUP04nC' target='_blank'>SciCo: Hierarchical Cross-Document Coreference for Scientific Concepts</a> | <a href='https://github.com/ariecattan/SciCo' target='_blank'>Github Repo</a></p>"

examples = [["In this paper we present the results of an experiment in <m> automatic concept and definition extraction </m> from written sources of law using relatively simple natural methods.","This task is important since many natural language processing (NLP) problems, such as <m> information extraction </m>, summarization and dialogue."]]
gr.Interface(
    inference, 
    [gr.inputs.Textbox(label="m1"),gr.inputs.Textbox(label="m2")], 
    gr.outputs.Label(label="Output"),
    title=title,
    description=description,
    article=article,
    enable_queue=True,
    examples=examples
    ).launch(debug=True)