| | import torch |
| | import torch.nn.functional as F |
| | from transformers import AutoTokenizer, AutoModel |
| |
|
| | |
| | MODEL_NAME = "shubharuidas/codebert-base-code-embed-mrl-langchain-langgraph" |
| |
|
| | import time |
| |
|
| | print(f"Downloading model: {MODEL_NAME}...") |
| | MAX_RETRIES = 3 |
| | for attempt in range(MAX_RETRIES): |
| | try: |
| | print(f"Attempt {attempt+1}/{MAX_RETRIES}...") |
| | tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
| | model = AutoModel.from_pretrained(MODEL_NAME) |
| | print("Model loaded successfully!") |
| | break |
| | except Exception as e: |
| | print(f"Attempt {attempt+1} failed: {e}") |
| | if attempt == MAX_RETRIES - 1: |
| | print("Failed to load model after multiple attempts.") |
| | print("Tip: Check internet connection or repo visibility.") |
| | exit(1) |
| | time.sleep(5) |
| |
|
| | |
| | query = "How to create a state graph in langgraph?" |
| | code = """ |
| | from langgraph.graph import StateGraph |
| | |
| | def create_workflow(): |
| | workflow = StateGraph(AgentState) |
| | workflow.add_node("agent", agent_node) |
| | return workflow.compile() |
| | """ |
| | irrelevant_code = "def fast_inverse_sqrt(number): return number ** -0.5" |
| |
|
| | |
| | def embed(text): |
| | inputs = tokenizer(text, return_tensors="pt", max_length=512, truncation=True) |
| | with torch.no_grad(): |
| | outputs = model(**inputs) |
| | |
| | embeddings = outputs.last_hidden_state.mean(dim=1) |
| | return F.normalize(embeddings, p=2, dim=1) |
| |
|
| | print("\nRunning Inference Test...") |
| | query_emb = embed(query) |
| | code_emb = embed(code) |
| | irrelevant_emb = embed(irrelevant_code) |
| |
|
| | |
| | sim_positive = F.cosine_similarity(query_emb, code_emb).item() |
| | sim_negative = F.cosine_similarity(query_emb, irrelevant_emb).item() |
| |
|
| | print(f"Query: '{query}'") |
| | print(f"Similarity to Relevant Code: {sim_positive:.4f} (Should be high)") |
| | print(f"Similarity to Irrelevant Code: {sim_negative:.4f} (Should be low)") |
| |
|
| | if sim_positive > sim_negative: |
| | print("\nSUCCESS: Model correctly ranks relevant code higher.") |
| | else: |
| | print("\n⚠️ WARNING: Model performance might be poor.") |
| |
|