menikev commited on
Commit
57a2c61
1 Parent(s): 740e228

Upload run.py

Browse files
Files changed (1) hide show
  1. run.py +76 -0
run.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from components.vector_db_operations import get_collection_from_vector_db
2
+ from components.vector_db_operations import retrieval
3
+ from components.english_information_extraction import english_information_extraction
4
+ from components.multi_lingual_model import MDFEND , loading_model_and_tokenizer
5
+ from components.data_loading import preparing_data , loading_data
6
+ from components.language_identification import language_identification
7
+
8
+
9
+
10
+ def run_pipeline(input_text:str):
11
+
12
+ language_dict = language_identification(input_text)
13
+ language_code = next(iter(language_dict))
14
+
15
+ if language_code == "en":
16
+
17
+ output_english = english_information_extraction(input_text)
18
+
19
+ return output_english
20
+
21
+ else:
22
+
23
+
24
+ num_results = 1
25
+ path = "/content/drive/MyDrive/general_domains/vector_database"
26
+ collection_name = "general_domains"
27
+
28
+
29
+ collection = get_collection_from_vector_db(path , collection_name)
30
+
31
+ domain , label_domain , distance = retrieval(input_text , num_results , collection )
32
+
33
+ if distance >1.45:
34
+ domain = "undetermined"
35
+
36
+ tokenizer , model = loading_model_and_tokenizer()
37
+
38
+ df = preparing_data(input_text , label_domain)
39
+
40
+ input_ids , input_masks , input_domains = loading_data(tokenizer , df )
41
+
42
+ labels = []
43
+ outputs = []
44
+ with torch.no_grad():
45
+
46
+ pred = model.forward(input_ids, input_masks , input_domains)
47
+ labels.append([])
48
+
49
+ for output in pred:
50
+ number = output.item()
51
+ label = int(1) if number >= 0.5 else int(0)
52
+ labels[-1].append(label)
53
+ outputs.append(pred)
54
+
55
+ discrimination_class = ["discriminative" if i == int(1) else "not discriminative" for i in labels[0]]
56
+
57
+
58
+ return { "domain_label" :domain ,
59
+ "domain_score":distance ,
60
+ "discrimination_label" : discrimination_class[-1],
61
+ "discrimination_score" : outputs[0][1:].item(),
62
+ }
63
+
64
+
65
+
66
+
67
+
68
+
69
+
70
+
71
+
72
+
73
+
74
+
75
+
76
+