Atharva commited on
Commit
5bc0741
1 Parent(s): 5ae066c

pipeline update

Browse files
Files changed (2) hide show
  1. app.py +5 -4
  2. src/__init__.py +11 -14
app.py CHANGED
@@ -2,7 +2,7 @@ import pandas as pd
2
  import streamlit as st
3
  from transformers import AutoModelForTokenClassification, AutoTokenizer, pipeline
4
 
5
- from src import GBRT, google_search, wikidata_search
6
 
7
  TYPE = {
8
  'LOC': ' location',
@@ -52,11 +52,12 @@ def get_candidates(mentions_tags):
52
  if (mention, tag) in cache.keys():
53
  candidates.append((mention, cache[(mention, tag)]))
54
  else:
55
- res1 = google_search(mention + TYPE[tag], limit=3)
56
- res2 = wikidata_search(mention, limit=3)
57
- cands = list(set(res1 + res2))
58
  cache[(mention, tag)] = cands
59
  candidates.append((mention, cands))
 
60
  return candidates
61
 
62
 
 
2
  import streamlit as st
3
  from transformers import AutoModelForTokenClassification, AutoTokenizer, pipeline
4
 
5
+ from src import GBRT, wikipedia_search, wikidata_search
6
 
7
  TYPE = {
8
  'LOC': ' location',
 
52
  if (mention, tag) in cache.keys():
53
  candidates.append((mention, cache[(mention, tag)]))
54
  else:
55
+ cands = wikidata_search(mention, limit=3)
56
+ if cands == []:
57
+ cands = wikipedia_search(mention, limit=3)
58
  cache[(mention, tag)] = cands
59
  candidates.append((mention, cands))
60
+ print(mention, cands)
61
  return candidates
62
 
63
 
src/__init__.py CHANGED
@@ -106,22 +106,19 @@ def wikidata_search(query, limit=3):
106
  return [i for i in candidates if is_disamb_page(i) == False]
107
 
108
 
109
- def google_search(query, limit=3):
110
- service_url = "https://www.googleapis.com/customsearch/v1/siterestrict"
111
  params = {
112
- 'q': query,
113
- 'num': limit,
114
- 'start': 0,
115
- 'key': os.environ.get('APIKEY'),
116
- 'cx': os.environ.get('CESCX')
117
  }
118
- res = requests.get(service_url, params=params)
119
- try:
120
- cands = [i['title'].replace(' - Wikipedia', '') for i in res.json()["items"]]
121
- cands = [i for i in cands if is_disamb_page(i) == False]
122
- return [i.replace(' ', '_') for i in cands]
123
- except:
124
- return []
125
 
126
 
127
  def get_entity_extract(entity_title, num_sentences=0):
 
106
  return [i for i in candidates if is_disamb_page(i) == False]
107
 
108
 
109
+ def wikipedia_search(query, limit=3):
110
+ service_url = 'https://en.wikipedia.org/w/api.php'
111
  params = {
112
+ 'action': 'opensearch',
113
+ 'search': query,
114
+ 'namespace': 0,
115
+ 'limit': limit,
116
+ 'redirects': 'resolve',
117
  }
118
+
119
+ results = requests.get(service_url, params=params).json()[1]
120
+ results = [i.replace(' ', '_') for i in results if 'disambiguation' not in i.lower()]
121
+ return [i for i in results if is_disamb_page(i) == False]
 
 
 
122
 
123
 
124
  def get_entity_extract(entity_title, num_sentences=0):