rogermt commited on
Commit
b9fb03c
·
verified ·
1 Parent(s): d979316

Add Lightning.ai multi-GPU notebook with Ollama + Qwen2.5-Coder (auto-shards across GPUs)

Browse files
Files changed (1) hide show
  1. notebooks/pemf_llm_lightning.ipynb +303 -0
notebooks/pemf_llm_lightning.ipynb ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# PEMF ARC-AGI — LLM Solver (Lightning.ai / Multi-GPU)\n",
8
+ "\n",
9
+ "Runs Ollama with auto multi-GPU sharding for local inference.\n",
10
+ "\n",
11
+ "| GPU Config | Model | VRAM | Quality |\n",
12
+ "|---|---|---|---|\n",
13
+ "| 2xA10G (48GB) | qwen2.5-coder:32b | ~20GB q4 | Best |\n",
14
+ "| 2xL4 (48GB) | qwen2.5-coder:32b | ~20GB q4 | Best |\n",
15
+ "| 2xT4 (32GB) | qwen2.5-coder:14b | ~10GB q4 | Good |\n",
16
+ "| 1xA10G (24GB) | qwen2.5-coder:14b | ~10GB | Good |\n",
17
+ "| 4xA10G (96GB) | qwen2.5-coder:32b fp16 | ~65GB | Best+fast |"
18
+ ]
19
+ },
20
+ {
21
+ "cell_type": "code",
22
+ "execution_count": null,
23
+ "metadata": {},
24
+ "outputs": [],
25
+ "source": [
26
+ "# ============ CONFIGURATION ============\n",
27
+ "MODEL = 'qwen2.5-coder:32b'\n",
28
+ "# MODEL = 'qwen2.5-coder:14b' # fallback for less VRAM\n",
29
+ "N_CANDIDATES = 8"
30
+ ]
31
+ },
32
+ {
33
+ "cell_type": "code",
34
+ "execution_count": null,
35
+ "metadata": {},
36
+ "outputs": [],
37
+ "source": [
38
+ "import subprocess, os, time, json, re, glob\n",
39
+ "import numpy as np, urllib.request\n",
40
+ "from collections import Counter\n",
41
+ "\n",
42
+ "# Check GPUs\n",
43
+ "!nvidia-smi --query-gpu=index,name,memory.total --format=csv,noheader\n",
44
+ "gpu_count = len(subprocess.run(['nvidia-smi','-L'], capture_output=True, text=True).stdout.strip().split('\\n'))\n",
45
+ "print(f'GPUs: {gpu_count}')"
46
+ ]
47
+ },
48
+ {
49
+ "cell_type": "code",
50
+ "execution_count": null,
51
+ "metadata": {},
52
+ "outputs": [],
53
+ "source": [
54
+ "# Install Ollama\n",
55
+ "try:\n",
56
+ " subprocess.run(['ollama','--version'], capture_output=True, check=True)\n",
57
+ " print('Ollama installed')\n",
58
+ "except: \n",
59
+ " !curl -fsSL https://ollama.com/install.sh | sh\n",
60
+ "\n",
61
+ "# Start server (auto-detects all GPUs)\n",
62
+ "subprocess.run(['pkill','-f','ollama'], capture_output=True)\n",
63
+ "time.sleep(2)\n",
64
+ "env = os.environ.copy()\n",
65
+ "env['CUDA_VISIBLE_DEVICES'] = ','.join(str(i) for i in range(gpu_count))\n",
66
+ "server = subprocess.Popen(['ollama','serve'],\n",
67
+ " stdout=open('/tmp/ollama.log','w'), stderr=subprocess.STDOUT, env=env)\n",
68
+ "time.sleep(5)\n",
69
+ "print(f'Server PID {server.pid}, GPUs: {env[\"CUDA_VISIBLE_DEVICES\"]}')\n",
70
+ "\n",
71
+ "# Pull model\n",
72
+ "print(f'Pulling {MODEL}...')\n",
73
+ "r = subprocess.run(['ollama','pull',MODEL], capture_output=True, text=True, timeout=3600)\n",
74
+ "if r.returncode != 0:\n",
75
+ " print(f'Failed, trying 14b...'); MODEL='qwen2.5-coder:14b'\n",
76
+ " subprocess.run(['ollama','pull',MODEL], capture_output=True, text=True, timeout=3600)\n",
77
+ "print(f'{MODEL} ready')\n",
78
+ "\n",
79
+ "# Test\n",
80
+ "r = subprocess.run(['ollama','run',MODEL,'Say hello'], capture_output=True, text=True, timeout=60)\n",
81
+ "print(f'Test: {r.stdout.strip()[:80]}')\n",
82
+ "!nvidia-smi --query-gpu=index,memory.used,memory.total --format=csv,noheader"
83
+ ]
84
+ },
85
+ {
86
+ "cell_type": "code",
87
+ "execution_count": null,
88
+ "metadata": {},
89
+ "outputs": [],
90
+ "source": [
91
+ "# Download ARC data\n",
92
+ "if not os.path.exists('arc_data/training'):\n",
93
+ " !git clone --depth 1 https://github.com/fchollet/ARC-AGI.git /tmp/arc\n",
94
+ " os.makedirs('arc_data', exist_ok=True)\n",
95
+ " !cp -r /tmp/arc/data/training arc_data/training\n",
96
+ "print(f'Tasks: {len(glob.glob(\"arc_data/training/*.json\"))}')\n",
97
+ "\n",
98
+ "ALREADY_SOLVED = {\n",
99
+ " '007bbfb7','00d62c1b','0d3d703e','1190e5a7','1cf80156','1e0a9b12','1f85a75f',\n",
100
+ " '2013d3e2','22168020','22eb0ac0','239be575','23b5c85d','28bf18c6','2dee498d',\n",
101
+ " '3618c87e','3906de3d','3aa6fb7a','3af2c5a8','3c9b0459','42a50994','4347f46a',\n",
102
+ " '50cb2852','6150a2bd','62c24649','67385a82','67a3c6ac','67e8384a','68b16354',\n",
103
+ " '6d0aefbc','6f8cd79b','6fa7a44f','746b3537','74dd1130','7b7f7511','7e0986d6',\n",
104
+ " '7f4411dc','868de0fa','8be77c9e','8d5021e8','91714a58','9172f3a0','9565186b',\n",
105
+ " '9dfd6313','a416b8f3','a5313dff','a699fb00','aabf363d','aedd82e4','b1948b0a',\n",
106
+ " 'b6afb2da','ba97ae07','bb43febb','bda2d7a6','be94b721','c0f76784','c59eb873',\n",
107
+ " 'c8f0f002','c9e6f938','d10ecb37','d23f8c26','d511f180','d631b094','d90796e8',\n",
108
+ " 'd9fac9be','de1cd16c','ded97339','e26a3af2','eb5a1d5d','ed36ccf7','f76d97a5',\n",
109
+ "}\n",
110
+ "task_files = sorted(glob.glob('arc_data/training/*.json'))\n",
111
+ "unsolved = [(os.path.basename(f).replace('.json',''),f) for f in task_files\n",
112
+ " if os.path.basename(f).replace('.json','') not in ALREADY_SOLVED]\n",
113
+ "print(f'Symbolic: {len(ALREADY_SOLVED)}, LLM to try: {len(unsolved)}')"
114
+ ]
115
+ },
116
+ {
117
+ "cell_type": "code",
118
+ "execution_count": null,
119
+ "metadata": {},
120
+ "outputs": [],
121
+ "source": [
122
+ "# LLM Engine\n",
123
+ "def call_ollama(prompt, model, temperature=0.7):\n",
124
+ " payload = {'model':model,'prompt':prompt,'stream':False,\n",
125
+ " 'options':{'temperature':temperature,'num_predict':2048}}\n",
126
+ " req = urllib.request.Request('http://localhost:11434/api/generate',\n",
127
+ " data=json.dumps(payload).encode(), headers={'Content-Type':'application/json'}, method='POST')\n",
128
+ " try:\n",
129
+ " with urllib.request.urlopen(req, timeout=180) as resp:\n",
130
+ " return json.loads(resp.read().decode()).get('response','')\n",
131
+ " except Exception as e: return f'ERROR: {e}'\n",
132
+ "\n",
133
+ "def build_prompt(task):\n",
134
+ " pairs = task.get('train',[])\n",
135
+ " ex = '\\n'.join(f\"Example {i+1}:\\n Input: {json.dumps(p['input'])}\\n Output: {json.dumps(p['output'])}\"\n",
136
+ " for i,p in enumerate(pairs))\n",
137
+ " inps = [np.array(p['input']) for p in pairs]\n",
138
+ " outs = [np.array(p['output']) for p in pairs]\n",
139
+ " same = all(i.shape==o.shape for i,o in zip(inps,outs))\n",
140
+ " ic = sorted(set(c for i in inps for c in np.unique(i).tolist()))\n",
141
+ " oc = sorted(set(c for o in outs for c in np.unique(o).tolist()))\n",
142
+ " a = f\" Same shape: {same}\\n Colors in: {ic}, out: {oc}\\n\"\n",
143
+ " if not same: a += f\" Shape: {inps[0].shape} -> {outs[0].shape}\\n\"\n",
144
+ " return f\"\"\"Solve this ARC-AGI puzzle. Write ONLY a Python function, no explanations.\n",
145
+ "\n",
146
+ "{ex}\n",
147
+ "\n",
148
+ "Analysis:\n",
149
+ "{a}\n",
150
+ "```python\n",
151
+ "import numpy as np\n",
152
+ "from collections import Counter, deque\n",
153
+ "\n",
154
+ "def transform(grid: list[list[int]]) -> list[list[int]]:\n",
155
+ " grid = np.array(grid)\n",
156
+ "\"\"\"\n",
157
+ "\n",
158
+ "def extract_code(resp):\n",
159
+ " for pat in [r'```python\\s*(.*?)```', r'```\\s*(.*?)```']:\n",
160
+ " for m in re.findall(pat, resp, re.DOTALL):\n",
161
+ " if 'def transform' in m: return m.strip()\n",
162
+ " idx = resp.find('def transform')\n",
163
+ " if idx >= 0:\n",
164
+ " before = resp[:idx]\n",
165
+ " s = max(before.rfind('import '), before.rfind('from '))\n",
166
+ " code = resp[s if s>=0 else idx:]\n",
167
+ " end = code.find('```')\n",
168
+ " if end>0: code=code[:end]\n",
169
+ " return code.strip()\n",
170
+ " s = resp.strip()\n",
171
+ " if s.startswith(('import','def transform','from')): return s\n",
172
+ " return None\n",
173
+ "\n",
174
+ "def verify(code, pairs):\n",
175
+ " ns = {'np':np,'numpy':np,'Counter':Counter,'deque':__import__('collections').deque}\n",
176
+ " try:\n",
177
+ " import scipy.ndimage; ns['scipy']=__import__('scipy')\n",
178
+ " except: pass\n",
179
+ " try: exec(code, ns)\n",
180
+ " except: return False\n",
181
+ " if 'transform' not in ns: return False\n",
182
+ " fn = ns['transform']\n",
183
+ " for p in pairs:\n",
184
+ " try:\n",
185
+ " r = np.array(fn([row[:] for row in p['input']]), dtype=int)\n",
186
+ " e = np.array(p['output'], dtype=int)\n",
187
+ " if r.shape!=e.shape or not np.array_equal(r,e): return False\n",
188
+ " except: return False\n",
189
+ " return True\n",
190
+ "\n",
191
+ "def apply_prog(code, inp):\n",
192
+ " ns = {'np':np,'numpy':np,'Counter':Counter,'deque':__import__('collections').deque}\n",
193
+ " try:\n",
194
+ " import scipy.ndimage; ns['scipy']=__import__('scipy')\n",
195
+ " except: pass\n",
196
+ " try:\n",
197
+ " exec(code, ns)\n",
198
+ " r = ns['transform']([row[:] for row in inp])\n",
199
+ " if r is not None: return np.array(r,dtype=int).tolist()\n",
200
+ " except: pass\n",
201
+ " return None\n",
202
+ "\n",
203
+ "print('Engine ready')"
204
+ ]
205
+ },
206
+ {
207
+ "cell_type": "code",
208
+ "execution_count": null,
209
+ "metadata": {},
210
+ "outputs": [],
211
+ "source": [
212
+ "# Quick test\n",
213
+ "with open(f'arc_data/training/{unsolved[0][0]}.json') as f: t=json.load(f)\n",
214
+ "print(f'Test on {unsolved[0][0]}...')\n",
215
+ "s=time.time(); r=call_ollama(build_prompt(t),MODEL,0.1); e=time.time()-s\n",
216
+ "code=extract_code(r)\n",
217
+ "if code: print(f'{e:.1f}s, {len(code)}ch, verified: {\"Y\" if verify(code,t[\"train\"]) else \"N\"}')\n",
218
+ "else: print(f'{e:.1f}s, no code')\n",
219
+ "est = e*N_CANDIDATES*len(unsolved)/3600\n",
220
+ "print(f'Est total: {est:.1f}h for {len(unsolved)} tasks x {N_CANDIDATES} candidates')"
221
+ ]
222
+ },
223
+ {
224
+ "cell_type": "code",
225
+ "execution_count": null,
226
+ "metadata": {},
227
+ "outputs": [],
228
+ "source": [
229
+ "# === MAIN LOOP (crash-safe, resumable) ===\n",
230
+ "results = {}\n",
231
+ "solved = 0\n",
232
+ "total_time = 0\n",
233
+ "\n",
234
+ "if os.path.exists('llm_results.json'):\n",
235
+ " with open('llm_results.json') as f: prev=json.load(f)\n",
236
+ " results=prev.get('results',{})\n",
237
+ " solved=sum(1 for r in results.values() if r['status']=='solved')\n",
238
+ " total_time=prev.get('total_time_s',0)\n",
239
+ " print(f'Resuming: {solved} LLM-solved, {len(results)} attempted')\n",
240
+ "\n",
241
+ "for idx,(tid,tf) in enumerate(unsolved):\n",
242
+ " if tid in results: continue\n",
243
+ " with open(tf) as f: task=json.load(f)\n",
244
+ " print(f'[{idx+1:3d}/{len(unsolved)}] {tid}:',end=' ',flush=True)\n",
245
+ " s=time.time(); prompt=build_prompt(task); ok=False\n",
246
+ " for i in range(N_CANDIDATES):\n",
247
+ " temp=0.1 if i==0 else min(0.4+0.15*i,1.2)\n",
248
+ " resp=call_ollama(prompt,MODEL,temp)\n",
249
+ " if resp.startswith('ERROR:'): continue\n",
250
+ " code=extract_code(resp)\n",
251
+ " if code and verify(code,task['train']):\n",
252
+ " e=time.time()-s; total_time+=e; solved+=1\n",
253
+ " to=[apply_prog(code,t['input']) for t in task.get('test',[])]\n",
254
+ " results[tid]={'status':'solved','rule':f'llm_c{i+1}','code':code,\n",
255
+ " 'test_outputs':to,'time_s':round(e,2)}\n",
256
+ " print(f'✅ c{i+1} ({e:.1f}s) [{len(ALREADY_SOLVED)+solved}/{len(task_files)}]')\n",
257
+ " ok=True; break\n",
258
+ " if not ok:\n",
259
+ " e=time.time()-s; total_time+=e\n",
260
+ " results[tid]={'status':'failed','time_s':round(e,2)}\n",
261
+ " print(f'❌ ({e:.1f}s)')\n",
262
+ " if (idx+1)%5==0 or ok:\n",
263
+ " with open('llm_results.json','w') as f:\n",
264
+ " json.dump({'model':MODEL,'n_candidates':N_CANDIDATES,'llm_solved':solved,\n",
265
+ " 'attempted':len(results),'symbolic_solved':len(ALREADY_SOLVED),\n",
266
+ " 'total_solved':len(ALREADY_SOLVED)+solved,'total_tasks':len(task_files),\n",
267
+ " 'solve_rate':round(100*(len(ALREADY_SOLVED)+solved)/len(task_files),2),\n",
268
+ " 'total_time_s':round(total_time,1),'results':results},f,indent=2)"
269
+ ]
270
+ },
271
+ {
272
+ "cell_type": "code",
273
+ "execution_count": null,
274
+ "metadata": {},
275
+ "outputs": [],
276
+ "source": [
277
+ "# Final save + summary\n",
278
+ "with open('llm_results.json','w') as f:\n",
279
+ " json.dump({'model':MODEL,'n_candidates':N_CANDIDATES,'llm_solved':solved,\n",
280
+ " 'attempted':len(results),'symbolic_solved':len(ALREADY_SOLVED),\n",
281
+ " 'total_solved':len(ALREADY_SOLVED)+solved,'total_tasks':len(task_files),\n",
282
+ " 'solve_rate':round(100*(len(ALREADY_SOLVED)+solved)/len(task_files),2),\n",
283
+ " 'total_time_s':round(total_time,1),'results':results},f,indent=2)\n",
284
+ "\n",
285
+ "print(f'\\n{\"=\"*60}')\n",
286
+ "print(f'LLM solved: {solved}')\n",
287
+ "print(f'Symbolic: {len(ALREADY_SOLVED)}')\n",
288
+ "print(f'TOTAL: {len(ALREADY_SOLVED)+solved}/{len(task_files)} ({100*(len(ALREADY_SOLVED)+solved)/len(task_files):.1f}%)')\n",
289
+ "print(f'Time: {total_time/3600:.1f}h')\n",
290
+ "print(f'\\nDownload llm_results.json, then run:')\n",
291
+ "print(f' python scripts/merge_results.py arc_results/summary_v4.json llm_results.json')\n",
292
+ "\n",
293
+ "subprocess.run(['pkill','-f','ollama'], capture_output=True)"
294
+ ]
295
+ }
296
+ ],
297
+ "metadata": {
298
+ "kernelspec": {"display_name":"Python 3","language":"python","name":"python3"},
299
+ "language_info": {"name":"python","version":"3.10.0"}
300
+ },
301
+ "nbformat": 4,
302
+ "nbformat_minor": 4
303
+ }