File size: 5,570 Bytes
9844436
 
 
9ae222c
 
9844436
 
 
 
 
 
 
 
 
 
9ae222c
 
 
 
 
 
 
 
 
 
 
2ac5e53
9ae222c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9844436
 
2ac5e53
9ae222c
 
 
 
 
 
2ac5e53
9ae222c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9844436
9ae222c
 
 
9844436
9ae222c
9844436
9ae222c
9844436
 
9ae222c
 
9844436
9ae222c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9844436
 
 
 
9ae222c
 
9844436
9ae222c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9844436
 
9ae222c
9844436
 
 
 
 
9ae222c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
from app.core.config import settings
import logging
import os
import zipfile
import hashlib

logger = logging.getLogger(__name__)

class LegalBertService:
    def __init__(self):
        self.device = "cpu"
        self.tokenizer = None
        self.model = None
        self._load_model()
    
    def _extract_model_from_zip(self, zipPath: str, extractPath: str):
        """Extract LegalBERT model from zip file"""
        try:
            if not os.path.exists(zipPath):
                logger.warning(f"Model zip file not found: {zipPath}")
                return False
            
            if not os.path.exists(extractPath):
                os.makedirs(extractPath)
                logger.info(f"Created model directory: {extractPath}")
            
           
            if os.path.exists(os.path.join(extractPath, "config.json")):
                logger.info("Model already extracted")
                return True
            
            logger.info(f"Extracting model from {zipPath} to {extractPath}")
            with zipfile.ZipFile(zipPath, 'r') as zipRef:
                zipRef.extractall(extractPath)
            
            logger.info("Model extraction completed")
            return True
            
        except Exception as e:
            logger.error(f"Failed to extract model: {str(e)}")
            return False
    
    def _load_model(self):
        try:
            
            zipPath = os.path.join("./models", "legalbert_epoch4.zip")
            
            if os.path.exists(zipPath):
                if self._extract_model_from_zip(zipPath, settings.legal_bert_model_path):
                    logger.info("Model zip file found and extracted")
            
            
            if os.path.exists(settings.legal_bert_model_path) and os.path.exists(os.path.join(settings.legal_bert_model_path, "config.json")):
                try:
                    import torch
                    import torch.nn.functional as F
                    from transformers import AutoTokenizer, AutoModelForSequenceClassification
                    
                    self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
                    logger.info(f"Loading LegalBERT model from {settings.legal_bert_model_path}")
                    
                    self.tokenizer = AutoTokenizer.from_pretrained(settings.legal_bert_model_path)
                    self.model = AutoModelForSequenceClassification.from_pretrained(
                        settings.legal_bert_model_path
                    ).to(self.device)
                    
                    logger.info(f"LegalBERT model loaded successfully on {self.device}")
                    
                except ImportError:
                    logger.warning("torch/transformers not installed - using placeholder mode")
                except Exception as e:
                    logger.error(f"Failed to load actual model: {str(e)}")
            else:
                logger.warning(f"LegalBERT model files not found in: {settings.legal_bert_model_path}")
                logger.info("Place your legalbert_epoch4.zip in ./models/ or model files directly in ./models/legalbert_model/")
                
        except Exception as e:
            logger.error(f"Failed to initialize LegalBERT service: {str(e)}")
    
    def predictVerdict(self, inputText: str) -> str:
        if not self.is_model_loaded():
            logger.info("Using placeholder verdict prediction")
            textHash = int(hashlib.md5(inputText.encode()).hexdigest(), 16)
            return "guilty" if textHash % 2 == 1 else "not guilty"
        
        try:
            import torch
            import torch.nn.functional as F
            
            inputs = self.tokenizer(
                inputText, 
                return_tensors="pt", 
                truncation=True, 
                padding=True
            ).to(self.device)
            
            with torch.no_grad():
                logits = self.model(**inputs).logits
                probabilities = F.softmax(logits, dim=1)
                predictedLabel = torch.argmax(probabilities, dim=1).item()
            
            return "guilty" if predictedLabel == 1 else "not guilty"
            
        except Exception as e:
            logger.error(f"Error predicting verdict: {str(e)}")
            return "not guilty"
    
    def getConfidence(self, inputText: str) -> float:
        if not self.is_model_loaded():
            logger.info("Using placeholder confidence score")
            textHash = int(hashlib.md5(inputText.encode()).hexdigest(), 16)
            return 0.5 + (textHash % 100) / 200.0
        
        try:
            import torch
            import torch.nn.functional as F
            
            inputs = self.tokenizer(
                inputText, 
                return_tensors="pt", 
                truncation=True, 
                padding=True
            ).to(self.device)
            
            with torch.no_grad():
                logits = self.model(**inputs).logits
                probabilities = F.softmax(logits, dim=1)
            
            return float(torch.max(probabilities).item())
            
        except Exception as e:
            logger.error(f"Error getting confidence: {str(e)}")
            return 0.5
    
    def is_model_loaded(self) -> bool:
        return self.model is not None and self.tokenizer is not None
    
    def get_device(self) -> str:
        return str(self.device)
    
    def is_healthy(self) -> bool:
        return True