Spaces:
Runtime error
Runtime error
File size: 5,371 Bytes
16188ba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
from .clustering import *
from typing import List
import textdistance as td
from .utils import UnionFind, ArticleList
from .academic_query import AcademicQuery
import streamlit as st
from tokenizers import Tokenizer
from .clustering.clusters import KeyphraseCount
class LiteratureResearchTool:
def __init__(self, cluster_config: Configuration = None):
self.literature_search = AcademicQuery
self.cluster_pipeline = ClusterPipeline(cluster_config)
def __postprocess_clusters__(self, clusters: ClusterList,query: str) ->ClusterList:
'''
add top-5 keyphrases to each cluster
:param clusters:
:return: clusters
'''
def condition(x: KeyphraseCount, y: KeyphraseCount):
return td.ratcliff_obershelp(x.keyphrase, y.keyphrase) > 0.8
def valid_keyphrase(x:KeyphraseCount):
tmp = x.keyphrase
return tmp is not None and tmp != '' and not tmp.isspace() and len(tmp)!=1\
and tmp != query
for cluster in clusters:
keyphrases = cluster.get_keyphrases() # [kc]
keyphrases = list(filter(valid_keyphrase,keyphrases))
unionfind = UnionFind(keyphrases, condition)
unionfind.union_step()
tmp = unionfind.get_unions() # dict(root_id = [kc])
tmp = tmp.values() # [[kc]]
# [[kc]] -> [ new kc] -> sorted
tmp = [KeyphraseCount.reduce(x) for x in tmp]
keyphrases = sorted(tmp,key= lambda x: x.count,reverse=True)[:5]
keyphrases = [x.keyphrase for x in keyphrases]
# keyphrases = sorted(list(unionfind.get_unions().values()), key=len, reverse=True)[:5] # top-5 keyphrases: list
# for i in keyphrases:
# tmp = '/'.join(i)
# cluster.top_5_keyphrases.append(tmp)
cluster.top_5_keyphrases = keyphrases
return clusters
def __call__(self,
query: str,
num_papers: int,
start_year: int,
end_year: int,
max_k: int,
platforms: List[str] = ['IEEE', 'Arxiv', 'Paper with Code'],
loading_ctx_manager = None,
standardization = False
):
for platform in platforms:
if loading_ctx_manager:
with loading_ctx_manager():
clusters, articles = self.__platformPipeline__(platform,query,num_papers,start_year,end_year,max_k,standardization)
else:
clusters, articles = self.__platformPipeline__(platform, query, num_papers, start_year, end_year,max_k,standardization)
clusters.sort()
yield clusters,articles
def __platformPipeline__(self,platforn_name:str,
query: str,
num_papers: int,
start_year: int,
end_year: int,
max_k: int,
standardization
) -> (ClusterList,ArticleList):
@st.cache(hash_funcs={Tokenizer: Tokenizer.__hash__},allow_output_mutation=True)
def ieee_process(
query: str,
num_papers: int,
start_year: int,
end_year: int,
):
articles = ArticleList.parse_ieee_articles(
self.literature_search.ieee(query, start_year, end_year, num_papers)) # ArticleList
abstracts = articles.getAbstracts() # List[str]
clusters = self.cluster_pipeline(abstracts,max_k,standardization)
clusters = self.__postprocess_clusters__(clusters,query)
return clusters, articles
@st.cache(hash_funcs={Tokenizer: Tokenizer.__hash__},allow_output_mutation=True)
def arxiv_process(
query: str,
num_papers: int,
):
articles = ArticleList.parse_arxiv_articles(
self.literature_search.arxiv(query, num_papers)) # ArticleList
abstracts = articles.getAbstracts() # List[str]
clusters = self.cluster_pipeline(abstracts,max_k,standardization)
clusters = self.__postprocess_clusters__(clusters,query)
return clusters, articles
@st.cache(hash_funcs={Tokenizer: Tokenizer.__hash__},allow_output_mutation=True)
def pwc_process(
query: str,
num_papers: int,
):
articles = ArticleList.parse_pwc_articles(
self.literature_search.paper_with_code(query, num_papers)) # ArticleList
abstracts = articles.getAbstracts() # List[str]
clusters = self.cluster_pipeline(abstracts,max_k,standardization)
clusters = self.__postprocess_clusters__(clusters,query)
return clusters, articles
if platforn_name == 'IEEE':
return ieee_process(query,num_papers,start_year,end_year)
elif platforn_name == 'Arxiv':
return arxiv_process(query,num_papers)
elif platforn_name == 'Paper with Code':
return pwc_process(query,num_papers)
else:
raise RuntimeError('This platform is not supported. Please open an issue on the GitHub.')
|