imamnurby commited on
Commit
d4ebedd
·
1 Parent(s): 624af83

Update backend_utils.py

Browse files
Files changed (1) hide show
  1. backend_utils.py +85 -9
backend_utils.py CHANGED
@@ -8,6 +8,12 @@ import pickle
8
  import torch
9
  from sklearn.multiclass import OneVsRestClassifier
10
  from sklearn.ensemble import RandomForestClassifier
 
 
 
 
 
 
11
 
12
  class wrappedTokenizer(RobertaTokenizer):
13
  def __call__(self, text_input):
@@ -24,7 +30,7 @@ def generate_index(db):
24
  })
25
  return index_list
26
 
27
- def load_db(db_metadata_path, db_constructor_path, db_params_path):
28
  '''
29
  Function to load dataframe
30
 
@@ -42,7 +48,11 @@ def load_db(db_metadata_path, db_constructor_path, db_params_path):
42
  db_constructor.dropna(inplace=True)
43
  db_params = pd.read_csv(db_params_path)
44
  db_params.dropna(inplace=True)
45
- return db_metadata, db_constructor, db_params
 
 
 
 
46
 
47
 
48
 
@@ -177,8 +187,9 @@ def retrieve_libraries(retrieval_model, model_input, db_metadata):
177
  '''
178
  results = retrieval_model(model_input)
179
  library_ids = [item.get('id') for item in results]
 
180
  library_names = [id_to_libname(item, db_metadata) for item in library_ids]
181
- return library_ids, library_names
182
 
183
  def prepare_input_generative_model(library_ids, db_constructor):
184
  '''
@@ -423,10 +434,11 @@ def initialize_all_components(config):
423
  classifier_head: a random forest model
424
  '''
425
  # load db
426
- db_metadata, db_constructor, db_params = load_db(
427
  config.get('db_metadata_path'),
428
  config.get('db_constructor_path'),
429
- config.get('db_params_path')
 
430
  )
431
 
432
  # load model
@@ -443,14 +455,14 @@ def initialize_all_components(config):
443
  config.get('classifier_head_path')
444
  )
445
 
446
- return db_metadata, db_constructor, db_params, model_retrieval, model_generative, tokenizer_generative, model_classifier, classifier_head, tokenizer_classifier
447
 
448
  def make_predictions(input_query,
449
  model_retrieval,
450
  model_generative,
451
  model_classifier, classifier_head,
452
  tokenizer_generative, tokenizer_classifier,
453
- db_metadata, db_constructor, db_params,
454
  config):
455
  '''
456
  Function to retrieve relevant libraries, generate API usage patterns, and predict the hw configs
@@ -467,9 +479,28 @@ def make_predictions(input_query,
467
  Returns:
468
  predictions (list): a list of dictionary containing the prediction details
469
  '''
470
- library_ids, library_names = retrieve_libraries(model_retrieval, input_query, db_metadata)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
471
 
472
  if len(library_ids) == 0:
 
473
  return "null"
474
 
475
  print("generate usage patterns")
@@ -500,4 +531,49 @@ def make_predictions(input_query,
500
  print("finished the predictions")
501
  predictions = get_metadata_library(predictions, db_metadata)
502
 
503
- return predictions
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  import torch
9
  from sklearn.multiclass import OneVsRestClassifier
10
  from sklearn.ensemble import RandomForestClassifier
11
+ import spacy
12
+
13
+ # nlp = spacy.load("en_core_web_trf")
14
+ nlp = spacy.load("en_core_web_sm")
15
+
16
+
17
 
18
  class wrappedTokenizer(RobertaTokenizer):
19
  def __call__(self, text_input):
 
30
  })
31
  return index_list
32
 
33
+ def load_db(db_metadata_path, db_constructor_path, db_params_path, exclusion_list_path):
34
  '''
35
  Function to load dataframe
36
 
 
48
  db_constructor.dropna(inplace=True)
49
  db_params = pd.read_csv(db_params_path)
50
  db_params.dropna(inplace=True)
51
+ with open(exclusion_list_path, 'r') as f:
52
+ ex_list = f.read()
53
+ ex_list = ex_list.split("\n")
54
+
55
+ return db_metadata, db_constructor, db_params, ex_list
56
 
57
 
58
 
 
187
  '''
188
  results = retrieval_model(model_input)
189
  library_ids = [item.get('id') for item in results]
190
+ scores = [item.get('similarity') for item in results]
191
  library_names = [id_to_libname(item, db_metadata) for item in library_ids]
192
+ return library_ids, library_names, scores
193
 
194
  def prepare_input_generative_model(library_ids, db_constructor):
195
  '''
 
434
  classifier_head: a random forest model
435
  '''
436
  # load db
437
+ db_metadata, db_constructor, db_params, ex_list = load_db(
438
  config.get('db_metadata_path'),
439
  config.get('db_constructor_path'),
440
+ config.get('db_params_path'),
441
+ config.get('exclusion_list_path')
442
  )
443
 
444
  # load model
 
455
  config.get('classifier_head_path')
456
  )
457
 
458
+ return db_metadata, db_constructor, db_params, ex_list, model_retrieval, model_generative, tokenizer_generative, model_classifier, classifier_head, tokenizer_classifier
459
 
460
  def make_predictions(input_query,
461
  model_retrieval,
462
  model_generative,
463
  model_classifier, classifier_head,
464
  tokenizer_generative, tokenizer_classifier,
465
+ db_metadata, db_constructor, db_params, ex_list,
466
  config):
467
  '''
468
  Function to retrieve relevant libraries, generate API usage patterns, and predict the hw configs
 
479
  Returns:
480
  predictions (list): a list of dictionary containing the prediction details
481
  '''
482
+ print("retrieve libraries")
483
+ queries = extract_keywords(input_query.lower(), ex_list)
484
+
485
+ temp_list = []
486
+ for query in queries:
487
+ temp_library_ids, temp_library_names, temp_scores = retrieve_libraries(model_retrieval, query, db_metadata)
488
+
489
+ if len(temp_library_ids) > 0:
490
+ for id_, name, score in zip(temp_library_ids, temp_library_names, temp_scores):
491
+ temp_list.append((id_, name, score))
492
+
493
+ library_ids = []
494
+ library_names = []
495
+ if len(temp_list) > 0:
496
+ sorted_list = sorted(temp_list, key=lambda tup: tup[2], reverse=True)
497
+ sorted_list = sorted_list[:config.get('max_k')]
498
+ for item in sorted_list:
499
+ library_ids.append(item[0])
500
+ library_names.append(item[1])
501
 
502
  if len(library_ids) == 0:
503
+ print("null libraries")
504
  return "null"
505
 
506
  print("generate usage patterns")
 
531
  print("finished the predictions")
532
  predictions = get_metadata_library(predictions, db_metadata)
533
 
534
+ return predictions
535
+
536
+ def extract_series(x):
537
+ name = x.replace("-", " ").replace("_", " ")
538
+ name = name.split()
539
+ series = []
540
+ for token in name:
541
+ if token.isalnum() and not(token.isalpha()) and not(token.isdigit()):
542
+ series.append(token)
543
+ if len(series) > 0:
544
+ return series
545
+ else:
546
+ return [x]
547
+
548
+ def extract_keywords(query, ex_list):
549
+ doc = nlp(query)
550
+ keyword_candidates = []
551
+
552
+ # extract keywords
553
+ for chunk in doc.noun_chunks:
554
+ temp_list = []
555
+
556
+ for token in chunk:
557
+ if token.text not in ex_list and token.pos_ not in ("DET", "PRON", "CCONJ", "NUM"):
558
+ temp_list.append(token.text)
559
+
560
+ if len(temp_list) > 0:
561
+ keyword_candidates.append(" ".join(temp_list))
562
+
563
+ filtered_keyword_candidates = []
564
+ for keyword in keyword_candidates:
565
+ temp_candidates = extract_series(keyword)
566
+
567
+ for keyword in temp_candidates:
568
+
569
+ if len(keyword.split()) > 1:
570
+ doc = nlp(keyword)
571
+ for chunk in doc.noun_chunks:
572
+ filtered_keyword_candidates.append(chunk.root.text)
573
+ else:
574
+ filtered_keyword_candidates.append(keyword)
575
+
576
+ if len(filtered_keyword_candidates) == 0:
577
+ filtered_keyword_candidates.append(query)
578
+
579
+ return filtered_keyword_candidates