lihuigu commited on
Commit
de0c71d
·
1 Parent(s): 8607d4b

update config

Browse files
configs/datasets.yaml CHANGED
@@ -9,6 +9,9 @@ ARTICLE:
9
  summarizing_prompt: ./assets/prompt/summarizing.xml
10
 
11
  RETRIEVE:
 
 
 
12
  cite_type: "all_cite_id_list"
13
  limit_num: 100 # 限制entity对应的paper数量
14
  sn_num_for_entity: 5 # SN搜索的文章数量,扩充entity
@@ -19,8 +22,6 @@ RETRIEVE:
19
  sum_paper_num: 100 # 最多检索到的paper数量
20
  sn_retrieve_paper_num: 55 # 通过SN检索到的文章
21
  cocite_top_k: 1
22
- use_cocite: True
23
- use_cluster_to_filter: True # 过滤器中使用聚类算法
24
  need_normalize: True
25
  alpha: 1
26
  beta: 0
 
9
  summarizing_prompt: ./assets/prompt/summarizing.xml
10
 
11
  RETRIEVE:
12
+ retriever_name: "SNKG"
13
+ use_cocite: True
14
+ use_cluster_to_filter: True # 过滤器中使用聚类算法
15
  cite_type: "all_cite_id_list"
16
  limit_num: 100 # 限制entity对应的paper数量
17
  sn_num_for_entity: 5 # SN搜索的文章数量,扩充entity
 
22
  sum_paper_num: 100 # 最多检索到的paper数量
23
  sn_retrieve_paper_num: 55 # 通过SN检索到的文章
24
  cocite_top_k: 1
 
 
25
  need_normalize: True
26
  alpha: 1
27
  beta: 0
src/generator.py CHANGED
@@ -342,7 +342,7 @@ def backtracking(config_path, ids_path, retriever_name, brainstorm_mode, use_cue
342
  logger.info("\nretrieve name : {}".format(retriever_name))
343
  logger.info("Loaded configuration:\n{}".format(OmegaConf.to_yaml(config)))
344
  api_helper = APIHelper(config)
345
- paper_client = PaperClient(config)
346
  eval_data = []
347
  processed_ids = set()
348
  cur_num = 0
 
342
  logger.info("\nretrieve name : {}".format(retriever_name))
343
  logger.info("Loaded configuration:\n{}".format(OmegaConf.to_yaml(config)))
344
  api_helper = APIHelper(config)
345
+ paper_client = PaperClient()
346
  eval_data = []
347
  processed_ids = set()
348
  cur_num = 0
src/pages/button_interface.py CHANGED
@@ -8,17 +8,15 @@ class Backend(object):
8
  def __init__(self) -> None:
9
  CONFIG_PATH = "./configs/datasets.yaml"
10
  EXAMPLE_PATH = "./assets/data/example.json"
11
- RETRIEVER_NAME = "SNKG"
12
  USE_INSPIRATION = True
13
  BRAINSTORM_MODE = "mode_c"
14
 
15
  self.config = ConfigReader.load(CONFIG_PATH)
 
16
  self.api_helper = APIHelper(self.config)
17
  self.retriever_factory = RetrieverFactory.get_retriever_factory().create_retriever(
18
  RETRIEVER_NAME,
19
- self.config,
20
- use_cocite=self.config.RETRIEVE.use_cocite,
21
- use_cluster_to_filter=self.config.RETRIEVE.use_cluster_to_filter,
22
  )
23
  self.idea_generator = IdeaGenerator(self.config, None)
24
  self.use_inspiration = USE_INSPIRATION
 
8
  def __init__(self) -> None:
9
  CONFIG_PATH = "./configs/datasets.yaml"
10
  EXAMPLE_PATH = "./assets/data/example.json"
 
11
  USE_INSPIRATION = True
12
  BRAINSTORM_MODE = "mode_c"
13
 
14
  self.config = ConfigReader.load(CONFIG_PATH)
15
+ RETRIEVER_NAME = self.config.RETRIEVE.retriever_name
16
  self.api_helper = APIHelper(self.config)
17
  self.retriever_factory = RetrieverFactory.get_retriever_factory().create_retriever(
18
  RETRIEVER_NAME,
19
+ self.config
 
 
20
  )
21
  self.idea_generator = IdeaGenerator(self.config, None)
22
  self.use_inspiration = USE_INSPIRATION
src/retriever.py CHANGED
@@ -38,31 +38,16 @@ def main(ctx):
38
  required=True,
39
  help="Dataset configuration file in YAML",
40
  )
41
- @click.option(
42
- "-r",
43
- "--retriever-name",
44
- default="SNKG",
45
- type=str,
46
- required=True,
47
- help="Retrieve method",
48
- )
49
- @click.option(
50
- "--co-cite",
51
- is_flag=True,
52
- help="Whether to use co-citation, defaults to False",
53
- )
54
- @click.option(
55
- "--cluster-to-filter",
56
- is_flag=True,
57
- help="Whether to use cluster-to-filter, defaults to False",
58
- )
59
  def retrieve(
60
- config_path, ids_path, retriever_name, co_cite, cluster_to_filter, **kwargs
61
  ):
62
  check_env()
63
  check_embedding()
64
  config = ConfigReader.load(config_path, **kwargs)
65
  log_dir = config.DEFAULT.log_dir
 
 
 
66
  if not os.path.exists(log_dir):
67
  os.makedirs(log_dir)
68
  print(f"Created log directory: {log_dir}")
@@ -73,10 +58,10 @@ def retrieve(
73
  ),
74
  )
75
  logger.add(log_file, level=config.DEFAULT.log_level)
76
- logger.info("\nretrieve name : {}".format(retriever_name))
77
  logger.info("Loaded configuration:\n{}".format(OmegaConf.to_yaml(config)))
78
  api_helper = APIHelper(config)
79
- paper_client = PaperClient(config)
80
  precision = 0
81
  filtered_precision = 0
82
  recall = 0
@@ -90,9 +75,7 @@ def retrieve(
90
  # Init Retriever
91
  rt = RetrieverFactory.get_retriever_factory().create_retriever(
92
  retriever_name,
93
- config,
94
- use_cocite=co_cite,
95
- use_cluster_to_filter=cluster_to_filter,
96
  )
97
  for line in ids_path:
98
  paper = json.loads(line)
@@ -108,7 +91,7 @@ def retrieve(
108
  entities = paper["entities"]
109
  else:
110
  entities = api_helper.generate_entity_list(bg)
111
- logger.info("origin entities from background: {}".format(entities))
112
  cite_type = config.RETRIEVE.cite_type
113
  if cite_type in paper and len(paper[cite_type]) >= 5:
114
  target_paper_id_list = paper[cite_type]
 
38
  required=True,
39
  help="Dataset configuration file in YAML",
40
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  def retrieve(
42
+ config_path, ids_path, **kwargs
43
  ):
44
  check_env()
45
  check_embedding()
46
  config = ConfigReader.load(config_path, **kwargs)
47
  log_dir = config.DEFAULT.log_dir
48
+ retriever_name = config.RETRIEVE.retriever_name
49
+ cluster_to_filter = config.RETRIEVE.use_cluster_to_filter
50
+ co_cite = config.RETRIEVE.use_cocite
51
  if not os.path.exists(log_dir):
52
  os.makedirs(log_dir)
53
  print(f"Created log directory: {log_dir}")
 
58
  ),
59
  )
60
  logger.add(log_file, level=config.DEFAULT.log_level)
61
+ logger.info("=== Retriever name : {} ===".format(retriever_name))
62
  logger.info("Loaded configuration:\n{}".format(OmegaConf.to_yaml(config)))
63
  api_helper = APIHelper(config)
64
+ paper_client = PaperClient()
65
  precision = 0
66
  filtered_precision = 0
67
  recall = 0
 
75
  # Init Retriever
76
  rt = RetrieverFactory.get_retriever_factory().create_retriever(
77
  retriever_name,
78
+ config
 
 
79
  )
80
  for line in ids_path:
81
  paper = json.loads(line)
 
91
  entities = paper["entities"]
92
  else:
93
  entities = api_helper.generate_entity_list(bg)
94
+ logger.info("\norigin entities from background: {}".format(entities))
95
  cite_type = config.RETRIEVE.cite_type
96
  if cite_type in paper and len(paper[cite_type]) >= 5:
97
  target_paper_id_list = paper[cite_type]
src/utils/hash.py CHANGED
@@ -62,10 +62,12 @@ class EmbeddingModel:
62
  def __new__(cls, config):
63
  if cls._instance is None:
64
  cls._instance = super(EmbeddingModel, cls).__new__(cls)
 
65
  cls._instance.embedding_model = SentenceTransformer(
66
  model_name_or_path=get_dir(config.DEFAULT.embedding),
67
- device="cuda" if torch.cuda.is_available() else "cpu",
68
  )
 
69
  return cls._instance
70
 
71
  def get_embedding_model(config):
 
62
  def __new__(cls, config):
63
  if cls._instance is None:
64
  cls._instance = super(EmbeddingModel, cls).__new__(cls)
65
+ device = "cuda" if torch.cuda.is_available() else "cpu"
66
  cls._instance.embedding_model = SentenceTransformer(
67
  model_name_or_path=get_dir(config.DEFAULT.embedding),
68
+ device=device,
69
  )
70
+ print(f"==== using device {device} ====")
71
  return cls._instance
72
 
73
  def get_embedding_model(config):
src/utils/paper_retriever.py CHANGED
@@ -84,10 +84,10 @@ class Retriever(object):
84
  __metaclass__ = ABCMeta
85
  retriever_name = "BASE"
86
 
87
- def __init__(self, config, use_cocite=False, use_cluster_to_filter=False):
88
  self.config = config
89
- self.use_cocite = use_cocite
90
- self.use_cluster_to_filter = use_cluster_to_filter
91
  self.paper_client = PaperClient()
92
  self.cocite = CoCite()
93
  self.api_helper = APIHelper(config=config)
@@ -389,7 +389,9 @@ class Retriever(object):
389
  logger.debug(f"target label counts : {target_label_counts}")
390
  target_label_list = list(target_label_counts.keys())
391
  max_k = max(self.config.RETRIEVE.top_k_list)
 
392
  max_k_paper_id_list = self.filter_related_paper(score_all_dict, top_k=max_k)
 
393
  for k in self.config.RETRIEVE.top_k_list:
394
  # 前top k 的文章
395
  top_k = min(k, len(max_k_paper_id_list))
@@ -507,8 +509,8 @@ class autoregister:
507
 
508
  @autoregister("SN")
509
  class SNRetriever(Retriever):
510
- def __init__(self, config, use_cocite=False, use_cluster_to_filter=False):
511
- super().__init__(config, use_cocite, use_cluster_to_filter)
512
 
513
  def retrieve_paper(self, bg):
514
  entities = []
@@ -590,8 +592,8 @@ class SNRetriever(Retriever):
590
 
591
  @autoregister("KG")
592
  class KGRetriever(Retriever):
593
- def __init__(self, config, use_cocite=False, use_cluster_to_filter=False):
594
- super().__init__(config, use_cocite, use_cluster_to_filter)
595
 
596
  def retrieve_paper(self, entities):
597
  new_entities = self.retrieve_entities_by_enties(entities)
@@ -669,8 +671,8 @@ class KGRetriever(Retriever):
669
 
670
  @autoregister("SNKG")
671
  class SNKGRetriever(Retriever):
672
- def __init__(self, config, use_cocite=False, use_cluster_to_filter=False):
673
- super().__init__(config, use_cocite, use_cluster_to_filter)
674
 
675
  def retrieve_paper(self, bg, entities):
676
  sn_entities = []
@@ -721,9 +723,11 @@ class SNKGRetriever(Retriever):
721
  retrieve_result = self.retrieve_paper(bg, entities)
722
  related_paper_id_list = retrieve_result["paper"]
723
  retrieve_paper_num = len(related_paper_id_list)
 
724
  _, _, score_all_dict = self.cal_related_score(
725
  bg, related_paper_id_list=related_paper_id_list, entities=entities
726
  )
 
727
  top_k_matrix = {}
728
  recall = 0
729
  precision = 0
@@ -738,7 +742,9 @@ class SNKGRetriever(Retriever):
738
  logger.debug("before filter:")
739
  logger.debug(f"Recall: {recall:.3f}")
740
  logger.debug(f"Precision: {precision:.3f}")
 
741
  related_paper = self.filter_related_paper(score_all_dict, top_k)
 
742
  related_paper = self.update_related_paper(related_paper)
743
  result = {
744
  "recall": recall,
 
84
  __metaclass__ = ABCMeta
85
  retriever_name = "BASE"
86
 
87
+ def __init__(self, config):
88
  self.config = config
89
+ self.use_cocite = config.RETRIEVE.use_cocite
90
+ self.use_cluster_to_filter = config.RETRIEVE.use_cluster_to_filter
91
  self.paper_client = PaperClient()
92
  self.cocite = CoCite()
93
  self.api_helper = APIHelper(config=config)
 
389
  logger.debug(f"target label counts : {target_label_counts}")
390
  target_label_list = list(target_label_counts.keys())
391
  max_k = max(self.config.RETRIEVE.top_k_list)
392
+ logger.info("=== Begin filter related paper ===")
393
  max_k_paper_id_list = self.filter_related_paper(score_all_dict, top_k=max_k)
394
+ logger.info("=== End filter related paper ===")
395
  for k in self.config.RETRIEVE.top_k_list:
396
  # 前top k 的文章
397
  top_k = min(k, len(max_k_paper_id_list))
 
509
 
510
  @autoregister("SN")
511
  class SNRetriever(Retriever):
512
+ def __init__(self, config):
513
+ super().__init__(config)
514
 
515
  def retrieve_paper(self, bg):
516
  entities = []
 
592
 
593
  @autoregister("KG")
594
  class KGRetriever(Retriever):
595
+ def __init__(self, config):
596
+ super().__init__(config)
597
 
598
  def retrieve_paper(self, entities):
599
  new_entities = self.retrieve_entities_by_enties(entities)
 
671
 
672
  @autoregister("SNKG")
673
  class SNKGRetriever(Retriever):
674
+ def __init__(self, config):
675
+ super().__init__(config)
676
 
677
  def retrieve_paper(self, bg, entities):
678
  sn_entities = []
 
723
  retrieve_result = self.retrieve_paper(bg, entities)
724
  related_paper_id_list = retrieve_result["paper"]
725
  retrieve_paper_num = len(related_paper_id_list)
726
+ logger.info("=== Begin cal related paper score ===")
727
  _, _, score_all_dict = self.cal_related_score(
728
  bg, related_paper_id_list=related_paper_id_list, entities=entities
729
  )
730
+ logger.info("=== End cal related paper score ===")
731
  top_k_matrix = {}
732
  recall = 0
733
  precision = 0
 
742
  logger.debug("before filter:")
743
  logger.debug(f"Recall: {recall:.3f}")
744
  logger.debug(f"Precision: {precision:.3f}")
745
+ logger.info("=== Begin filter related paper score ===")
746
  related_paper = self.filter_related_paper(score_all_dict, top_k)
747
+ logger.info("=== End filter related paper score ===")
748
  related_paper = self.update_related_paper(related_paper)
749
  result = {
750
  "recall": recall,