Atharva commited on
Commit
5230da2
1 Parent(s): fe0139a

pipeline update

Browse files
Files changed (2) hide show
  1. app.py +2 -3
  2. src/__init__.py +43 -12
app.py CHANGED
@@ -1,9 +1,8 @@
1
- from turtle import color
2
  import pandas as pd
3
  import streamlit as st
4
  from transformers import AutoModelForTokenClassification, AutoTokenizer, pipeline
5
 
6
- from src import GBRT, wikipedia_search, google_search
7
 
8
  TYPE = {
9
  'LOC': ' location',
@@ -55,7 +54,7 @@ def get_candidates(mentions_tags):
55
  candidates.append((mention, cache[(mention, tag)]))
56
  else:
57
  res1 = google_search(mention + TYPE[tag], limit=3)
58
- res2 = wikipedia_search(mention, limit=3)
59
  cands = list(set(res1 + res2))
60
  cache[(mention, tag)] = cands
61
  candidates.append((mention, cands))
 
 
1
  import pandas as pd
2
  import streamlit as st
3
  from transformers import AutoModelForTokenClassification, AutoTokenizer, pipeline
4
 
5
+ from src import GBRT, google_search, wikidata_search
6
 
7
  TYPE = {
8
  'LOC': ' location',
 
54
  candidates.append((mention, cache[(mention, tag)]))
55
  else:
56
  res1 = google_search(mention + TYPE[tag], limit=3)
57
+ res2 = wikidata_search(mention, limit=3)
58
  cands = list(set(res1 + res2))
59
  cache[(mention, tag)] = cands
60
  candidates.append((mention, cands))
src/__init__.py CHANGED
@@ -60,23 +60,53 @@ def cosine_similarity(v1, v2):
60
  return np.dot(v2, v1) / v1v2
61
 
62
 
63
- def wikipedia_search(query, limit=20):
64
- service_url = 'https://en.wikipedia.org/w/api.php'
65
  params = {
66
- 'action': 'opensearch',
67
- 'search': query,
68
- 'namespace': 0,
69
- 'limit': limit,
70
- 'redirects': 'resolve',
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  }
72
 
73
- results = requests.get(service_url, params=params).json()[1]
74
- results = [i.replace(' ', '_')
75
- for i in results if 'disambiguation' not in i.lower()]
76
- return results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
 
79
- def google_search(query, limit=10):
80
  service_url = "https://www.googleapis.com/customsearch/v1/siterestrict"
81
  params = {
82
  'q': query,
@@ -88,6 +118,7 @@ def google_search(query, limit=10):
88
  res = requests.get(service_url, params=params)
89
  try:
90
  cands = [i['title'].replace(' - Wikipedia', '') for i in res.json()["items"]]
 
91
  return [i.replace(' ', '_') for i in cands]
92
  except:
93
  return []
 
60
  return np.dot(v2, v1) / v1v2
61
 
62
 
63
+ def is_disamb_page(title):
64
+ service_url = "https://en.wikipedia.org/w/api.php"
65
  params = {
66
+ "action": "query",
67
+ "prop": "pageprops",
68
+ "ppprop" : "disambiguation",
69
+ "redirects":'',
70
+ "format": "json",
71
+ "titles": title
72
+ }
73
+ results = requests.get(service_url, params=params).json()
74
+ return 'disambiguation' in str(results)
75
+
76
+
77
+ def wikidata_search(query, limit=3):
78
+ service_url = 'https://www.wikidata.org/w/api.php'
79
+ params1 = {
80
+ "action": "wbsearchentities",
81
+ "search": query,
82
+ "language": "en",
83
+ "limit": limit,
84
+ "format": "json"
85
  }
86
 
87
+ params2 = {
88
+ "action": "wbgetentities",
89
+ "language": "en",
90
+ "props": "sitelinks",
91
+ "sitefilter": "enwiki",
92
+ "format": "json"
93
+ }
94
+
95
+ results = requests.get(service_url, params=params1).json()
96
+ entities = [i['id'] for i in results['search']]
97
+
98
+ params2['ids'] = '|'.join(entities)
99
+ results = requests.get(service_url, params=params2).json()
100
+ candidates = []
101
+ for i in entities:
102
+ try:
103
+ candidates.append(results['entities'][i]['sitelinks']['enwiki']['title'].replace(' ', '_'))
104
+ except:
105
+ pass
106
+ return [i for i in candidates if is_disamb_page(i) == False]
107
 
108
 
109
+ def google_search(query, limit=3):
110
  service_url = "https://www.googleapis.com/customsearch/v1/siterestrict"
111
  params = {
112
  'q': query,
 
118
  res = requests.get(service_url, params=params)
119
  try:
120
  cands = [i['title'].replace(' - Wikipedia', '') for i in res.json()["items"]]
121
+ cands = [i for i in cands if is_disamb_page(i) == False]
122
  return [i.replace(' ', '_') for i in cands]
123
  except:
124
  return []