Henry65 commited on
Commit
058bda7
·
1 Parent(s): a4a0af8

Update RepoPipeline.py

Browse files
Files changed (1) hide show
  1. RepoPipeline.py +2 -2
RepoPipeline.py CHANGED
@@ -309,12 +309,12 @@ class RepoPipeline(Pipeline):
309
  info["mean_readme_embedding"] = torch.mean(readme_embeddings, dim=0).cpu().numpy()
310
 
311
  # Repo-level mean embedding
312
- info["mean_repo_embedding"] = torch.cat([
313
  info["mean_code_embedding"],
314
  info["mean_doc_embedding"],
315
  info["mean_requirement_embedding"],
316
  info["mean_readme_embedding"]
317
- ], dim=1)
318
 
319
  # TODO Remove test
320
  info["code_embeddings_shape"] = info["code_embeddings"].shape
 
309
  info["mean_readme_embedding"] = torch.mean(readme_embeddings, dim=0).cpu().numpy()
310
 
311
  # Repo-level mean embedding
312
+ info["mean_repo_embedding"] = np.concatenate([
313
  info["mean_code_embedding"],
314
  info["mean_doc_embedding"],
315
  info["mean_requirement_embedding"],
316
  info["mean_readme_embedding"]
317
+ ], axis=1)
318
 
319
  # TODO Remove test
320
  info["code_embeddings_shape"] = info["code_embeddings"].shape