Add streamlit widget support to the pipeline
Browse files- pipeline.py +31 -8
pipeline.py
CHANGED
@@ -6,7 +6,7 @@ from io import BytesIO
|
|
6 |
import numpy as np
|
7 |
import requests
|
8 |
import torch
|
9 |
-
from tqdm import tqdm
|
10 |
from transformers import Pipeline
|
11 |
|
12 |
|
@@ -96,26 +96,38 @@ def download_and_extract(repos, headers=None):
|
|
96 |
|
97 |
|
98 |
class RepoEmbeddingPipeline(Pipeline):
|
99 |
-
def __init__(self, github_token=None, *args, **kwargs):
|
100 |
super().__init__(*args, **kwargs)
|
|
|
|
|
|
|
|
|
101 |
self.API_HEADERS = {"Accept": "application/vnd.github+json"}
|
102 |
if not github_token:
|
103 |
-
|
104 |
-
"[
|
105 |
-
"For more info, see:"
|
106 |
"https://docs.github.com/authentication/keeping-your-account-and-data-secure/creating-a-personal-access-token"
|
107 |
)
|
|
|
|
|
|
|
108 |
else:
|
109 |
self.set_github_token(github_token)
|
110 |
|
111 |
def set_github_token(self, github_token):
|
112 |
self.API_HEADERS["Authorization"] = f"Bearer {github_token}"
|
113 |
-
|
|
|
|
|
|
|
114 |
|
115 |
def _sanitize_parameters(self, **kwargs):
|
116 |
_forward_kwargs = {}
|
117 |
if "max_length" in kwargs:
|
118 |
_forward_kwargs["max_length"] = kwargs["max_length"]
|
|
|
|
|
119 |
|
120 |
return {}, _forward_kwargs, {}
|
121 |
|
@@ -123,6 +135,8 @@ class RepoEmbeddingPipeline(Pipeline):
|
|
123 |
if isinstance(inputs, str):
|
124 |
inputs = (inputs,)
|
125 |
|
|
|
|
|
126 |
extracted_infos = download_and_extract(inputs, headers=self.API_HEADERS)
|
127 |
|
128 |
return extracted_infos
|
@@ -153,7 +167,7 @@ class RepoEmbeddingPipeline(Pipeline):
|
|
153 |
|
154 |
return sentence_embeddings
|
155 |
|
156 |
-
def _forward(self, extracted_infos, max_length=512):
|
157 |
repo_dataset = {}
|
158 |
num_texts = sum(
|
159 |
len(x["funcs"]) + len(x["docs"]) for x in extracted_infos.values()
|
@@ -163,14 +177,20 @@ class RepoEmbeddingPipeline(Pipeline):
|
|
163 |
pbar.set_description(f"Processing {repo_name}")
|
164 |
entry = {"topics": repo_info.get("topics")}
|
165 |
|
166 |
-
|
|
|
|
|
|
|
167 |
|
168 |
code_embeddings = []
|
169 |
for func in repo_info["funcs"]:
|
170 |
code_embeddings.append(
|
171 |
[func, self.encode(func, max_length).squeeze().tolist()]
|
172 |
)
|
|
|
173 |
pbar.update(1)
|
|
|
|
|
174 |
|
175 |
entry["code_embeddings"] = code_embeddings
|
176 |
entry["mean_code_embedding"] = (
|
@@ -184,7 +204,10 @@ class RepoEmbeddingPipeline(Pipeline):
|
|
184 |
doc_embeddings.append(
|
185 |
[doc, self.encode(doc, max_length).squeeze().tolist()]
|
186 |
)
|
|
|
187 |
pbar.update(1)
|
|
|
|
|
188 |
|
189 |
entry["doc_embeddings"] = doc_embeddings
|
190 |
entry["mean_doc_embedding"] = (
|
|
|
6 |
import numpy as np
|
7 |
import requests
|
8 |
import torch
|
9 |
+
from tqdm.auto import tqdm
|
10 |
from transformers import Pipeline
|
11 |
|
12 |
|
|
|
96 |
|
97 |
|
98 |
class RepoEmbeddingPipeline(Pipeline):
|
99 |
+
def __init__(self, github_token=None, st_messager=None, *args, **kwargs):
|
100 |
super().__init__(*args, **kwargs)
|
101 |
+
|
102 |
+
# Streamlit single element container created by st.empty()
|
103 |
+
self.st_messager = st_messager
|
104 |
+
|
105 |
self.API_HEADERS = {"Accept": "application/vnd.github+json"}
|
106 |
if not github_token:
|
107 |
+
message = (
|
108 |
+
"[*] Consider setting GitHub token to avoid hitting rate limits. \n"
|
109 |
+
"For more info, see: "
|
110 |
"https://docs.github.com/authentication/keeping-your-account-and-data-secure/creating-a-personal-access-token"
|
111 |
)
|
112 |
+
print(message)
|
113 |
+
if self.st_messager:
|
114 |
+
self.st_messager.info(message)
|
115 |
else:
|
116 |
self.set_github_token(github_token)
|
117 |
|
118 |
def set_github_token(self, github_token):
|
119 |
self.API_HEADERS["Authorization"] = f"Bearer {github_token}"
|
120 |
+
message = "[+] GitHub token set"
|
121 |
+
print(message)
|
122 |
+
if self.st_messager:
|
123 |
+
self.st_messager.success(message)
|
124 |
|
125 |
def _sanitize_parameters(self, **kwargs):
|
126 |
_forward_kwargs = {}
|
127 |
if "max_length" in kwargs:
|
128 |
_forward_kwargs["max_length"] = kwargs["max_length"]
|
129 |
+
if "st_progress" in kwargs:
|
130 |
+
_forward_kwargs["st_progress"] = kwargs["st_progress"]
|
131 |
|
132 |
return {}, _forward_kwargs, {}
|
133 |
|
|
|
135 |
if isinstance(inputs, str):
|
136 |
inputs = (inputs,)
|
137 |
|
138 |
+
if self.st_messager:
|
139 |
+
self.st_messager.info("[*] Downloading and extracting repos...")
|
140 |
extracted_infos = download_and_extract(inputs, headers=self.API_HEADERS)
|
141 |
|
142 |
return extracted_infos
|
|
|
167 |
|
168 |
return sentence_embeddings
|
169 |
|
170 |
+
def _forward(self, extracted_infos, max_length=512, st_progress=None):
|
171 |
repo_dataset = {}
|
172 |
num_texts = sum(
|
173 |
len(x["funcs"]) + len(x["docs"]) for x in extracted_infos.values()
|
|
|
177 |
pbar.set_description(f"Processing {repo_name}")
|
178 |
entry = {"topics": repo_info.get("topics")}
|
179 |
|
180 |
+
message = f"[*] Generating embeddings for {repo_name}"
|
181 |
+
tqdm.write(message)
|
182 |
+
if self.st_messager:
|
183 |
+
self.st_messager.info(message)
|
184 |
|
185 |
code_embeddings = []
|
186 |
for func in repo_info["funcs"]:
|
187 |
code_embeddings.append(
|
188 |
[func, self.encode(func, max_length).squeeze().tolist()]
|
189 |
)
|
190 |
+
|
191 |
pbar.update(1)
|
192 |
+
if st_progress:
|
193 |
+
st_progress.progress(pbar.n / pbar.total)
|
194 |
|
195 |
entry["code_embeddings"] = code_embeddings
|
196 |
entry["mean_code_embedding"] = (
|
|
|
204 |
doc_embeddings.append(
|
205 |
[doc, self.encode(doc, max_length).squeeze().tolist()]
|
206 |
)
|
207 |
+
|
208 |
pbar.update(1)
|
209 |
+
if st_progress:
|
210 |
+
st_progress.progress(pbar.n / pbar.total)
|
211 |
|
212 |
entry["doc_embeddings"] = doc_embeddings
|
213 |
entry["mean_doc_embedding"] = (
|