from components.vector_db_operations import get_collection_from_vector_db from components.vector_db_operations import retrieval from components.english_information_extraction import english_information_extraction from components.multi_lingual_model import MDFEND , loading_model_and_tokenizer from components.data_loading import preparing_data , loading_data from components.language_identification import language_identification def run_pipeline(input_text:str): language_dict = language_identification(input_text) language_code = next(iter(language_dict)) if language_code == "en": output_english = english_information_extraction(input_text) return output_english else: num_results = 1 path = "/content/drive/MyDrive/general_domains/vector_database" collection_name = "general_domains" collection = get_collection_from_vector_db(path , collection_name) domain , label_domain , distance = retrieval(input_text , num_results , collection ) if distance >1.45: domain = "undetermined" tokenizer , model = loading_model_and_tokenizer() df = preparing_data(input_text , label_domain) input_ids , input_masks , input_domains = loading_data(tokenizer , df ) labels = [] outputs = [] with torch.no_grad(): pred = model.forward(input_ids, input_masks , input_domains) labels.append([]) for output in pred: number = output.item() label = int(1) if number >= 0.5 else int(0) labels[-1].append(label) outputs.append(pred) discrimination_class = ["discriminative" if i == int(1) else "not discriminative" for i in labels[0]] return { "domain_label" :domain , "domain_score":distance , "discrimination_label" : discrimination_class[-1], "discrimination_score" : outputs[0][1:].item(), }