tobil commited on
Commit
e97d105
·
verified ·
1 Parent(s): 98a3f77

Upload evaluate_model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. evaluate_model.py +213 -0
evaluate_model.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.10"
3
+ # dependencies = [
4
+ # "transformers>=4.45.0",
5
+ # "peft>=0.7.0",
6
+ # "torch",
7
+ # "huggingface_hub",
8
+ # ]
9
+ # ///
10
+ """
11
+ Evaluate QMD query expansion model quality.
12
+
13
+ Generates expansions for test queries and outputs results for review.
14
+ """
15
+
16
+ import json
17
+ import torch
18
+ from transformers import AutoModelForCausalLM, AutoTokenizer
19
+ from peft import PeftModel
20
+
21
+ # Test queries covering different QMD use cases
22
+ TEST_QUERIES = [
23
+ # Technical documentation
24
+ "how to configure authentication",
25
+ "typescript async await",
26
+ "docker compose networking",
27
+ "git rebase vs merge",
28
+ "react useEffect cleanup",
29
+
30
+ # Short/ambiguous queries
31
+ "auth",
32
+ "config",
33
+ "setup",
34
+ "api",
35
+
36
+ # Personal notes / journals style
37
+ "meeting notes project kickoff",
38
+ "ideas for new feature",
39
+ "todo list app architecture",
40
+
41
+ # Research / learning
42
+ "what is dependency injection",
43
+ "difference between sql and nosql",
44
+ "kubernetes vs docker swarm",
45
+
46
+ # Error/debugging
47
+ "connection timeout error",
48
+ "memory leak debugging",
49
+ "cors error fix",
50
+
51
+ # Complex queries
52
+ "how to implement caching with redis in nodejs",
53
+ "best practices for api rate limiting",
54
+ "setting up ci cd pipeline with github actions",
55
+ ]
56
+
57
+ PROMPT_TEMPLATE = """You are a search query optimization expert. Transform the query into retrieval-optimized outputs.
58
+
59
+ Query: {query}
60
+
61
+ Output format:
62
+ lex: {{keyword variation}}
63
+ vec: {{semantic reformulation}}
64
+ hyde: {{hypothetical document passage}}
65
+
66
+ Output:"""
67
+
68
+
69
+ def load_model(model_name: str, base_model: str = "Qwen/Qwen3-0.6B"):
70
+ """Load the finetuned model."""
71
+ print(f"Loading tokenizer from {base_model}...")
72
+ tokenizer = AutoTokenizer.from_pretrained(base_model)
73
+ if tokenizer.pad_token is None:
74
+ tokenizer.pad_token = tokenizer.eos_token
75
+
76
+ print(f"Loading base model...")
77
+ base = AutoModelForCausalLM.from_pretrained(
78
+ base_model,
79
+ torch_dtype=torch.bfloat16,
80
+ device_map="auto",
81
+ )
82
+
83
+ print(f"Loading adapter from {model_name}...")
84
+ model = PeftModel.from_pretrained(base, model_name)
85
+ model.eval()
86
+
87
+ return model, tokenizer
88
+
89
+
90
+ def generate_expansion(model, tokenizer, query: str, max_new_tokens: int = 200) -> str:
91
+ """Generate query expansion."""
92
+ prompt = PROMPT_TEMPLATE.format(query=query)
93
+
94
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
95
+
96
+ with torch.no_grad():
97
+ outputs = model.generate(
98
+ **inputs,
99
+ max_new_tokens=max_new_tokens,
100
+ temperature=0.7,
101
+ do_sample=True,
102
+ pad_token_id=tokenizer.pad_token_id,
103
+ eos_token_id=tokenizer.eos_token_id,
104
+ )
105
+
106
+ # Decode and extract just the generated part
107
+ full_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
108
+
109
+ # Remove the prompt to get just the expansion
110
+ if "Output:" in full_output:
111
+ expansion = full_output.split("Output:")[-1].strip()
112
+ else:
113
+ expansion = full_output[len(prompt):].strip()
114
+
115
+ return expansion
116
+
117
+
118
+ def evaluate_expansion(query: str, expansion: str) -> dict:
119
+ """Basic automatic evaluation metrics."""
120
+ lines = expansion.strip().split("\n")
121
+
122
+ has_lex = any(l.strip().startswith("lex:") for l in lines)
123
+ has_vec = any(l.strip().startswith("vec:") for l in lines)
124
+ has_hyde = any(l.strip().startswith("hyde:") for l in lines)
125
+
126
+ # Count valid lines
127
+ valid_lines = sum(1 for l in lines if l.strip().startswith(("lex:", "vec:", "hyde:")))
128
+
129
+ # Check for repetition
130
+ contents = []
131
+ for l in lines:
132
+ if ":" in l:
133
+ contents.append(l.split(":", 1)[1].strip().lower())
134
+ unique_contents = len(set(contents))
135
+
136
+ return {
137
+ "has_lex": has_lex,
138
+ "has_vec": has_vec,
139
+ "has_hyde": has_hyde,
140
+ "valid_lines": valid_lines,
141
+ "total_lines": len(lines),
142
+ "unique_contents": unique_contents,
143
+ "format_score": (has_lex + has_vec + has_hyde) / 3,
144
+ }
145
+
146
+
147
+ def main():
148
+ import argparse
149
+ parser = argparse.ArgumentParser()
150
+ parser.add_argument("--model", default="tobil/qmd-query-expansion-0.6B",
151
+ help="Model to evaluate")
152
+ parser.add_argument("--base-model", default="Qwen/Qwen3-0.6B",
153
+ help="Base model")
154
+ parser.add_argument("--output", default="evaluation_results.json",
155
+ help="Output file for results")
156
+ parser.add_argument("--queries", type=str, help="Custom queries file (one per line)")
157
+ args = parser.parse_args()
158
+
159
+ # Load custom queries if provided
160
+ queries = TEST_QUERIES
161
+ if args.queries:
162
+ with open(args.queries) as f:
163
+ queries = [l.strip() for l in f if l.strip()]
164
+
165
+ # Load model
166
+ model, tokenizer = load_model(args.model, args.base_model)
167
+
168
+ # Run evaluation
169
+ results = []
170
+ print(f"\n{'='*70}")
171
+ print("EVALUATION RESULTS")
172
+ print(f"{'='*70}\n")
173
+
174
+ for i, query in enumerate(queries, 1):
175
+ print(f"[{i}/{len(queries)}] Query: {query}")
176
+ print("-" * 50)
177
+
178
+ expansion = generate_expansion(model, tokenizer, query)
179
+ metrics = evaluate_expansion(query, expansion)
180
+
181
+ print(expansion)
182
+ print(f"\n Format: {'✓' if metrics['format_score'] == 1.0 else '⚠'} "
183
+ f"(lex:{metrics['has_lex']}, vec:{metrics['has_vec']}, hyde:{metrics['has_hyde']})")
184
+ print(f" Lines: {metrics['valid_lines']}/{metrics['total_lines']} valid, "
185
+ f"{metrics['unique_contents']} unique")
186
+ print()
187
+
188
+ results.append({
189
+ "query": query,
190
+ "expansion": expansion,
191
+ "metrics": metrics,
192
+ })
193
+
194
+ # Summary
195
+ print(f"\n{'='*70}")
196
+ print("SUMMARY")
197
+ print(f"{'='*70}")
198
+
199
+ avg_format = sum(r["metrics"]["format_score"] for r in results) / len(results)
200
+ full_format = sum(1 for r in results if r["metrics"]["format_score"] == 1.0)
201
+
202
+ print(f" Total queries: {len(results)}")
203
+ print(f" Average format score: {avg_format:.2%}")
204
+ print(f" Full format compliance: {full_format}/{len(results)} ({full_format/len(results):.0%})")
205
+
206
+ # Save results
207
+ with open(args.output, "w") as f:
208
+ json.dump(results, f, indent=2)
209
+ print(f"\n Results saved to: {args.output}")
210
+
211
+
212
+ if __name__ == "__main__":
213
+ main()