Spaces:
Runtime error
Runtime error
Upload m3.py
Browse files
m3.py
CHANGED
@@ -1,9 +1,10 @@
|
|
1 |
-
import os, xml.etree.ElementTree as ET, torch, torch.nn as nn, torch.nn.functional as F,
|
2 |
from typing import List, Dict, Any, Optional
|
3 |
from collections import defaultdict
|
4 |
from accelerate import Accelerator
|
5 |
from transformers import AutoTokenizer, AutoModel
|
6 |
from termcolor import colored
|
|
|
7 |
|
8 |
class DM(nn.Module):
|
9 |
def __init__(self, s: Dict[str, List[Dict[str, Any]]]):
|
@@ -108,7 +109,7 @@ def cmf(folder_path: str) -> DM:
|
|
108 |
def ceas(folder_path: str, model_name: str = "sentence-transformers/all-MiniLM-L6-v2"):
|
109 |
t = AutoTokenizer.from_pretrained(model_name)
|
110 |
m = AutoModel.from_pretrained(model_name)
|
111 |
-
|
112 |
ds = []
|
113 |
for r, d, f in os.walk(folder_path):
|
114 |
for file in f:
|
@@ -123,20 +124,22 @@ def ceas(folder_path: str, model_name: str = "sentence-transformers/all-MiniLM-L
|
|
123 |
i = t(text, return_tensors="pt", truncation=True, padding=True)
|
124 |
with torch.no_grad():
|
125 |
emb = m(**i).last_hidden_state.mean(dim=1).numpy()
|
126 |
-
|
127 |
ds.append(text)
|
128 |
except Exception as e:
|
129 |
print(colored(f"Error processing {fp}: {str(e)}", 'red'))
|
130 |
-
|
|
|
131 |
|
132 |
-
def qvs(query: str,
|
133 |
t = AutoTokenizer.from_pretrained(model_name)
|
134 |
m = AutoModel.from_pretrained(model_name)
|
135 |
i = t(query, return_tensors="pt", truncation=True, padding=True)
|
136 |
with torch.no_grad():
|
137 |
qe = m(**i).last_hidden_state.mean(dim=1).numpy()
|
138 |
-
|
139 |
-
|
|
|
140 |
|
141 |
def main():
|
142 |
fp = 'data'
|
@@ -148,7 +151,7 @@ def main():
|
|
148 |
si = torch.randn(1, ife)
|
149 |
o = m(si)
|
150 |
print(colored(f"Sample output shape: {o.shape}", 'green'))
|
151 |
-
|
152 |
a = Accelerator()
|
153 |
o = torch.optim.Adam(m.parameters(), lr=0.001)
|
154 |
c = nn.CrossEntropyLoss()
|
@@ -169,7 +172,7 @@ def main():
|
|
169 |
al = tl / len(td)
|
170 |
print(colored(f"Epoch {e+1}/{ne}, Average Loss: {al:.4f}", 'blue'))
|
171 |
uq = "example query text"
|
172 |
-
r = qvs(uq,
|
173 |
print(colored(f"Query results: {r}", 'magenta'))
|
174 |
|
175 |
if __name__ == "__main__":
|
|
|
1 |
+
import os, xml.etree.ElementTree as ET, torch, torch.nn as nn, torch.nn.functional as F, numpy as np
|
2 |
from typing import List, Dict, Any, Optional
|
3 |
from collections import defaultdict
|
4 |
from accelerate import Accelerator
|
5 |
from transformers import AutoTokenizer, AutoModel
|
6 |
from termcolor import colored
|
7 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
8 |
|
9 |
class DM(nn.Module):
|
10 |
def __init__(self, s: Dict[str, List[Dict[str, Any]]]):
|
|
|
109 |
def ceas(folder_path: str, model_name: str = "sentence-transformers/all-MiniLM-L6-v2"):
|
110 |
t = AutoTokenizer.from_pretrained(model_name)
|
111 |
m = AutoModel.from_pretrained(model_name)
|
112 |
+
embeddings = []
|
113 |
ds = []
|
114 |
for r, d, f in os.walk(folder_path):
|
115 |
for file in f:
|
|
|
124 |
i = t(text, return_tensors="pt", truncation=True, padding=True)
|
125 |
with torch.no_grad():
|
126 |
emb = m(**i).last_hidden_state.mean(dim=1).numpy()
|
127 |
+
embeddings.append(emb)
|
128 |
ds.append(text)
|
129 |
except Exception as e:
|
130 |
print(colored(f"Error processing {fp}: {str(e)}", 'red'))
|
131 |
+
embeddings = np.vstack(embeddings)
|
132 |
+
return embeddings, ds
|
133 |
|
134 |
+
def qvs(query: str, embeddings, ds, model_name: str = "sentence-transformers/all-MiniLM-L6-v2"):
|
135 |
t = AutoTokenizer.from_pretrained(model_name)
|
136 |
m = AutoModel.from_pretrained(model_name)
|
137 |
i = t(query, return_tensors="pt", truncation=True, padding=True)
|
138 |
with torch.no_grad():
|
139 |
qe = m(**i).last_hidden_state.mean(dim=1).numpy()
|
140 |
+
similarities = cosine_similarity(qe, embeddings)
|
141 |
+
top_k_indices = similarities[0].argsort()[-5:][::-1]
|
142 |
+
return [ds[i] for i in top_k_indices]
|
143 |
|
144 |
def main():
|
145 |
fp = 'data'
|
|
|
151 |
si = torch.randn(1, ife)
|
152 |
o = m(si)
|
153 |
print(colored(f"Sample output shape: {o.shape}", 'green'))
|
154 |
+
embeddings, ds = ceas(fp)
|
155 |
a = Accelerator()
|
156 |
o = torch.optim.Adam(m.parameters(), lr=0.001)
|
157 |
c = nn.CrossEntropyLoss()
|
|
|
172 |
al = tl / len(td)
|
173 |
print(colored(f"Epoch {e+1}/{ne}, Average Loss: {al:.4f}", 'blue'))
|
174 |
uq = "example query text"
|
175 |
+
r = qvs(uq, embeddings, ds)
|
176 |
print(colored(f"Query results: {r}", 'magenta'))
|
177 |
|
178 |
if __name__ == "__main__":
|