Spaces:
Build error
Build error
Update backend_utils.py
Browse files- backend_utils.py +85 -9
backend_utils.py
CHANGED
@@ -8,6 +8,12 @@ import pickle
|
|
8 |
import torch
|
9 |
from sklearn.multiclass import OneVsRestClassifier
|
10 |
from sklearn.ensemble import RandomForestClassifier
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
class wrappedTokenizer(RobertaTokenizer):
|
13 |
def __call__(self, text_input):
|
@@ -24,7 +30,7 @@ def generate_index(db):
|
|
24 |
})
|
25 |
return index_list
|
26 |
|
27 |
-
def load_db(db_metadata_path, db_constructor_path, db_params_path):
|
28 |
'''
|
29 |
Function to load dataframe
|
30 |
|
@@ -42,7 +48,11 @@ def load_db(db_metadata_path, db_constructor_path, db_params_path):
|
|
42 |
db_constructor.dropna(inplace=True)
|
43 |
db_params = pd.read_csv(db_params_path)
|
44 |
db_params.dropna(inplace=True)
|
45 |
-
|
|
|
|
|
|
|
|
|
46 |
|
47 |
|
48 |
|
@@ -177,8 +187,9 @@ def retrieve_libraries(retrieval_model, model_input, db_metadata):
|
|
177 |
'''
|
178 |
results = retrieval_model(model_input)
|
179 |
library_ids = [item.get('id') for item in results]
|
|
|
180 |
library_names = [id_to_libname(item, db_metadata) for item in library_ids]
|
181 |
-
return library_ids, library_names
|
182 |
|
183 |
def prepare_input_generative_model(library_ids, db_constructor):
|
184 |
'''
|
@@ -423,10 +434,11 @@ def initialize_all_components(config):
|
|
423 |
classifier_head: a random forest model
|
424 |
'''
|
425 |
# load db
|
426 |
-
db_metadata, db_constructor, db_params = load_db(
|
427 |
config.get('db_metadata_path'),
|
428 |
config.get('db_constructor_path'),
|
429 |
-
config.get('db_params_path')
|
|
|
430 |
)
|
431 |
|
432 |
# load model
|
@@ -443,14 +455,14 @@ def initialize_all_components(config):
|
|
443 |
config.get('classifier_head_path')
|
444 |
)
|
445 |
|
446 |
-
return db_metadata, db_constructor, db_params, model_retrieval, model_generative, tokenizer_generative, model_classifier, classifier_head, tokenizer_classifier
|
447 |
|
448 |
def make_predictions(input_query,
|
449 |
model_retrieval,
|
450 |
model_generative,
|
451 |
model_classifier, classifier_head,
|
452 |
tokenizer_generative, tokenizer_classifier,
|
453 |
-
db_metadata, db_constructor, db_params,
|
454 |
config):
|
455 |
'''
|
456 |
Function to retrieve relevant libraries, generate API usage patterns, and predict the hw configs
|
@@ -467,9 +479,28 @@ def make_predictions(input_query,
|
|
467 |
Returns:
|
468 |
predictions (list): a list of dictionary containing the prediction details
|
469 |
'''
|
470 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
471 |
|
472 |
if len(library_ids) == 0:
|
|
|
473 |
return "null"
|
474 |
|
475 |
print("generate usage patterns")
|
@@ -500,4 +531,49 @@ def make_predictions(input_query,
|
|
500 |
print("finished the predictions")
|
501 |
predictions = get_metadata_library(predictions, db_metadata)
|
502 |
|
503 |
-
return predictions
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
import torch
|
9 |
from sklearn.multiclass import OneVsRestClassifier
|
10 |
from sklearn.ensemble import RandomForestClassifier
|
11 |
+
import spacy
|
12 |
+
|
13 |
+
# nlp = spacy.load("en_core_web_trf")
|
14 |
+
nlp = spacy.load("en_core_web_sm")
|
15 |
+
|
16 |
+
|
17 |
|
18 |
class wrappedTokenizer(RobertaTokenizer):
|
19 |
def __call__(self, text_input):
|
|
|
30 |
})
|
31 |
return index_list
|
32 |
|
33 |
+
def load_db(db_metadata_path, db_constructor_path, db_params_path, exclusion_list_path):
|
34 |
'''
|
35 |
Function to load dataframe
|
36 |
|
|
|
48 |
db_constructor.dropna(inplace=True)
|
49 |
db_params = pd.read_csv(db_params_path)
|
50 |
db_params.dropna(inplace=True)
|
51 |
+
with open(exclusion_list_path, 'r') as f:
|
52 |
+
ex_list = f.read()
|
53 |
+
ex_list = ex_list.split("\n")
|
54 |
+
|
55 |
+
return db_metadata, db_constructor, db_params, ex_list
|
56 |
|
57 |
|
58 |
|
|
|
187 |
'''
|
188 |
results = retrieval_model(model_input)
|
189 |
library_ids = [item.get('id') for item in results]
|
190 |
+
scores = [item.get('similarity') for item in results]
|
191 |
library_names = [id_to_libname(item, db_metadata) for item in library_ids]
|
192 |
+
return library_ids, library_names, scores
|
193 |
|
194 |
def prepare_input_generative_model(library_ids, db_constructor):
|
195 |
'''
|
|
|
434 |
classifier_head: a random forest model
|
435 |
'''
|
436 |
# load db
|
437 |
+
db_metadata, db_constructor, db_params, ex_list = load_db(
|
438 |
config.get('db_metadata_path'),
|
439 |
config.get('db_constructor_path'),
|
440 |
+
config.get('db_params_path'),
|
441 |
+
config.get('exclusion_list_path')
|
442 |
)
|
443 |
|
444 |
# load model
|
|
|
455 |
config.get('classifier_head_path')
|
456 |
)
|
457 |
|
458 |
+
return db_metadata, db_constructor, db_params, ex_list, model_retrieval, model_generative, tokenizer_generative, model_classifier, classifier_head, tokenizer_classifier
|
459 |
|
460 |
def make_predictions(input_query,
|
461 |
model_retrieval,
|
462 |
model_generative,
|
463 |
model_classifier, classifier_head,
|
464 |
tokenizer_generative, tokenizer_classifier,
|
465 |
+
db_metadata, db_constructor, db_params, ex_list,
|
466 |
config):
|
467 |
'''
|
468 |
Function to retrieve relevant libraries, generate API usage patterns, and predict the hw configs
|
|
|
479 |
Returns:
|
480 |
predictions (list): a list of dictionary containing the prediction details
|
481 |
'''
|
482 |
+
print("retrieve libraries")
|
483 |
+
queries = extract_keywords(input_query.lower(), ex_list)
|
484 |
+
|
485 |
+
temp_list = []
|
486 |
+
for query in queries:
|
487 |
+
temp_library_ids, temp_library_names, temp_scores = retrieve_libraries(model_retrieval, query, db_metadata)
|
488 |
+
|
489 |
+
if len(temp_library_ids) > 0:
|
490 |
+
for id_, name, score in zip(temp_library_ids, temp_library_names, temp_scores):
|
491 |
+
temp_list.append((id_, name, score))
|
492 |
+
|
493 |
+
library_ids = []
|
494 |
+
library_names = []
|
495 |
+
if len(temp_list) > 0:
|
496 |
+
sorted_list = sorted(temp_list, key=lambda tup: tup[2], reverse=True)
|
497 |
+
sorted_list = sorted_list[:config.get('max_k')]
|
498 |
+
for item in sorted_list:
|
499 |
+
library_ids.append(item[0])
|
500 |
+
library_names.append(item[1])
|
501 |
|
502 |
if len(library_ids) == 0:
|
503 |
+
print("null libraries")
|
504 |
return "null"
|
505 |
|
506 |
print("generate usage patterns")
|
|
|
531 |
print("finished the predictions")
|
532 |
predictions = get_metadata_library(predictions, db_metadata)
|
533 |
|
534 |
+
return predictions
|
535 |
+
|
536 |
+
def extract_series(x):
|
537 |
+
name = x.replace("-", " ").replace("_", " ")
|
538 |
+
name = name.split()
|
539 |
+
series = []
|
540 |
+
for token in name:
|
541 |
+
if token.isalnum() and not(token.isalpha()) and not(token.isdigit()):
|
542 |
+
series.append(token)
|
543 |
+
if len(series) > 0:
|
544 |
+
return series
|
545 |
+
else:
|
546 |
+
return [x]
|
547 |
+
|
548 |
+
def extract_keywords(query, ex_list):
|
549 |
+
doc = nlp(query)
|
550 |
+
keyword_candidates = []
|
551 |
+
|
552 |
+
# extract keywords
|
553 |
+
for chunk in doc.noun_chunks:
|
554 |
+
temp_list = []
|
555 |
+
|
556 |
+
for token in chunk:
|
557 |
+
if token.text not in ex_list and token.pos_ not in ("DET", "PRON", "CCONJ", "NUM"):
|
558 |
+
temp_list.append(token.text)
|
559 |
+
|
560 |
+
if len(temp_list) > 0:
|
561 |
+
keyword_candidates.append(" ".join(temp_list))
|
562 |
+
|
563 |
+
filtered_keyword_candidates = []
|
564 |
+
for keyword in keyword_candidates:
|
565 |
+
temp_candidates = extract_series(keyword)
|
566 |
+
|
567 |
+
for keyword in temp_candidates:
|
568 |
+
|
569 |
+
if len(keyword.split()) > 1:
|
570 |
+
doc = nlp(keyword)
|
571 |
+
for chunk in doc.noun_chunks:
|
572 |
+
filtered_keyword_candidates.append(chunk.root.text)
|
573 |
+
else:
|
574 |
+
filtered_keyword_candidates.append(keyword)
|
575 |
+
|
576 |
+
if len(filtered_keyword_candidates) == 0:
|
577 |
+
filtered_keyword_candidates.append(query)
|
578 |
+
|
579 |
+
return filtered_keyword_candidates
|