Atharva commited on
Commit
aa726b0
1 Parent(s): 5bc0741

pipeline update

Browse files
Files changed (2) hide show
  1. app.py +2 -1
  2. src/__init__.py +18 -0
app.py CHANGED
@@ -2,7 +2,7 @@ import pandas as pd
2
  import streamlit as st
3
  from transformers import AutoModelForTokenClassification, AutoTokenizer, pipeline
4
 
5
- from src import GBRT, wikipedia_search, wikidata_search
6
 
7
  TYPE = {
8
  'LOC': ' location',
@@ -53,6 +53,7 @@ def get_candidates(mentions_tags):
53
  candidates.append((mention, cache[(mention, tag)]))
54
  else:
55
  cands = wikidata_search(mention, limit=3)
 
56
  if cands == []:
57
  cands = wikipedia_search(mention, limit=3)
58
  cache[(mention, tag)] = cands
 
2
  import streamlit as st
3
  from transformers import AutoModelForTokenClassification, AutoTokenizer, pipeline
4
 
5
+ from src import GBRT, wikipedia_search, wikidata_search, google_search
6
 
7
  TYPE = {
8
  'LOC': ' location',
 
53
  candidates.append((mention, cache[(mention, tag)]))
54
  else:
55
  cands = wikidata_search(mention, limit=3)
56
+ cands = list(set(cands + google_search(mention)))
57
  if cands == []:
58
  cands = wikipedia_search(mention, limit=3)
59
  cache[(mention, tag)] = cands
src/__init__.py CHANGED
@@ -121,6 +121,24 @@ def wikipedia_search(query, limit=3):
121
  return [i for i in results if is_disamb_page(i) == False]
122
 
123
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  def get_entity_extract(entity_title, num_sentences=0):
125
  service_url = 'https://en.wikipedia.org/w/api.php'
126
  params = {
 
121
  return [i for i in results if is_disamb_page(i) == False]
122
 
123
 
124
+ def google_search(query, limit=10):
125
+ service_url = "https://www.googleapis.com/customsearch/v1/siterestrict"
126
+ params = {
127
+ 'q': query,
128
+ 'num': limit,
129
+ 'start': 0,
130
+ 'key': os.environ.get('APIKEY'),
131
+ 'cx': os.environ.get('CESCX')
132
+ }
133
+ res = requests.get(service_url, params=params)
134
+ try:
135
+ cands = [i['title'].replace(' - Wikipedia', '') for i in res.json()["items"]]
136
+ return [i.replace(' ', '_') for i in cands if is_disamb_page(i) == False]
137
+ except:
138
+ return []
139
+
140
+
141
+
142
  def get_entity_extract(entity_title, num_sentences=0):
143
  service_url = 'https://en.wikipedia.org/w/api.php'
144
  params = {