sugiv commited on
Commit
e2ca773
·
1 Parent(s): 53a7e0c

Adding a simple monkey search for Leetcode - Darn LeetMonkey

Browse files
Files changed (2) hide show
  1. app.py +94 -0
  2. requirements.txt +11 -0
app.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from pinecone import Pinecone, ServerlessSpec
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
+ import torch
5
+ from pinecone_text.sparse import SpladeEncoder
6
+ from sentence_transformers import SentenceTransformer
7
+ import transformers
8
+ transformers.logging.set_verbosity_error()
9
+
10
+ # Initialize Pinecone
11
+ PINECONE_API_KEY = "your_pinecone_api_key"
12
+ pc = Pinecone(api_key=PINECONE_API_KEY)
13
+
14
+ index_name = "leetmonkey-sparse-dense"
15
+ index = pc.Index(index_name)
16
+
17
+ # Initialize models
18
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
19
+ splade = SpladeEncoder(device=device)
20
+ dense_model = SentenceTransformer('sentence-transformers/all-Mpnet-base-v2', device=device)
21
+
22
+ # Load the quantized Llama 2 model and tokenizer
23
+ model_name = "TheBloke/Llama-2-7B-Chat-GPTQ"
24
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
25
+ model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
26
+
27
+ def search_problems(query, top_k=5):
28
+ dense_query = dense_model.encode([query])[0].tolist()
29
+ sparse_query = splade.encode_documents([query])[0]
30
+
31
+ results = index.query(
32
+ vector=dense_query,
33
+ sparse_vector={
34
+ 'indices': sparse_query['indices'],
35
+ 'values': sparse_query['values']
36
+ },
37
+ top_k=top_k,
38
+ include_metadata=True,
39
+ namespace='leetcode-problems'
40
+ )
41
+
42
+ return results['matches']
43
+
44
+ def generate_few_shot_prompt(search_results):
45
+ prompt = "Here are some example LeetCode problems:\n\n"
46
+ for result in search_results:
47
+ metadata = result['metadata']
48
+ prompt += f"Title: {metadata['title']}\n"
49
+ prompt += f"Topics: {', '.join(metadata['topicTags'])}\n"
50
+ prompt += f"Difficulty: {metadata['difficulty']}\n\n"
51
+ return prompt
52
+
53
+ def generate_response(user_query, top_k=5):
54
+ search_results = search_problems(user_query, top_k)
55
+ few_shot_prompt = generate_few_shot_prompt(search_results)
56
+
57
+ system_prompt = """You are an AI assistant specialized in providing information about LeetCode problems.
58
+ Your task is to recommend relevant problems based on the user's query and the provided examples.
59
+ Focus on problem titles, difficulty levels, topic tags, and companies that have asked these problems.
60
+ Do not provide specific problem solutions or content."""
61
+
62
+ user_prompt = f"Based on the following query, recommend relevant LeetCode problems:\n{user_query}"
63
+ full_prompt = f"{system_prompt}\n\n{few_shot_prompt}\n{user_prompt}\n\nRecommendations:"
64
+
65
+ input_ids = tokenizer.encode(full_prompt, return_tensors="pt").to(model.device)
66
+ attention_mask = torch.ones_like(input_ids)
67
+
68
+ with torch.no_grad():
69
+ output = model.generate(
70
+ input_ids,
71
+ attention_mask=attention_mask,
72
+ max_new_tokens=250,
73
+ do_sample=True,
74
+ top_p=0.9,
75
+ temperature=0.7,
76
+ num_return_sequences=1,
77
+ pad_token_id=tokenizer.eos_token_id
78
+ )
79
+
80
+ response = tokenizer.decode(output[0], skip_special_tokens=True)
81
+ recommendations = response.split("Recommendations:")[1].strip()
82
+ return recommendations
83
+
84
+ # Create a Gradio interface
85
+ iface = gr.Interface(
86
+ fn=generate_response,
87
+ inputs=gr.inputs.Textbox(lines=2, placeholder="Enter your query about LeetCode problems..."),
88
+ outputs="text",
89
+ title="LeetCode Problem Assistant",
90
+ description="Ask about LeetCode problems and get structured responses based on titles, topics, and difficulty levels."
91
+ )
92
+
93
+ # Launch the app
94
+ iface.launch()
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio==3.50.2
2
+ pinecone-client==2.2.4
3
+ transformers==4.36.2
4
+ torch==2.1.2
5
+ sentence-transformers==2.2.2
6
+ pinecone-text==0.7.0
7
+ numpy==1.26.2
8
+ pandas==2.1.4
9
+ networkx==3.2.1
10
+ matplotlib==3.8.2
11
+ torch-geometric==2.4.0