Prateek0515 commited on
Commit
7cc7f9f
·
verified ·
1 Parent(s): 2465222

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +119 -0
app.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ import numpy as np
5
+ from transformers import AutoTokenizer, AutoModel
6
+ from torchcrf import CRF
7
+ from huggingface_hub import hf_hub_download
8
+ import PyPDF2
9
+ from docx import Document
10
+
11
+ class PositionalEncoding(nn.Module):
12
+ def __init__(self, d_model, dropout=0.1, max_len=5000):
13
+ super().__init__()
14
+ self.dropout = nn.Dropout(p=dropout)
15
+ pe = torch.zeros(max_len, d_model)
16
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
17
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
18
+ pe[:, 0::2] = torch.sin(position * div_term)
19
+ pe[:, 1::2] = torch.cos(position * div_term)
20
+ self.register_buffer('pe', pe.unsqueeze(0))
21
+
22
+ def forward(self, x):
23
+ return x + self.pe[:, :x.size(1)]
24
+
25
+ class VanillaTransformer(nn.Module):
26
+ def __init__(self, d_model=768, nhead=8, num_layers=3, dim_feedforward=2048, dropout=0.1):
27
+ super().__init__()
28
+ self.pos_encoder = PositionalEncoding(d_model, dropout)
29
+ encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout, activation='gelu', batch_first=True)
30
+ self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
31
+
32
+ def forward(self, src, src_key_padding_mask=None):
33
+ src = self.pos_encoder(src)
34
+ return self.transformer(src, src_key_padding_mask=src_key_padding_mask)
35
+
36
+ class HierarchicalLegalSegModel(nn.Module):
37
+ def __init__(self, longformer_model, num_labels, hidden_dim=768, transformer_layers=3, transformer_heads=8, dropout=0.1):
38
+ super().__init__()
39
+ self.longformer = longformer_model
40
+ self.hidden_dim = hidden_dim
41
+ self.vanilla_transformer = VanillaTransformer(d_model=hidden_dim, nhead=transformer_heads, num_layers=transformer_layers, dim_feedforward=hidden_dim*4, dropout=dropout)
42
+ self.classifier = nn.Linear(hidden_dim, num_labels)
43
+ self.crf = CRF(num_labels, batch_first=True)
44
+ self.dropout = nn.Dropout(dropout)
45
+
46
+ def encode_sentences(self, input_ids, attention_mask):
47
+ batch_size, num_sentences, max_seq_len = input_ids.shape
48
+ input_ids_flat = input_ids.view(-1, max_seq_len)
49
+ attention_mask_flat = attention_mask.view(-1, max_seq_len)
50
+ outputs = self.longformer(input_ids=input_ids_flat, attention_mask=attention_mask_flat)
51
+ cls_embeddings = outputs.last_hidden_state[:, 0, :]
52
+ return cls_embeddings.view(batch_size, num_sentences, self.hidden_dim)
53
+
54
+ def forward(self, input_ids, attention_mask, sentence_mask=None):
55
+ embeddings = self.encode_sentences(input_ids, attention_mask)
56
+ embeddings = self.dropout(embeddings)
57
+ output = self.vanilla_transformer(embeddings, src_key_padding_mask=~sentence_mask if sentence_mask is not None else None)
58
+ emissions = self.classifier(output)
59
+ return self.crf.decode(emissions, mask=sentence_mask)
60
+
61
+ device = torch.device("cpu")
62
+ tokenizer = AutoTokenizer.from_pretrained("lexlms/legal-longformer-base")
63
+ longformer = AutoModel.from_pretrained("lexlms/legal-longformer-base").to(device)
64
+ for param in longformer.parameters():
65
+ param.requires_grad = False
66
+
67
+ model = HierarchicalLegalSegModel(longformer, 7).to(device)
68
+ model_path = hf_hub_download(repo_id="Prateek0515/legal-document-segmentation", filename="model.pth")
69
+ model.load_state_dict(torch.load(model_path, map_location=device))
70
+ model.eval()
71
+
72
+ id2label = {0: "Arguments of Petitioner", 1: "Arguments of Respondent", 2: "Decision", 3: "Facts", 4: "Issue", 5: "None", 6: "Reasoning"}
73
+
74
+ def extract_text_from_pdf(file):
75
+ reader = PyPDF2.PdfReader(file)
76
+ text = ""
77
+ for page in reader.pages:
78
+ text += page.extract_text()
79
+ return text.strip()
80
+
81
+ def extract_text_from_docx(file):
82
+ doc = Document(file)
83
+ return "\n".join([para.text for para in doc.paragraphs]).strip()
84
+
85
+ def predict(text_input, file_input):
86
+ try:
87
+ if file_input is not None:
88
+ if file_input.name.endswith('.pdf'):
89
+ text = extract_text_from_pdf(file_input.name)
90
+ elif file_input.name.endswith('.docx'):
91
+ text = extract_text_from_docx(file_input.name)
92
+ elif file_input.name.endswith('.txt'):
93
+ with open(file_input.name, 'r') as f:
94
+ text = f.read()
95
+ else:
96
+ return "❌ Unsupported file type"
97
+ else:
98
+ text = text_input
99
+
100
+ if not text:
101
+ return "⚠️ Please provide text"
102
+
103
+ encoded = tokenizer(text, padding="max_length", truncation=True, max_length=512, return_tensors="pt")
104
+ input_ids = encoded["input_ids"].unsqueeze(1).to(device)
105
+ attention_mask = encoded["attention_mask"].unsqueeze(1).to(device)
106
+ sentence_mask = torch.ones(1, 1, dtype=torch.bool).to(device)
107
+
108
+ with torch.no_grad():
109
+ predictions = model(input_ids, attention_mask, sentence_mask=sentence_mask)
110
+
111
+ label = id2label[predictions[0][0]]
112
+ return f"✅ **Label:** {label}\n\n📄 **Text:** {text[:300]}..."
113
+ except Exception as e:
114
+ return f"❌ Error: {str(e)}"
115
+
116
+ demo = gr.Interface(fn=predict, inputs=[gr.Textbox(label="Enter Legal Text", lines=5), gr.File(label="Or Upload (PDF/DOCX/TXT)")], outputs=gr.Textbox(label="Result", lines=5), title="⚖️ Legal Document Segmentation", api_name="predict")
117
+
118
+ if __name__ == "__main__":
119
+ demo.launch()