khaerens commited on
Commit
70303d6
1 Parent(s): 390b2d8
.gitattributes CHANGED
@@ -1,27 +1,27 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bin.* filter=lfs diff=lfs merge=lfs -text
5
- *.bz2 filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.model filter=lfs diff=lfs merge=lfs -text
12
- *.msgpack filter=lfs diff=lfs merge=lfs -text
13
- *.onnx filter=lfs diff=lfs merge=lfs -text
14
- *.ot filter=lfs diff=lfs merge=lfs -text
15
- *.parquet filter=lfs diff=lfs merge=lfs -text
16
- *.pb filter=lfs diff=lfs merge=lfs -text
17
- *.pt filter=lfs diff=lfs merge=lfs -text
18
- *.pth filter=lfs diff=lfs merge=lfs -text
19
- *.rar filter=lfs diff=lfs merge=lfs -text
20
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
21
- *.tar.* filter=lfs diff=lfs merge=lfs -text
22
- *.tflite filter=lfs diff=lfs merge=lfs -text
23
- *.tgz filter=lfs diff=lfs merge=lfs -text
24
- *.xz filter=lfs diff=lfs merge=lfs -text
25
- *.zip filter=lfs diff=lfs merge=lfs -text
26
- *.zstandard filter=lfs diff=lfs merge=lfs -text
27
- *tfevents* filter=lfs diff=lfs merge=lfs -text
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bin.* filter=lfs diff=lfs merge=lfs -text
5
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.model filter=lfs diff=lfs merge=lfs -text
12
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
13
+ *.onnx filter=lfs diff=lfs merge=lfs -text
14
+ *.ot filter=lfs diff=lfs merge=lfs -text
15
+ *.parquet filter=lfs diff=lfs merge=lfs -text
16
+ *.pb filter=lfs diff=lfs merge=lfs -text
17
+ *.pt filter=lfs diff=lfs merge=lfs -text
18
+ *.pth filter=lfs diff=lfs merge=lfs -text
19
+ *.rar filter=lfs diff=lfs merge=lfs -text
20
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
21
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
22
+ *.tflite filter=lfs diff=lfs merge=lfs -text
23
+ *.tgz filter=lfs diff=lfs merge=lfs -text
24
+ *.xz filter=lfs diff=lfs merge=lfs -text
25
+ *.zip filter=lfs diff=lfs merge=lfs -text
26
+ *.zstandard filter=lfs diff=lfs merge=lfs -text
27
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
1
+ venv
2
+ test.html
.vscode/launch.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ // Use IntelliSense to learn about possible attributes.
3
+ // Hover to view descriptions of existing attributes.
4
+ // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
5
+ "version": "0.2.0",
6
+ "configurations": [
7
+ {
8
+ "name": "Python: Current File",
9
+ "type": "python",
10
+ "request": "launch",
11
+ "program": "${file}",
12
+ "console": "integratedTerminal",
13
+ "justMyCode": false
14
+ }
15
+ ]
16
+ }
.vscode/settings.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
1
+ {
2
+ "workbench.colorCustomizations": {
3
+ "activityBar.background": "#09323E",
4
+ "titleBar.activeBackground": "#0C4656",
5
+ "titleBar.activeForeground": "#F6FCFE"
6
+ }
7
+ }
README.md CHANGED
@@ -1,37 +1,37 @@
1
- ---
2
- title: REBEL
3
- emoji: 🏢
4
- colorFrom: green
5
- colorTo: blue
6
- sdk: streamlit
7
- app_file: app.py
8
- pinned: false
9
- ---
10
-
11
- # Configuration
12
-
13
- `title`: _string_
14
- Display title for the Space
15
-
16
- `emoji`: _string_
17
- Space emoji (emoji-only character allowed)
18
-
19
- `colorFrom`: _string_
20
- Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
21
-
22
- `colorTo`: _string_
23
- Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
24
-
25
- `sdk`: _string_
26
- Can be either `gradio`, `streamlit`, or `static`
27
-
28
- `sdk_version` : _string_
29
- Only applicable for `streamlit` SDK.
30
- See [doc](https://hf.co/docs/hub/spaces) for more info on supported versions.
31
-
32
- `app_file`: _string_
33
- Path to your main application file (which contains either `gradio` or `streamlit` Python code, or `static` html code).
34
- Path is relative to the root of the repository.
35
-
36
- `pinned`: _boolean_
37
- Whether the Space stays on top of your list.
1
+ ---
2
+ title: REBEL
3
+ emoji: 🏢
4
+ colorFrom: green
5
+ colorTo: blue
6
+ sdk: streamlit
7
+ app_file: app.py
8
+ pinned: false
9
+ ---
10
+
11
+ # Configuration
12
+
13
+ `title`: _string_
14
+ Display title for the Space
15
+
16
+ `emoji`: _string_
17
+ Space emoji (emoji-only character allowed)
18
+
19
+ `colorFrom`: _string_
20
+ Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
21
+
22
+ `colorTo`: _string_
23
+ Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
24
+
25
+ `sdk`: _string_
26
+ Can be either `gradio`, `streamlit`, or `static`
27
+
28
+ `sdk_version` : _string_
29
+ Only applicable for `streamlit` SDK.
30
+ See [doc](https://hf.co/docs/hub/spaces) for more info on supported versions.
31
+
32
+ `app_file`: _string_
33
+ Path to your main application file (which contains either `gradio` or `streamlit` Python code, or `static` html code).
34
+ Path is relative to the root of the repository.
35
+
36
+ `pinned`: _boolean_
37
+ Whether the Space stays on top of your list.
__pycache__/app.cpython-38.pyc ADDED
Binary file (3.55 kB). View file
__pycache__/rebel.cpython-38.pyc ADDED
Binary file (3.65 kB). View file
app.py CHANGED
@@ -1,4 +1,122 @@
 
 
1
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- x = st.slider('Select a value')
4
- st.write(x, 'squared is', x * x)
1
+ from logging import disable
2
+ from pkg_resources import EggMetadata
3
  import streamlit as st
4
+ import streamlit.components.v1 as components
5
+ import networkx as nx
6
+ import matplotlib.pyplot as plt
7
+ from pyvis.network import Network
8
+ from streamlit.state.session_state import SessionState
9
+ from streamlit.type_util import Key
10
+ import rebel
11
+ import wikipedia
12
+
13
+ network_filename = "test.html"
14
+
15
+ state_variables = {
16
+ 'has_run':False,
17
+ 'wiki_suggestions': "",
18
+ 'wiki_text' : [],
19
+ 'nodes':[]
20
+ }
21
+
22
+ for k, v in state_variables.items():
23
+ if k not in st.session_state:
24
+ st.session_state[k] = v
25
+
26
+ def clip_text(t, lenght = 5):
27
+ return ".".join(t.split(".")[:lenght]) + "."
28
+
29
+
30
+
31
+ def generate_graph():
32
+ if 'wiki_text' not in st.session_state:
33
+ return
34
+ if len(st.session_state['wiki_text']) == 0:
35
+ st.error("please enter a topic and select a wiki page first")
36
+ return
37
+ with st.spinner(text="Generating graph..."):
38
+ texts = st.session_state['wiki_text']
39
+ nodes = rebel.generate_knowledge_graph(texts, network_filename)
40
+ st.session_state['nodes'] = nodes
41
+ st.session_state['has_run'] = True
42
+ st.success('Done!')
43
+
44
+ def show_suggestion():
45
+ with st.spinner(text="fetching wiki topics..."):
46
+ if st.session_state['input_method'] == "wikipedia":
47
+ text = st.session_state.text
48
+ if text is not None:
49
+ st.session_state['wiki_suggestions'] = wikipedia.search(text, results = 3)
50
+
51
+ def show_wiki_text(page_title):
52
+ with st.spinner(text="fetching wiki page..."):
53
+ try:
54
+ page = wikipedia.page(title=page_title, auto_suggest=False)
55
+ st.session_state['wiki_text'].append(clip_text(page.summary))
56
+ except wikipedia.DisambiguationError as e:
57
+ with st.spinner(text="Woops, ambigious term, recalculating options..."):
58
+ st.session_state['wiki_suggestions'].remove(page_title)
59
+ temp = st.session_state['wiki_suggestions'] + e.options[:3]
60
+ st.session_state['wiki_suggestions'] = list(set(temp))
61
+
62
+ def add_text(term):
63
+ try:
64
+ extra_text = clip_text(wikipedia.page(title=term, auto_suggest=True).summary)
65
+ st.session_state['wiki_text'].append(extra_text)
66
+ except wikipedia.DisambiguationError as e:
67
+ st.session_state["nodes"].remove(term)
68
+
69
+
70
+ def reset_session():
71
+ for k in state_variables:
72
+ del st.session_state[k]
73
+
74
+ st.title('REBELious knowledge graph generation')
75
+ st.session_state['input_method'] = "wikipedia"
76
+
77
+ # st.selectbox(
78
+ # 'input method',
79
+ # ('wikipedia', 'free text'), key="input_method")
80
+
81
+ if st.session_state['input_method'] != "wikipedia":
82
+ st.text_area("Your text", key="text")
83
+ else:
84
+ st.text_input("wikipedia search term",on_change=show_suggestion, key="text")
85
+
86
+ if len(st.session_state['wiki_suggestions']) != 0:
87
+ columns = st.columns([1] * len(st.session_state['wiki_suggestions']))
88
+ for i, (c, s) in enumerate(zip(columns, st.session_state['wiki_suggestions'])):
89
+ with c:
90
+ st.button(s, on_click=show_wiki_text, args=(s,), key=i)
91
+
92
+ if len(st.session_state['wiki_text']) != 0:
93
+ for t in st.session_state['wiki_text']:
94
+ new_expander = st.expander(label=t[:30] + "...")
95
+ with new_expander:
96
+ st.markdown(t)
97
+
98
+ if st.session_state['input_method'] != "wikipedia":
99
+ st.button("find wiki pages")
100
+ if "wiki_suggestions" in st.session_state:
101
+ st.button("generate", on_click=generate_graph, key="gen_graph")
102
+
103
+ else:
104
+ st.button("generate", on_click=generate_graph, key="gen_graph2")
105
+
106
+
107
+ if st.session_state['has_run']:
108
+ cols = st.columns([4, 1])
109
+ with cols[0]:
110
+ HtmlFile = open(network_filename, 'r', encoding='utf-8')
111
+ source_code = HtmlFile.read()
112
+ components.html(source_code, height=1500,width=1500)
113
+ with cols[1]:
114
+ st.text("expand")
115
+ for s in st.session_state["nodes"]:
116
+ st.button(s, on_click=add_text, args=(s,))
117
+
118
+
119
+
120
+
121
+
122
 
 
 
rebel.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ from transformers import pipeline
3
+ from pyvis.network import Network
4
+ from functools import lru_cache
5
+ from app import generate_graph
6
+ import spacy
7
+ from spacy import displacy
8
+
9
+
10
+ DEFAULT_LABEL_COLORS = {
11
+ "ORG": "#7aecec",
12
+ "PRODUCT": "#bfeeb7",
13
+ "GPE": "#feca74",
14
+ "LOC": "#ff9561",
15
+ "PERSON": "#aa9cfc",
16
+ "NORP": "#c887fb",
17
+ "FACILITY": "#9cc9cc",
18
+ "EVENT": "#ffeb80",
19
+ "LAW": "#ff8197",
20
+ "LANGUAGE": "#ff8197",
21
+ "WORK_OF_ART": "#f0d0ff",
22
+ "DATE": "#bfe1d9",
23
+ "TIME": "#bfe1d9",
24
+ "MONEY": "#e4e7d2",
25
+ "QUANTITY": "#e4e7d2",
26
+ "ORDINAL": "#e4e7d2",
27
+ "CARDINAL": "#e4e7d2",
28
+ "PERCENT": "#e4e7d2",
29
+ }
30
+
31
+ def generate_knowledge_graph(texts: List[str], filename: str):
32
+ nlp = spacy.load("en_core_web_sm")
33
+ doc = nlp("\n".join(texts))
34
+ NERs = [ent.text for ent in doc.ents]
35
+ NER_types = [ent.label_ for ent in doc.ents]
36
+ for nr, nrt in zip(NERs, NER_types):
37
+ print(nr, nrt)
38
+
39
+ triplets = []
40
+ for triplet in texts:
41
+ triplets.extend(generate_partial_graph(triplet))
42
+ print(generate_partial_graph.cache_info())
43
+ heads = [ t["head"] for t in triplets]
44
+ tails = [ t["tail"] for t in triplets]
45
+
46
+ nodes = set(heads + tails)
47
+ net = Network(directed=True)
48
+
49
+ for n in nodes:
50
+ if n in NERs:
51
+ NER_type = NER_types[NERs.index(n)]
52
+ color = DEFAULT_LABEL_COLORS[NER_type]
53
+ net.add_node(n, title=NER_type, shape="circle", color=color)
54
+ else:
55
+ net.add_node(n, shape="circle")
56
+
57
+ unique_triplets = set()
58
+ stringify_trip = lambda x : x["tail"] + x["head"] + x["type"]
59
+ for triplet in triplets:
60
+ if stringify_trip(triplet) not in unique_triplets:
61
+ net.add_edge(triplet["tail"], triplet["head"], title=triplet["type"], label=triplet["type"])
62
+ unique_triplets.add(stringify_trip(triplet))
63
+
64
+ net.repulsion(
65
+ node_distance=200,
66
+ central_gravity=0.2,
67
+ spring_length=200,
68
+ spring_strength=0.05,
69
+ damping=0.09
70
+ )
71
+ net.set_edge_smooth('dynamic')
72
+ net.show(filename)
73
+ return nodes
74
+
75
+
76
+ @lru_cache
77
+ def generate_partial_graph(text):
78
+ triplet_extractor = pipeline('text2text-generation', model='Babelscape/rebel-large', tokenizer='Babelscape/rebel-large')
79
+ a = triplet_extractor(text, return_tensors=True, return_text=False)[0]["generated_token_ids"]["output_ids"]
80
+ extracted_text = triplet_extractor.tokenizer.batch_decode(a)
81
+ extracted_triplets = extract_triplets(extracted_text[0])
82
+ return extracted_triplets
83
+
84
+
85
+ def extract_triplets(text):
86
+ """
87
+ Function to parse the generated text and extract the triplets
88
+ """
89
+ triplets = []
90
+ relation, subject, relation, object_ = '', '', '', ''
91
+ text = text.strip()
92
+ current = 'x'
93
+ for token in text.replace("<s>", "").replace("<pad>", "").replace("</s>", "").split():
94
+ if token == "<triplet>":
95
+ current = 't'
96
+ if relation != '':
97
+ triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
98
+ relation = ''
99
+ subject = ''
100
+ elif token == "<subj>":
101
+ current = 's'
102
+ if relation != '':
103
+ triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
104
+ object_ = ''
105
+ elif token == "<obj>":
106
+ current = 'o'
107
+ relation = ''
108
+ else:
109
+ if current == 't':
110
+ subject += ' ' + token
111
+ elif current == 's':
112
+ object_ += ' ' + token
113
+ elif current == 'o':
114
+ relation += ' ' + token
115
+ if subject != '' and relation != '' and object_ != '':
116
+ triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
117
+
118
+ return triplets
119
+
requirements.txt ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ altair==4.2.0
2
+ argon2-cffi==21.3.0
3
+ argon2-cffi-bindings==21.2.0
4
+ astor==0.8.1
5
+ attrs==21.4.0
6
+ backcall==0.2.0
7
+ backports.zoneinfo==0.2.1
8
+ base58==2.1.1
9
+ beautifulsoup4==4.10.0
10
+ bleach==4.1.0
11
+ blinker==1.4
12
+ blis==0.7.5
13
+ cachetools==5.0.0
14
+ catalogue==2.0.6
15
+ certifi==2021.10.8
16
+ cffi==1.15.0
17
+ charset-normalizer==2.0.10
18
+ click==7.1.2
19
+ cycler==0.11.0
20
+ cymem==2.0.6
21
+ debugpy==1.5.1
22
+ decorator==5.1.1
23
+ defusedxml==0.7.1
24
+ en-core-web-sm==3.2.0
25
+ entrypoints==0.3
26
+ filelock==3.4.2
27
+ fonttools==4.28.5
28
+ gitdb==4.0.9
29
+ GitPython==3.1.26
30
+ got==0.0.1
31
+ huggingface-hub==0.2.1
32
+ idna==3.3
33
+ importlib-resources==5.4.0
34
+ ipykernel==6.6.1
35
+ ipython==7.31.0
36
+ ipython-genutils==0.2.0
37
+ ipywidgets==7.6.5
38
+ jedi==0.18.1
39
+ Jinja2==3.0.3
40
+ joblib==1.1.0
41
+ jsonpickle==2.0.0
42
+ jsonschema==4.3.3
43
+ jupyter-client==7.1.0
44
+ jupyter-core==4.9.1
45
+ jupyterlab-pygments==0.1.2
46
+ jupyterlab-widgets==1.0.2
47
+ kiwisolver==1.3.2
48
+ langcodes==3.3.0
49
+ MarkupSafe==2.0.1
50
+ matplotlib==3.5.1
51
+ matplotlib-inline==0.1.3
52
+ mistune==0.8.4
53
+ murmurhash==1.0.6
54
+ nbclient==0.5.9
55
+ nbconvert==6.4.0
56
+ nbformat==5.1.3
57
+ nest-asyncio==1.5.4
58
+ networkx==2.6.3
59
+ notebook==6.4.6
60
+ numpy==1.22.0
61
+ packaging==21.3
62
+ pandas==1.3.5
63
+ pandocfilters==1.5.0
64
+ parso==0.8.3
65
+ pathy==0.6.1
66
+ pexpect==4.8.0
67
+ pickleshare==0.7.5
68
+ Pillow==9.0.0
69
+ preshed==3.0.6
70
+ prometheus-client==0.12.0
71
+ prompt-toolkit==3.0.24
72
+ protobuf==3.19.3
73
+ ptyprocess==0.7.0
74
+ pyarrow==6.0.1
75
+ pycparser==2.21
76
+ pydantic==1.8.2
77
+ pydeck==0.7.1
78
+ Pygments==2.11.2
79
+ Pympler==1.0.1
80
+ pyparsing==3.0.6
81
+ pyrsistent==0.18.0
82
+ python-dateutil==2.8.2
83
+ pytz==2021.3
84
+ pytz-deprecation-shim==0.1.0.post0
85
+ pyvis==0.1.9
86
+ PyYAML==6.0
87
+ pyzmq==22.3.0
88
+ regex==2021.11.10
89
+ requests==2.27.1
90
+ sacremoses==0.0.47
91
+ Send2Trash==1.8.0
92
+ six==1.16.0
93
+ smart-open==5.2.1
94
+ smmap==5.0.0
95
+ soupsieve==2.3.1
96
+ spacy==3.2.1
97
+ spacy-legacy==3.0.8
98
+ spacy-loggers==1.0.1
99
+ srsly==2.4.2
100
+ streamlit==1.3.1
101
+ terminado==0.12.1
102
+ testpath==0.5.0
103
+ thinc==8.0.13
104
+ tokenizers==0.10.3
105
+ toml==0.10.2
106
+ toolz==0.11.2
107
+ torch==1.10.1
108
+ tornado==6.1
109
+ tqdm==4.62.3
110
+ traitlets==5.1.1
111
+ transformers==4.15.0
112
+ typer==0.4.0
113
+ typing-extensions==4.0.1
114
+ tzdata==2021.5
115
+ tzlocal==4.1
116
+ urllib3==1.26.8
117
+ validators==0.18.2
118
+ wasabi==0.9.0
119
+ watchdog==2.1.6
120
+ wcwidth==0.2.5
121
+ webencodings==0.5.1
122
+ widgetsnbextension==3.5.2
123
+ wikipedia==1.4.0
124
+ zipp==3.7.0
125
+ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.0.0/en_core_web_sm-3.0.0.tar.gz#egg=en_core_web_sm