|
from components.vector_db_operations import get_collection_from_vector_db |
|
from components.vector_db_operations import retrieval |
|
from components.english_information_extraction import english_information_extraction |
|
from components.multi_lingual_model import MDFEND , loading_model_and_tokenizer |
|
from components.data_loading import preparing_data , loading_data |
|
from components.language_identification import language_identification |
|
|
|
|
|
|
|
def run_pipeline(input_text:str): |
|
|
|
language_dict = language_identification(input_text) |
|
language_code = next(iter(language_dict)) |
|
|
|
if language_code == "en": |
|
|
|
output_english = english_information_extraction(input_text) |
|
|
|
return output_english |
|
|
|
else: |
|
|
|
|
|
num_results = 1 |
|
path = "/content/drive/MyDrive/general_domains/vector_database" |
|
collection_name = "general_domains" |
|
|
|
|
|
collection = get_collection_from_vector_db(path , collection_name) |
|
|
|
domain , label_domain , distance = retrieval(input_text , num_results , collection ) |
|
|
|
if distance >1.45: |
|
domain = "undetermined" |
|
|
|
tokenizer , model = loading_model_and_tokenizer() |
|
|
|
df = preparing_data(input_text , label_domain) |
|
|
|
input_ids , input_masks , input_domains = loading_data(tokenizer , df ) |
|
|
|
labels = [] |
|
outputs = [] |
|
with torch.no_grad(): |
|
|
|
pred = model.forward(input_ids, input_masks , input_domains) |
|
labels.append([]) |
|
|
|
for output in pred: |
|
number = output.item() |
|
label = int(1) if number >= 0.5 else int(0) |
|
labels[-1].append(label) |
|
outputs.append(pred) |
|
|
|
discrimination_class = ["discriminative" if i == int(1) else "not discriminative" for i in labels[0]] |
|
|
|
|
|
return { "domain_label" :domain , |
|
"domain_score":distance , |
|
"discrimination_label" : discrimination_class[-1], |
|
"discrimination_score" : outputs[0][1:].item(), |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|