Princess3 commited on
Commit
c593750
·
verified ·
1 Parent(s): 1ddbace

Upload 2 files

Browse files
Files changed (2) hide show
  1. m5.py +215 -0
  2. requirements.txt +12 -7
m5.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, xml.etree.ElementTree as ET, torch, torch.nn as nn, torch.nn.functional as F, numpy as np, logging, requests
2
+ from typing import List, Dict, Any, Optional
3
+ from collections import defaultdict
4
+ from accelerate import Accelerator
5
+ from transformers import AutoTokenizer, AutoModel
6
+ from sklearn.metrics.pairwise import cosine_similarity
7
+
8
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
9
+
10
+ class DM(nn.Module):
11
+ def __init__(self, s: Dict[str, List[Dict[str, Any]]]):
12
+ super(DM, self).__init__()
13
+ self.s = nn.ModuleDict()
14
+ if not s: s = {'default': [{'input_size': 128, 'output_size': 256, 'activation': 'relu', 'batch_norm': True, 'dropout': 0.1}]}
15
+ for sn, l in s.items():
16
+ self.s[sn] = nn.ModuleList()
17
+ for lp in l:
18
+ logging.info(f"Creating layer in section '{sn}' with params: {lp}")
19
+ self.s[sn].append(self.cl(lp))
20
+
21
+ def cl(self, lp: Dict[str, Any]) -> nn.Module:
22
+ l = [nn.Linear(lp['input_size'], lp['output_size'])]
23
+ if lp.get('batch_norm', True): l.append(nn.BatchNorm1d(lp['output_size']))
24
+ a = lp.get('activation', 'relu')
25
+ if a == 'relu': l.append(nn.ReLU(inplace=True))
26
+ elif a == 'tanh': l.append(nn.Tanh())
27
+ elif a == 'sigmoid': l.append(nn.Sigmoid())
28
+ elif a == 'leaky_relu': l.append(nn.LeakyReLU(negative_slope=0.01, inplace=True))
29
+ elif a == 'elu': l.append(nn.ELU(alpha=1.0, inplace=True))
30
+ elif a is not None: raise ValueError(f"Unsupported activation function: {a}")
31
+ if dr := lp.get('dropout', 0.0): l.append(nn.Dropout(p=dr))
32
+ if hl := lp.get('hidden_layers', []):
33
+ for hlp in hl: l.append(self.cl(hlp))
34
+ if lp.get('memory_augmentation', True): l.append(MAL(lp['output_size']))
35
+ if lp.get('hybrid_attention', True): l.append(HAL(lp['output_size']))
36
+ if lp.get('dynamic_flash_attention', True): l.append(DFAL(lp['output_size']))
37
+ return nn.Sequential(*l)
38
+
39
+ def forward(self, x: torch.Tensor, sn: Optional[str] = None) -> torch.Tensor:
40
+ if sn is not None:
41
+ if sn not in self.s: raise KeyError(f"Section '{sn}' not found in model")
42
+ for l in self.s[sn]: x = l(x)
43
+ else:
44
+ for sn, l in self.s.items():
45
+ for l in l: x = l(x)
46
+ return x
47
+
48
+ class MAL(nn.Module):
49
+ def __init__(self, s: int):
50
+ super(MAL, self).__init__()
51
+ self.m = nn.Parameter(torch.randn(s))
52
+
53
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
54
+ return x + self.m
55
+
56
+ class HAL(nn.Module):
57
+ def __init__(self, s: int):
58
+ super(HAL, self).__init__()
59
+ self.a = nn.MultiheadAttention(s, num_heads=8)
60
+
61
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
62
+ x = x.unsqueeze(1)
63
+ ao, _ = self.a(x, x, x)
64
+ return ao.squeeze(1)
65
+
66
+ class DFAL(nn.Module):
67
+ def __init__(self, s: int):
68
+ super(DFAL, self).__init__()
69
+ self.a = nn.MultiheadAttention(s, num_heads=8)
70
+
71
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
72
+ x = x.unsqueeze(1)
73
+ ao, _ = self.a(x, x, x)
74
+ return ao.squeeze(1)
75
+
76
+ def px(file_path: str) -> List[Dict[str, Any]]:
77
+ t = ET.parse(file_path)
78
+ r = t.getroot()
79
+ l = []
80
+ for ly in r.findall('.//layer'):
81
+ lp = {'input_size': int(ly.get('input_size', 128)), 'output_size': int(ly.get('output_size', 256)), 'activation': ly.get('activation', 'relu').lower()}
82
+ if lp['activation'] not in ['relu', 'tanh', 'sigmoid', 'none']: raise ValueError(f"Unsupported activation function: {lp['activation']}")
83
+ if lp['input_size'] <= 0 or lp['output_size'] <= 0: raise ValueError("Layer dimensions must be positive integers")
84
+ l.append(lp)
85
+ if not l: l.append({'input_size': 128, 'output_size': 256, 'activation': 'relu'})
86
+ return l
87
+
88
+ def cmf(folder_path: str) -> DM:
89
+ s = defaultdict(list)
90
+ if not os.path.exists(folder_path):
91
+ logging.warning(f"Folder {folder_path} does not exist. Creating model with default configuration.")
92
+ return DM({})
93
+ xf = True
94
+ for r, d, f in os.walk(folder_path):
95
+ for file in f:
96
+ if file.endswith('.xml'):
97
+ xf = True
98
+ fp = os.path.join(r, file)
99
+ try:
100
+ l = px(fp)
101
+ sn = os.path.basename(r).replace('.', '_')
102
+ s[sn].extend(l)
103
+ except Exception as e:
104
+ logging.error(f"Error processing {fp}: {str(e)}")
105
+ if not xf:
106
+ logging.warning("No XML files found. Creating model with default configuration.")
107
+ return DM({})
108
+ return DM(dict(s))
109
+
110
+ def ceas(folder_path: str, model_name: str = "sentence-transformers/all-MiniLM-L6-v2"):
111
+ t = AutoTokenizer.from_pretrained(model_name)
112
+ m = AutoModel.from_pretrained(model_name)
113
+ embeddings = []
114
+ ds = []
115
+ for r, d, f in os.walk(folder_path):
116
+ for file in f:
117
+ if file.endswith('.xml'):
118
+ fp = os.path.join(r, file)
119
+ try:
120
+ tree = ET.parse(fp)
121
+ root = tree.getroot()
122
+ for e in root.iter():
123
+ if e.text:
124
+ text = e.text.strip()
125
+ i = t(text, return_tensors="pt", truncation=True, padding=True)
126
+ with torch.no_grad():
127
+ emb = m(**i).last_hidden_state.mean(dim=1).numpy()
128
+ embeddings.append(emb)
129
+ ds.append(text)
130
+ except Exception as e:
131
+ logging.error(f"Error processing {fp}: {str(e)}")
132
+ embeddings = np.vstack(embeddings)
133
+ return embeddings, ds
134
+
135
+ def qvs(query: str, embeddings, ds, model_name: str = "sentence-transformers/all-MiniLM-L6-v2"):
136
+ t = AutoTokenizer.from_pretrained(model_name)
137
+ m = AutoModel.from_pretrained(model_name)
138
+ i = t(query, return_tensors="pt", truncation=True, padding=True)
139
+ with torch.no_grad():
140
+ qe = m(**i).last_hidden_state.mean(dim=1).numpy()
141
+ similarities = cosine_similarity(qe, embeddings)
142
+ top_k_indices = similarities[0].argsort()[-5:][::-1]
143
+ return [ds[i] for i in top_k_indices]
144
+
145
+ def fetch_courtlistener_data(query: str) -> List[Dict[str, Any]]:
146
+ base_url = "https://nzlii.org/cgi-bin/sinosrch.cgi"
147
+ params = {
148
+ "method": "auto",
149
+ "query": query,
150
+ "meta": "/nz",
151
+ "mask_path": "",
152
+ "results": "50",
153
+ "format": "json"
154
+ }
155
+ try:
156
+ response = requests.get(base_url, params=params, headers={"Accept": "application/json"}, timeout=10)
157
+ response.raise_for_status()
158
+ results = response.json().get("results", [])
159
+ processed_results = []
160
+ for result in results:
161
+ processed_results.append({
162
+ "title": result.get("title", ""),
163
+ "citation": result.get("citation", ""),
164
+ "date": result.get("date", ""),
165
+ "court": result.get("court", ""),
166
+ "summary": result.get("summary", ""),
167
+ "url": result.get("url", "")
168
+ })
169
+ return processed_results
170
+ except requests.exceptions.RequestException as e:
171
+ logging.error(f"Failed to fetch data from NZLII API: {str(e)}")
172
+ return []
173
+ except ValueError as e:
174
+ logging.error(f"Failed to parse NZLII API response: {str(e)}")
175
+ return []
176
+
177
+ def main():
178
+ fp = 'data'
179
+ m = cmf(fp)
180
+ logging.info(f"Created dynamic PyTorch model with sections: {list(m.s.keys())}")
181
+ fs = next(iter(m.s.keys()))
182
+ fl = m.s[fs][0]
183
+ ife = fl[0].in_features
184
+ si = torch.randn(1, ife)
185
+ o = m(si)
186
+ logging.info(f"Sample output shape: {o.shape}")
187
+ embeddings, ds = ceas(fp)
188
+ a = Accelerator()
189
+ o = torch.optim.Adam(m.parameters(), lr=0.001)
190
+ c = nn.CrossEntropyLoss()
191
+ ne = 10
192
+ d = torch.utils.data.TensorDataset(torch.randn(100, ife), torch.randint(0, 2, (100,)))
193
+ td = torch.utils.data.DataLoader(d, batch_size=16, shuffle=True)
194
+ m, o, td = a.prepare(m, o, td)
195
+ for e in range(ne):
196
+ m.train()
197
+ tl = 0
198
+ for bi, (i, l) in enumerate(td):
199
+ o.zero_grad()
200
+ o = m(i)
201
+ l = c(o, l)
202
+ a.backward(l)
203
+ o.step()
204
+ tl += l.item()
205
+ al = tl / len(td)
206
+ logging.info(f"Epoch {e+1}/{ne}, Average Loss: {al:.4f}")
207
+ uq = "example query text"
208
+ r = qvs(uq, embeddings, ds)
209
+ logging.info(f"Query results: {r}")
210
+
211
+ cl_data = fetch_courtlistener_data(uq)
212
+ logging.info(f"CourtListener API results: {cl_data}")
213
+
214
+ if __name__ == "__main__":
215
+ main()
requirements.txt CHANGED
@@ -1,7 +1,12 @@
1
- fastapi
2
- uvicorn
3
- requests
4
- llama-cpp-python
5
- psutil
6
- transformers
7
- torch
 
 
 
 
 
 
1
+ torch==1.9.0
2
+ torchvision==0.10.0
3
+ transformers==4.9.2
4
+ accelerate==0.5.1
5
+ numpy==1.21.2
6
+ scikit-learn==0.24.2
7
+ termcolor==1.1.0
8
+ requests==2.26.0
9
+ transformers
10
+ accelerate
11
+ psutil
12
+ termcolor