TestApp / run.py
menikev's picture
Upload run.py
57a2c61 verified
raw
history blame
2.02 kB
from components.vector_db_operations import get_collection_from_vector_db
from components.vector_db_operations import retrieval
from components.english_information_extraction import english_information_extraction
from components.multi_lingual_model import MDFEND , loading_model_and_tokenizer
from components.data_loading import preparing_data , loading_data
from components.language_identification import language_identification
def run_pipeline(input_text:str):
language_dict = language_identification(input_text)
language_code = next(iter(language_dict))
if language_code == "en":
output_english = english_information_extraction(input_text)
return output_english
else:
num_results = 1
path = "/content/drive/MyDrive/general_domains/vector_database"
collection_name = "general_domains"
collection = get_collection_from_vector_db(path , collection_name)
domain , label_domain , distance = retrieval(input_text , num_results , collection )
if distance >1.45:
domain = "undetermined"
tokenizer , model = loading_model_and_tokenizer()
df = preparing_data(input_text , label_domain)
input_ids , input_masks , input_domains = loading_data(tokenizer , df )
labels = []
outputs = []
with torch.no_grad():
pred = model.forward(input_ids, input_masks , input_domains)
labels.append([])
for output in pred:
number = output.item()
label = int(1) if number >= 0.5 else int(0)
labels[-1].append(label)
outputs.append(pred)
discrimination_class = ["discriminative" if i == int(1) else "not discriminative" for i in labels[0]]
return { "domain_label" :domain ,
"domain_score":distance ,
"discrimination_label" : discrimination_class[-1],
"discrimination_score" : outputs[0][1:].item(),
}