RepoSim4Py / RepoPipeline.py
HenryStephen's picture
Update pipeline progress bar
7b66718
from typing import Dict, Any, List
import ast
import tarfile
import torch
import requests
import numpy as np
from ast import AsyncFunctionDef, ClassDef, FunctionDef, Module
from transformers import Pipeline
from tqdm.auto import tqdm
def extract_code_and_docs(text: str):
"""
The method for extracting codes and docs in text.
:param text: python file.
:return: codes and docs set.
"""
code_set = set()
docs_set = set()
root = ast.parse(text)
for node in ast.walk(root):
if not isinstance(node, (AsyncFunctionDef, FunctionDef, ClassDef, Module)):
continue
docs = ast.get_docstring(node)
node_without_docs = node
if docs is not None:
docs_set.add(docs)
# Remove docstrings from the node
node_without_docs.body = node_without_docs.body[1:]
if isinstance(node, (AsyncFunctionDef, FunctionDef)):
code_set.add(ast.unparse(node_without_docs))
return code_set, docs_set
def extract_readmes(file_content):
"""
The method for extracting readmes.
:param lines: readmes.
:return: readme sentences.
"""
readmes_set = set()
lines = file_content.split('\n')
for line in lines:
line = line.replace("\n", "").strip()
readmes_set.add(line)
return readmes_set
def extract_requirements(file_content):
"""
The method for extracting requirements.
:param lines: requirements.
:return: requirement libraries.
"""
requirements_set = set()
lines = file_content.split('\n')
for line in lines:
line = line.replace("\n", "").strip()
try:
if " == " in line:
splitLine = line.split(" == ")
else:
splitLine = line.split("==")
requirements_set.add(splitLine[0])
except:
pass
return requirements_set
def get_metadata(repo_name, headers=None):
"""
The method for getting metadata of repository from github_api.
:param repo_name: repository name.
:param headers: request headers.
:return: response json.
"""
api_url = f"https://api.github.com/repos/{repo_name}"
tqdm.write(f"[+] Getting metadata for {repo_name}")
try:
response = requests.get(api_url, headers=headers)
response.raise_for_status()
return response.json()
except requests.exceptions.HTTPError as e:
tqdm.write(f"[-] Failed to retrieve metadata from {repo_name}: {e}")
return {}
def extract_information(repos, headers=None):
"""
The method for extracting repositories information.
:param repos: repositories.
:param headers: request header.
:return: a list for representing the information of each repository.
"""
extracted_infos = []
for repo_name in tqdm(repos, disable=len(repos) <= 1):
# 1. Extracting metadata.
metadata = get_metadata(repo_name, headers=headers)
repo_info = {
"name": repo_name,
"codes": set(),
"docs": set(),
"requirements": set(),
"readmes": set(),
"topics": [],
"license": "",
"stars": metadata.get("stargazers_count"),
}
if metadata.get("topics"):
repo_info["topics"] = metadata["topics"]
if metadata.get("license"):
repo_info["license"] = metadata["license"]["spdx_id"]
# Download repo tarball bytes ---- Download repository.
download_url = f"https://api.github.com/repos/{repo_name}/tarball"
tqdm.write(f"[+] Downloading {repo_name}")
try:
response = requests.get(download_url, headers=headers, stream=True)
response.raise_for_status()
except requests.exceptions.HTTPError as e:
tqdm.write(f"[-] Failed to download {repo_name}: {e}")
continue
# Extract repository files and parse them
tqdm.write(f"[+] Extracting {repo_name} info")
with tarfile.open(fileobj=response.raw, mode="r|gz") as tar:
for member in tar:
# 2. Extracting codes and docs.
if member.name.endswith(".py") and member.isfile():
try:
file_content = tar.extractfile(member).read().decode("utf-8")
# extract_code_and_docs
code_set, docs_set = extract_code_and_docs(file_content)
repo_info["codes"].update(code_set)
repo_info["docs"].update(docs_set)
except UnicodeDecodeError as e:
tqdm.write(
f"[-] UnicodeDecodeError in {member.name}, skipping: \n{e}"
)
except SyntaxError as e:
tqdm.write(f"[-] SyntaxError in {member.name}, skipping: \n{e}")
# 3. Extracting readme.
elif (member.name.endswith("README.md") or member.name.endswith("README.rst")) and member.isfile():
try:
file_content = tar.extractfile(member).read().decode("utf-8")
# extract readme
readmes_set = extract_readmes(file_content)
repo_info["readmes"].update(readmes_set)
except UnicodeDecodeError as e:
tqdm.write(
f"[-] UnicodeDecodeError in {member.name}, skipping: \n{e}"
)
except SyntaxError as e:
tqdm.write(f"[-] SyntaxError in {member.name}, skipping: \n{e}")
# 4. Extracting requirements.
elif member.name.endswith("requirements.txt") and member.isfile():
try:
file_content = tar.extractfile(member).read().decode("utf-8")
# extract readme
requirements_set = extract_requirements(file_content)
repo_info["requirements"].update(requirements_set)
except UnicodeDecodeError as e:
tqdm.write(
f"[-] UnicodeDecodeError in {member.name}, skipping: \n{e}"
)
except SyntaxError as e:
tqdm.write(f"[-] SyntaxError in {member.name}, skipping: \n{e}")
extracted_infos.append(repo_info)
return extracted_infos
class RepoPipeline(Pipeline):
"""
A custom pipeline for generating series of embeddings of a repository.
"""
def __init__(self, github_token=None, *args, **kwargs):
"""
The initial method for pipeline.
:param github_token: github_token
:param args: args
:param kwargs: kwargs
"""
super().__init__(*args, **kwargs)
# Getting github token
self.github_token = github_token
if self.github_token:
print("[+] GitHub token set!")
else:
print(
"[*] Please set GitHub token to avoid unexpected errors. \n"
"For more info, see: "
"https://docs.github.com/authentication/keeping-your-account-and-data-secure/creating-a-personal-access-token"
)
def _sanitize_parameters(self, **pipeline_parameters):
"""
The method for splitting parameters.
:param pipeline_parameters: parameters
:return: different parameters of different periods.
"""
# The parameters of "preprocess" period.
preprocess_parameters = {}
if "github_token" in pipeline_parameters:
preprocess_parameters["github_token"] = pipeline_parameters["github_token"]
# The parameters of "forward" period.
forward_parameters = {}
if "max_length" in pipeline_parameters:
forward_parameters["max_length"] = pipeline_parameters["max_length"]
# The parameters of "postprocess" period.
postprocess_parameters = {}
return preprocess_parameters, forward_parameters, postprocess_parameters
def preprocess(self, input_: Any, github_token=None) -> List:
"""
The method for "preprocess" period.
:param input_: the input.
:param github_token: github_token.
:return: a list about repository information.
"""
# Making input to list format.
if isinstance(input_, str):
input_ = [input_]
# Building headers.
headers = {"Accept": "application/vnd.github+json"}
token = github_token or self.github_token
if token:
headers["Authorization"] = f"Bearer {token}"
# Getting repositories' information: input_ means series of repositories (can be only one repository).
extracted_infos = extract_information(input_, headers=headers)
return extracted_infos
def encode(self, text, max_length):
"""
The method for encoding the text to embedding by using UniXcoder.
:param text: text.
:param max_length: the max length.
:return: the embedding of text.
"""
assert max_length < 1024
# Getting the tokenizer.
tokenizer = self.tokenizer
tokens = (
[tokenizer.cls_token, "<encoder-only>", tokenizer.sep_token]
+ tokenizer.tokenize(text)[: max_length - 4]
+ [tokenizer.sep_token]
)
tokens_id = tokenizer.convert_tokens_to_ids(tokens)
source_ids = torch.tensor([tokens_id]).to(self.device)
token_embeddings = self.model(source_ids)[0]
# Getting the text embedding.
sentence_embeddings = token_embeddings.mean(dim=1)
return sentence_embeddings
def generate_embeddings(self, text_sets, max_length):
"""
The method for generating embeddings of a text set.
:param text_sets: text set.
:param max_length: max length.
:return: the embeddings of text set.
"""
assert max_length < 1024
# Concat the embeddings of each sentence/text in vertical dimension.
return torch.zeros((1, 768), device=self.device) \
if not text_sets \
else torch.cat([self.encode(text, max_length) for text in text_sets], dim=0)
def _forward(self, extracted_infos: List, max_length=512, st_progress=None) -> List:
"""
The method for "forward" period.
:param extracted_infos: the information of repositories.
:param max_length: max length.
:return: the output of this pipeline.
"""
model_outputs = []
# The number of repository.
num_texts = sum(
len(x["codes"]) + len(x["docs"]) + len(x["requirements"]) + len(x["readmes"]) for x in extracted_infos)
with tqdm(total=num_texts) as progress_bar:
# For each repository
for repo_info in extracted_infos:
repo_name = repo_info["name"]
info = {
"name": repo_name,
"topics": repo_info["topics"],
"license": repo_info["license"],
"stars": repo_info["stars"],
}
progress_bar.set_description(f"Processing {repo_name}")
# Code embeddings
tqdm.write(f"[*] Generating code embeddings for {repo_name}")
code_embeddings = self.generate_embeddings(repo_info["codes"], max_length)
info["code_embeddings"] = code_embeddings.cpu().numpy()
info["mean_code_embedding"] = torch.mean(code_embeddings, dim=0, keepdim=True).cpu().numpy()
progress_bar.update(len(repo_info["codes"]))
if st_progress:
st_progress.progress(progress_bar.n / progress_bar.total)
# Doc embeddings
tqdm.write(f"[*] Generating doc embeddings for {repo_name}")
doc_embeddings = self.generate_embeddings(repo_info["docs"], max_length)
info["doc_embeddings"] = doc_embeddings.cpu().numpy()
info["mean_doc_embedding"] = torch.mean(doc_embeddings, dim=0, keepdim=True).cpu().numpy()
progress_bar.update(len(repo_info["docs"]))
if st_progress:
st_progress.progress(progress_bar.n / progress_bar.total)
# Requirement embeddings
tqdm.write(f"[*] Generating requirement embeddings for {repo_name}")
requirement_embeddings = self.generate_embeddings(repo_info["requirements"], max_length)
info["requirement_embeddings"] = requirement_embeddings.cpu().numpy()
info["mean_requirement_embedding"] = torch.mean(requirement_embeddings, dim=0,
keepdim=True).cpu().numpy()
progress_bar.update(len(repo_info["requirements"]))
if st_progress:
st_progress.progress(progress_bar.n / progress_bar.total)
# Readme embeddings
tqdm.write(f"[*] Generating readme embeddings for {repo_name}")
readme_embeddings = self.generate_embeddings(repo_info["readmes"], max_length)
info["readme_embeddings"] = readme_embeddings.cpu().numpy()
info["mean_readme_embedding"] = torch.mean(readme_embeddings, dim=0, keepdim=True).cpu().numpy()
progress_bar.update(len(repo_info["readmes"]))
if st_progress:
st_progress.progress(progress_bar.n / progress_bar.total)
# Repo-level mean embedding
info["mean_repo_embedding"] = np.concatenate([
info["mean_code_embedding"],
info["mean_doc_embedding"],
info["mean_requirement_embedding"],
info["mean_readme_embedding"]
], axis=0).reshape(1, -1)
info["code_embeddings_shape"] = info["code_embeddings"].shape
info["mean_code_embedding_shape"] = info["mean_code_embedding"].shape
info["doc_embeddings_shape"] = info["doc_embeddings"].shape
info["mean_doc_embedding_shape"] = info["mean_doc_embedding"].shape
info["requirement_embeddings_shape"] = info["requirement_embeddings"].shape
info["mean_requirement_embedding_shape"] = info["mean_requirement_embedding"].shape
info["readme_embeddings_shape"] = info["readme_embeddings"].shape
info["mean_readme_embedding_shape"] = info["mean_readme_embedding"].shape
info["mean_repo_embedding_shape"] = info["mean_repo_embedding"].shape
model_outputs.append(info)
return model_outputs
def postprocess(self, model_outputs: List, **postprocess_parameters: Dict) -> List:
"""
The method for "postprocess" period.
:param model_outputs: the output of this pipeline.
:param postprocess_parameters: the parameters of "postprocess" period.
:return: model output.
"""
return model_outputs