rshakked commited on
Commit
cf43376
Β·
1 Parent(s): 87fc4e6

fix: include missing predict_pipeline.py in repo with run_prediction_pipeline()

Browse files
Files changed (1) hide show
  1. predict_pipeline.py +55 -0
predict_pipeline.py CHANGED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import zipfile
2
+ import tempfile
3
+ from pathlib import Path
4
+ import torch
5
+ from transformers import DebertaV2Tokenizer, AutoModelForSequenceClassification
6
+
7
+ from train_abuse_model import (
8
+ MODEL_DIR,
9
+ device,
10
+ load_saved_model_and_tokenizer,
11
+ map_to_3_classes,
12
+ convert_to_label_strings
13
+ )
14
+
15
+ def run_prediction_pipeline(desc_input, chat_zip):
16
+ try:
17
+ # Start with the base input
18
+ merged_input = desc_input.strip()
19
+
20
+ # If a chat zip was uploaded
21
+ if chat_zip:
22
+ with tempfile.TemporaryDirectory() as tmpdir:
23
+ with zipfile.ZipFile(chat_zip.name, 'r') as zip_ref:
24
+ zip_ref.extractall(tmpdir)
25
+ chat_texts = []
26
+ for file in Path(tmpdir).glob("*.txt"):
27
+ with open(file, encoding="utf-8", errors="ignore") as f:
28
+ chat_texts.append(f.read())
29
+ full_chat = "\n".join(chat_texts)
30
+
31
+ # 🧠 MOCK summarization
32
+ summary = "[Mock summary of Hebrew WhatsApp chat...]"
33
+
34
+ # 🌍 MOCK translation
35
+ translated_summary = "[Translated summary in English]"
36
+
37
+ merged_input = f"{desc_input.strip()}\n\n[Summary]: {translated_summary}"
38
+
39
+ # Load classifier
40
+ tokenizer, model = load_saved_model_and_tokenizer()
41
+ inputs = tokenizer(merged_input, truncation=True, padding=True, max_length=512, return_tensors="pt").to(device)
42
+
43
+ with torch.no_grad():
44
+ outputs = model(**inputs).logits
45
+ probs = torch.sigmoid(outputs).cpu().numpy()
46
+
47
+ # Static threshold values (or load from config later)
48
+ best_low, best_high = 0.35, 0.65
49
+ pred_soft = map_to_3_classes(probs, best_low, best_high)
50
+ pred_str = convert_to_label_strings(pred_soft)
51
+
52
+ return merged_input, ", ".join(pred_str)
53
+
54
+ except Exception as e:
55
+ return f"❌ Prediction failed: {e}", ""