Princess3 commited on
Commit
2457015
1 Parent(s): 3cb7033

Upload m3.py

Browse files
Files changed (1) hide show
  1. m3.py +12 -9
m3.py CHANGED
@@ -1,9 +1,10 @@
1
- import os, xml.etree.ElementTree as ET, torch, torch.nn as nn, torch.nn.functional as F, faiss, numpy as np
2
  from typing import List, Dict, Any, Optional
3
  from collections import defaultdict
4
  from accelerate import Accelerator
5
  from transformers import AutoTokenizer, AutoModel
6
  from termcolor import colored
 
7
 
8
  class DM(nn.Module):
9
  def __init__(self, s: Dict[str, List[Dict[str, Any]]]):
@@ -108,7 +109,7 @@ def cmf(folder_path: str) -> DM:
108
  def ceas(folder_path: str, model_name: str = "sentence-transformers/all-MiniLM-L6-v2"):
109
  t = AutoTokenizer.from_pretrained(model_name)
110
  m = AutoModel.from_pretrained(model_name)
111
- vs = faiss.IndexFlatL2(384)
112
  ds = []
113
  for r, d, f in os.walk(folder_path):
114
  for file in f:
@@ -123,20 +124,22 @@ def ceas(folder_path: str, model_name: str = "sentence-transformers/all-MiniLM-L
123
  i = t(text, return_tensors="pt", truncation=True, padding=True)
124
  with torch.no_grad():
125
  emb = m(**i).last_hidden_state.mean(dim=1).numpy()
126
- vs.add(emb)
127
  ds.append(text)
128
  except Exception as e:
129
  print(colored(f"Error processing {fp}: {str(e)}", 'red'))
130
- return vs, ds
 
131
 
132
- def qvs(query: str, vs, ds, model_name: str = "sentence-transformers/all-MiniLM-L6-v2"):
133
  t = AutoTokenizer.from_pretrained(model_name)
134
  m = AutoModel.from_pretrained(model_name)
135
  i = t(query, return_tensors="pt", truncation=True, padding=True)
136
  with torch.no_grad():
137
  qe = m(**i).last_hidden_state.mean(dim=1).numpy()
138
- D, I = vs.search(qe, k=5)
139
- return [ds[i] for i in I[0]]
 
140
 
141
  def main():
142
  fp = 'data'
@@ -148,7 +151,7 @@ def main():
148
  si = torch.randn(1, ife)
149
  o = m(si)
150
  print(colored(f"Sample output shape: {o.shape}", 'green'))
151
- vs, ds = ceas(fp)
152
  a = Accelerator()
153
  o = torch.optim.Adam(m.parameters(), lr=0.001)
154
  c = nn.CrossEntropyLoss()
@@ -169,7 +172,7 @@ def main():
169
  al = tl / len(td)
170
  print(colored(f"Epoch {e+1}/{ne}, Average Loss: {al:.4f}", 'blue'))
171
  uq = "example query text"
172
- r = qvs(uq, vs, ds)
173
  print(colored(f"Query results: {r}", 'magenta'))
174
 
175
  if __name__ == "__main__":
 
1
+ import os, xml.etree.ElementTree as ET, torch, torch.nn as nn, torch.nn.functional as F, numpy as np
2
  from typing import List, Dict, Any, Optional
3
  from collections import defaultdict
4
  from accelerate import Accelerator
5
  from transformers import AutoTokenizer, AutoModel
6
  from termcolor import colored
7
+ from sklearn.metrics.pairwise import cosine_similarity
8
 
9
  class DM(nn.Module):
10
  def __init__(self, s: Dict[str, List[Dict[str, Any]]]):
 
109
  def ceas(folder_path: str, model_name: str = "sentence-transformers/all-MiniLM-L6-v2"):
110
  t = AutoTokenizer.from_pretrained(model_name)
111
  m = AutoModel.from_pretrained(model_name)
112
+ embeddings = []
113
  ds = []
114
  for r, d, f in os.walk(folder_path):
115
  for file in f:
 
124
  i = t(text, return_tensors="pt", truncation=True, padding=True)
125
  with torch.no_grad():
126
  emb = m(**i).last_hidden_state.mean(dim=1).numpy()
127
+ embeddings.append(emb)
128
  ds.append(text)
129
  except Exception as e:
130
  print(colored(f"Error processing {fp}: {str(e)}", 'red'))
131
+ embeddings = np.vstack(embeddings)
132
+ return embeddings, ds
133
 
134
+ def qvs(query: str, embeddings, ds, model_name: str = "sentence-transformers/all-MiniLM-L6-v2"):
135
  t = AutoTokenizer.from_pretrained(model_name)
136
  m = AutoModel.from_pretrained(model_name)
137
  i = t(query, return_tensors="pt", truncation=True, padding=True)
138
  with torch.no_grad():
139
  qe = m(**i).last_hidden_state.mean(dim=1).numpy()
140
+ similarities = cosine_similarity(qe, embeddings)
141
+ top_k_indices = similarities[0].argsort()[-5:][::-1]
142
+ return [ds[i] for i in top_k_indices]
143
 
144
  def main():
145
  fp = 'data'
 
151
  si = torch.randn(1, ife)
152
  o = m(si)
153
  print(colored(f"Sample output shape: {o.shape}", 'green'))
154
+ embeddings, ds = ceas(fp)
155
  a = Accelerator()
156
  o = torch.optim.Adam(m.parameters(), lr=0.001)
157
  c = nn.CrossEntropyLoss()
 
172
  al = tl / len(td)
173
  print(colored(f"Epoch {e+1}/{ne}, Average Loss: {al:.4f}", 'blue'))
174
  uq = "example query text"
175
+ r = qvs(uq, embeddings, ds)
176
  print(colored(f"Query results: {r}", 'magenta'))
177
 
178
  if __name__ == "__main__":