Princess3 commited on
Commit
0f463a5
·
verified ·
1 Parent(s): 392bc58

Delete m5.py

Browse files
Files changed (1) hide show
  1. m5.py +0 -229
m5.py DELETED
@@ -1,229 +0,0 @@
1
- import os, xml.etree.ElementTree as ET, torch, torch.nn as nn, torch.nn.functional as F, numpy as np, logging, requests
2
- from typing import List, Dict, Any, Optional
3
- from collections import defaultdict
4
- from accelerate import Accelerator
5
- from transformers import AutoTokenizer, AutoModel
6
- from sklearn.metrics.pairwise import cosine_similarity
7
- import termcolor
8
-
9
- # Set the cache directory path
10
- cache_dir = '/app/cache'
11
-
12
- # Create the directory if it doesn't exist
13
- if not os.path.exists(cache_dir):
14
- os.makedirs(cache_dir)
15
-
16
- # Set the environment variable
17
- os.environ['TRANSFORMERS_CACHE'] = cache_dir
18
-
19
- # Verify the environment variable is set
20
- print(f"TRANSFORMERS_CACHE is set to: {os.environ['TRANSFORMERS_CACHE']}")
21
-
22
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
23
-
24
- class DM(nn.Module):
25
- def __init__(self, s: Dict[str, List[Dict[str, Any]]]):
26
- super(DM, self).__init__()
27
- self.s = nn.ModuleDict()
28
- if not s: s = {'default': [{'input_size': 128, 'output_size': 256, 'activation': 'relu', 'batch_norm': True, 'dropout': 0.1}]}
29
- for sn, l in s.items():
30
- self.s[sn] = nn.ModuleList()
31
- for lp in l:
32
- logging.info(f"Creating layer in section '{sn}' with params: {lp}")
33
- self.s[sn].append(self.cl(lp))
34
-
35
- def cl(self, lp: Dict[str, Any]) -> nn.Module:
36
- l = [nn.Linear(lp['input_size'], lp['output_size'])]
37
- if lp.get('batch_norm', True): l.append(nn.BatchNorm1d(lp['output_size']))
38
- a = lp.get('activation', 'relu')
39
- if a == 'relu': l.append(nn.ReLU(inplace=True))
40
- elif a == 'tanh': l.append(nn.Tanh())
41
- elif a == 'sigmoid': l.append(nn.Sigmoid())
42
- elif a == 'leaky_relu': l.append(nn.LeakyReLU(negative_slope=0.01, inplace=True))
43
- elif a == 'elu': l.append(nn.ELU(alpha=1.0, inplace=True))
44
- elif a is not None: raise ValueError(f"Unsupported activation function: {a}")
45
- if dr := lp.get('dropout', 0.0): l.append(nn.Dropout(p=dr))
46
- if hl := lp.get('hidden_layers', []):
47
- for hlp in hl: l.append(self.cl(hlp))
48
- if lp.get('memory_augmentation', True): l.append(MAL(lp['output_size']))
49
- if lp.get('hybrid_attention', True): l.append(HAL(lp['output_size']))
50
- if lp.get('dynamic_flash_attention', True): l.append(DFAL(lp['output_size']))
51
- return nn.Sequential(*l)
52
-
53
- def forward(self, x: torch.Tensor, sn: Optional[str] = None) -> torch.Tensor:
54
- if sn is not None:
55
- if sn not in self.s: raise KeyError(f"Section '{sn}' not found in model")
56
- for l in self.s[sn]: x = l(x)
57
- else:
58
- for sn, l in self.s.items():
59
- for l in l: x = l(x)
60
- return x
61
-
62
- class MAL(nn.Module):
63
- def __init__(self, s: int):
64
- super(MAL, self).__init__()
65
- self.m = nn.Parameter(torch.randn(s))
66
-
67
- def forward(self, x: torch.Tensor) -> torch.Tensor:
68
- return x + self.m
69
-
70
- class HAL(nn.Module):
71
- def __init__(self, s: int):
72
- super(HAL, self).__init__()
73
- self.a = nn.MultiheadAttention(s, num_heads=8)
74
-
75
- def forward(self, x: torch.Tensor) -> torch.Tensor:
76
- x = x.unsqueeze(1)
77
- ao, _ = self.a(x, x, x)
78
- return ao.squeeze(1)
79
-
80
- class DFAL(nn.Module):
81
- def __init__(self, s: int):
82
- super(DFAL, self).__init__()
83
- self.a = nn.MultiheadAttention(s, num_heads=8)
84
-
85
- def forward(self, x: torch.Tensor) -> torch.Tensor:
86
- x = x.unsqueeze(1)
87
- ao, _ = self.a(x, x, x)
88
- return ao.squeeze(1)
89
-
90
- def px(file_path: str) -> List[Dict[str, Any]]:
91
- t = ET.parse(file_path)
92
- r = t.getroot()
93
- l = []
94
- for ly in r.findall('.//layer'):
95
- lp = {'input_size': int(ly.get('input_size', 128)), 'output_size': int(ly.get('output_size', 256)), 'activation': ly.get('activation', 'relu').lower()}
96
- if lp['activation'] not in ['relu', 'tanh', 'sigmoid', 'none']: raise ValueError(f"Unsupported activation function: {lp['activation']}")
97
- if lp['input_size'] <= 0 or lp['output_size'] <= 0: raise ValueError("Layer dimensions must be positive integers")
98
- l.append(lp)
99
- if not l: l.append({'input_size': 128, 'output_size': 256, 'activation': 'relu'})
100
- return l
101
-
102
- def cmf(folder_path: str) -> DM:
103
- s = defaultdict(list)
104
- if not os.path.exists(folder_path):
105
- logging.warning(f"Folder {folder_path} does not exist. Creating model with default configuration.")
106
- return DM({})
107
- xf = True
108
- for r, d, f in os.walk(folder_path):
109
- for file in f:
110
- if file.endswith('.xml'):
111
- xf = True
112
- fp = os.path.join(r, file)
113
- try:
114
- l = px(fp)
115
- sn = os.path.basename(r).replace('.', '_')
116
- s[sn].extend(l)
117
- except Exception as e:
118
- logging.error(f"Error processing {fp}: {str(e)}")
119
- if not xf:
120
- logging.warning("No XML files found. Creating model with default configuration.")
121
- return DM({})
122
- return DM(dict(s))
123
-
124
- def ceas(folder_path: str, model_name: str = "sentence-transformers/all-MiniLM-L6-v2"):
125
- t = AutoTokenizer.from_pretrained(model_name)
126
- m = AutoModel.from_pretrained(model_name)
127
- embeddings = []
128
- ds = []
129
- for r, d, f in os.walk(folder_path):
130
- for file in f:
131
- if file.endswith('.xml'):
132
- fp = os.path.join(r, file)
133
- try:
134
- tree = ET.parse(fp)
135
- root = tree.getroot()
136
- for e in root.iter():
137
- if e.text:
138
- text = e.text.strip()
139
- i = t(text, return_tensors="pt", truncation=True, padding=True)
140
- with torch.no_grad():
141
- emb = m(**i).last_hidden_state.mean(dim=1).numpy()
142
- embeddings.append(emb)
143
- ds.append(text)
144
- except Exception as e:
145
- logging.error(f"Error processing {fp}: {str(e)}")
146
- embeddings = np.vstack(embeddings)
147
- return embeddings, ds
148
-
149
- def qvs(query: str, embeddings, ds, model_name: str = "sentence-transformers/all-MiniLM-L6-v2"):
150
- t = AutoTokenizer.from_pretrained(model_name)
151
- m = AutoModel.from_pretrained(model_name)
152
- i = t(query, return_tensors="pt", truncation=True, padding=True)
153
- with torch.no_grad():
154
- qe = m(**i).last_hidden_state.mean(dim=1).numpy()
155
- similarities = cosine_similarity(qe, embeddings)
156
- top_k_indices = similarities[0].argsort()[-5:][::-1]
157
- return [ds[i] for i in top_k_indices]
158
-
159
- def fetch_courtlistener_data(query: str) -> List[Dict[str, Any]]:
160
- base_url = "https://nzlii.org/cgi-bin/sinosrch.cgi"
161
- params = {
162
- "method": "auto",
163
- "query": query,
164
- "meta": "/nz",
165
- "mask_path": "",
166
- "results": "50",
167
- "format": "json"
168
- }
169
- try:
170
- response = requests.get(base_url, params=params, headers={"Accept": "application/json"}, timeout=10)
171
- response.raise_for_status()
172
- results = response.json().get("results", [])
173
- processed_results = []
174
- for result in results:
175
- processed_results.append({
176
- "title": result.get("title", ""),
177
- "citation": result.get("citation", ""),
178
- "date": result.get("date", ""),
179
- "court": result.get("court", ""),
180
- "summary": result.get("summary", ""),
181
- "url": result.get("url", "")
182
- })
183
- return processed_results
184
- except requests.exceptions.RequestException as e:
185
- logging.error(f"Failed to fetch data from NZLII API: {str(e)}")
186
- return []
187
- except ValueError as e:
188
- logging.error(f"Failed to parse NZLII API response: {str(e)}")
189
- return []
190
-
191
- def main():
192
- fp = 'data'
193
- m = cmf(fp)
194
- logging.info(f"Created dynamic PyTorch model with sections: {list(m.s.keys())}")
195
- fs = next(iter(m.s.keys()))
196
- fl = m.s[fs][0]
197
- ife = fl[0].in_features
198
- si = torch.randn(1, ife)
199
- o = m(si)
200
- logging.info(f"Sample output shape: {o.shape}")
201
- embeddings, ds = ceas(fp)
202
- a = Accelerator()
203
- o = torch.optim.Adam(m.parameters(), lr=0.001)
204
- c = nn.CrossEntropyLoss()
205
- ne = 10
206
- d = torch.utils.data.TensorDataset(torch.randn(100, ife), torch.randint(0, 2, (100,)))
207
- td = torch.utils.data.DataLoader(d, batch_size=16, shuffle=True)
208
- m, o, td = a.prepare(m, o, td)
209
- for e in range(ne):
210
- m.train()
211
- tl = 0
212
- for bi, (i, l) in enumerate(td):
213
- o.zero_grad()
214
- o = m(i)
215
- l = c(o, l)
216
- a.backward(l)
217
- o.step()
218
- tl += l.item()
219
- al = tl / len(td)
220
- logging.info(f"Epoch {e+1}/{ne}, Average Loss: {al:.4f}")
221
- uq = "example query text"
222
- r = qvs(uq, embeddings, ds)
223
- logging.info(f"Query results: {r}")
224
-
225
- cl_data = fetch_courtlistener_data(uq)
226
- logging.info(f"CourtListener API results: {cl_data}")
227
-
228
- if __name__ == "__main__":
229
- main()