AIencoder commited on
Commit
58a2c61
·
verified ·
1 Parent(s): 2cea58f

Update forgekit/ai_advisor.py

Browse files
Files changed (1) hide show
  1. forgekit/ai_advisor.py +80 -135
forgekit/ai_advisor.py CHANGED
@@ -1,224 +1,169 @@
1
- """AI-powered merge advisor using HuggingFace Inference API."""
2
 
3
- import json
4
  import requests
5
  from typing import Optional
6
 
7
- HF_INFERENCE_URL = "https://api-inference.huggingface.co/models"
8
- DEFAULT_MODEL = "mistralai/Mistral-7B-Instruct-v0.3"
9
 
10
 
11
- def _query_llm(
12
  prompt: str,
13
  system: str = "",
14
  model: str = DEFAULT_MODEL,
15
- token: Optional[str] = None,
16
- max_tokens: int = 800,
17
  ) -> str:
18
- """Query an LLM via HF Inference API.
19
 
20
  Args:
21
  prompt: User message
22
  system: System prompt
23
- model: HF model ID for inference
24
- token: HF API token (recommended for higher rate limits)
25
  max_tokens: Max response length
26
 
27
  Returns:
28
  Generated text response
29
  """
30
- headers = {"Content-Type": "application/json"}
31
- if token:
32
- headers["Authorization"] = f"Bearer {token}"
 
 
 
 
 
 
33
 
34
- # Format as chat messages
35
  messages = []
36
  if system:
37
  messages.append({"role": "system", "content": system})
38
  messages.append({"role": "user", "content": prompt})
39
 
 
 
 
 
40
  payload = {
41
- "inputs": _format_chat(messages, model),
42
- "parameters": {
43
- "max_new_tokens": max_tokens,
44
- "temperature": 0.7,
45
- "do_sample": True,
46
- "return_full_text": False,
47
- },
48
  }
49
 
50
  try:
51
- resp = requests.post(
52
- f"{HF_INFERENCE_URL}/{model}",
53
- headers=headers,
54
- json=payload,
55
- timeout=60,
56
- )
57
-
58
- if resp.status_code == 503:
59
- # Model loading
60
- return "⏳ The AI model is loading (this can take 1-2 minutes on first use). Please try again shortly."
61
 
62
  if resp.status_code == 429:
63
- return "⚠️ Rate limited — please wait a moment and try again, or add your HF token for higher limits."
64
-
 
65
  if resp.status_code != 200:
66
- return f"⚠️ AI service returned status {resp.status_code}. Try again or add an HF token."
67
 
68
  data = resp.json()
69
- if isinstance(data, list) and len(data) > 0:
70
- text = data[0].get("generated_text", "")
71
- # Clean up any leftover template tokens
72
- for tag in ["</s>", "<|im_end|>", "<|eot_id|>", "[/INST]"]:
73
- text = text.replace(tag, "")
74
- return text.strip()
75
-
76
- return "⚠️ No response generated. The model may be overloaded — try again."
77
 
78
  except requests.exceptions.Timeout:
79
- return "⚠️ Request timed out. The model may be loading — try again in a minute."
80
  except Exception as e:
81
- return f"⚠️ Error: {str(e)}"
82
-
83
-
84
- def _format_chat(messages: list[dict], model: str) -> str:
85
- """Format messages into the model's expected chat template."""
86
- # Mistral Instruct format
87
- if "mistral" in model.lower() or "mixtral" in model.lower():
88
- parts = []
89
- for msg in messages:
90
- if msg["role"] == "system":
91
- parts.append(f"[INST] {msg['content']}\n")
92
- elif msg["role"] == "user":
93
- if parts:
94
- parts.append(f"{msg['content']} [/INST]")
95
- else:
96
- parts.append(f"[INST] {msg['content']} [/INST]")
97
- return "".join(parts)
98
-
99
- # Generic ChatML fallback
100
- parts = []
101
- for msg in messages:
102
- parts.append(f"<|im_start|>{msg['role']}\n{msg['content']}<|im_end|>")
103
- parts.append("<|im_start|>assistant\n")
104
- return "\n".join(parts)
105
 
106
 
107
- # ===== AI FEATURES =====
108
 
109
- ADVISOR_SYSTEM = """You are ForgeKit AI, an expert assistant for merging large language models. You have deep knowledge of mergekit, model architectures, merge methods (DARE-TIES, TIES, SLERP, Linear, Task Arithmetic, Passthrough), and best practices for creating high-quality merged models.
110
 
111
- Be concise, practical, and specific. Give actionable recommendations with concrete numbers (weights, densities). Format your response with clear sections using markdown."""
 
 
 
112
 
 
 
 
 
 
113
 
114
  def merge_advisor(
115
  models_text: str,
116
  goal: str = "",
117
- token: Optional[str] = None,
118
  ) -> str:
119
- """AI recommends the best merge method, weights, and configuration.
120
-
121
- Args:
122
- models_text: Newline-separated model IDs
123
- goal: What the user wants the merged model to do
124
- token: HF API token
125
-
126
- Returns:
127
- AI recommendation as markdown
128
- """
129
  models = [m.strip() for m in models_text.strip().split("\n") if m.strip()]
130
  if len(models) < 2:
131
- return "⚠️ Add at least 2 models to get a recommendation."
132
 
133
  models_str = "\n".join(f"- {m}" for m in models)
134
- goal_str = f"\n\nUser's goal: {goal}" if goal.strip() else ""
135
 
136
  prompt = f"""I want to merge these models:
137
  {models_str}
138
  {goal_str}
139
 
140
- Recommend:
141
- 1. **Best merge method** and why (DARE-TIES, SLERP, Linear, TIES, Task Arithmetic, or Passthrough)
142
- 2. **Optimal weights** for each model (with reasoning)
143
- 3. **Density values** if applicable
144
- 4. **Which model to use as base** and why
145
  5. **Which tokenizer** to keep
146
- 6. **Any warnings** or tips specific to these models
147
-
148
- Be specific with numbers and keep it practical."""
149
 
150
- return _query_llm(prompt, system=ADVISOR_SYSTEM, token=token)
151
 
152
 
153
  def model_describer(
154
  models_text: str,
155
  method: str = "",
156
  weights_text: str = "",
157
- token: Optional[str] = None,
158
  ) -> str:
159
- """AI explains what the merged model will be good at.
160
-
161
- Args:
162
- models_text: Newline-separated model IDs
163
- method: Merge method being used
164
- weights_text: Comma-separated weights
165
- token: HF API token
166
-
167
- Returns:
168
- AI description of expected capabilities
169
- """
170
  models = [m.strip() for m in models_text.strip().split("\n") if m.strip()]
171
  if not models:
172
- return "⚠️ Add models first."
173
 
174
  models_str = "\n".join(f"- {m}" for m in models)
175
- method_str = f" using {method}" if method else ""
176
  weights_str = f"\nWeights: {weights_text}" if weights_text.strip() else ""
177
 
178
  prompt = f"""I'm merging these models{method_str}:
179
  {models_str}{weights_str}
180
 
181
- Based on what each source model is known for, describe:
182
- 1. **What the merged model will excel at** (specific tasks/benchmarks)
183
- 2. **What it might struggle with** compared to the source models
184
  3. **Ideal use cases** for this merge
185
- 4. **Expected quality** compared to each individual model
186
- 5. **A creative name suggestion** for this merge
187
 
188
- Keep it concise and practical."""
189
-
190
- return _query_llm(prompt, system=ADVISOR_SYSTEM, token=token)
191
 
192
 
193
  def config_explainer(
194
  yaml_config: str,
195
- token: Optional[str] = None,
196
  ) -> str:
197
- """AI explains a YAML merge config in plain English.
198
-
199
- Args:
200
- yaml_config: The YAML configuration string
201
- token: HF API token
202
-
203
- Returns:
204
- Plain English explanation
205
- """
206
  if not yaml_config.strip() or yaml_config.startswith("# Add"):
207
- return "⚠️ Generate a YAML config first."
208
 
209
- prompt = f"""Explain this mergekit YAML configuration in plain English. Break it down so someone new to model merging can understand exactly what will happen:
210
 
211
  ```yaml
212
  {yaml_config}
213
  ```
214
 
215
- Explain:
216
- 1. **What this config does** in simple terms
217
- 2. **Why these specific settings** were chosen (method, weights, density)
218
- 3. **What the output model will be like**
219
- 4. **Any potential issues** to watch out for
220
- 5. **Estimated resource requirements** (RAM, time)
221
-
222
- Be clear and beginner-friendly."""
223
 
224
- return _query_llm(prompt, system=ADVISOR_SYSTEM, token=token)
 
1
+ """AI-powered merge advisor using Groq API (free, fast inference)."""
2
 
3
+ import os
4
  import requests
5
  from typing import Optional
6
 
7
+ GROQ_API_URL = "https://api.groq.com/openai/v1/chat/completions"
8
+ DEFAULT_MODEL = "llama-3.3-70b-versatile"
9
 
10
 
11
+ def _query_groq(
12
  prompt: str,
13
  system: str = "",
14
  model: str = DEFAULT_MODEL,
15
+ api_key: Optional[str] = None,
16
+ max_tokens: int = 1024,
17
  ) -> str:
18
+ """Query Groq's OpenAI-compatible API.
19
 
20
  Args:
21
  prompt: User message
22
  system: System prompt
23
+ model: Groq model ID
24
+ api_key: Groq API key (free at console.groq.com)
25
  max_tokens: Max response length
26
 
27
  Returns:
28
  Generated text response
29
  """
30
+ key = (api_key or "").strip() or os.environ.get("GROQ_API_KEY", "")
31
+ if not key:
32
+ return (
33
+ "**Groq API Key required** — the AI Advisor uses Groq for fast, free inference.\n\n"
34
+ "1. Go to [console.groq.com](https://console.groq.com) and sign up (free, no credit card)\n"
35
+ "2. Create an API key\n"
36
+ "3. Paste it in the field above\n\n"
37
+ "Groq gives you thousands of free requests per day with Llama 3.3 70B!"
38
+ )
39
 
 
40
  messages = []
41
  if system:
42
  messages.append({"role": "system", "content": system})
43
  messages.append({"role": "user", "content": prompt})
44
 
45
+ headers = {
46
+ "Authorization": f"Bearer {key}",
47
+ "Content-Type": "application/json",
48
+ }
49
  payload = {
50
+ "model": model,
51
+ "messages": messages,
52
+ "max_tokens": max_tokens,
53
+ "temperature": 0.7,
 
 
 
54
  }
55
 
56
  try:
57
+ resp = requests.post(GROQ_API_URL, headers=headers, json=payload, timeout=30)
 
 
 
 
 
 
 
 
 
58
 
59
  if resp.status_code == 429:
60
+ return "Rate limited — Groq free tier allows ~30 requests/min. Wait a moment and try again."
61
+ if resp.status_code == 401:
62
+ return "Invalid Groq API key. Get a free one at [console.groq.com](https://console.groq.com)."
63
  if resp.status_code != 200:
64
+ return f"Groq API error (status {resp.status_code}). Try again."
65
 
66
  data = resp.json()
67
+ text = data["choices"][0]["message"]["content"]
68
+ return text.strip()
 
 
 
 
 
 
69
 
70
  except requests.exceptions.Timeout:
71
+ return "Request timed out — try again."
72
  except Exception as e:
73
+ return f"Error: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
 
76
+ # ===== SYSTEM PROMPT =====
77
 
78
+ ADVISOR_SYSTEM = """You are ForgeKit AI, an expert assistant for merging large language models using mergekit. You have deep knowledge of:
79
 
80
+ - Model architectures (LLaMA, Qwen, Mistral, Gemma, Phi)
81
+ - Merge methods: DARE-TIES, TIES, SLERP, Linear, Task Arithmetic, Passthrough (Frankenmerge)
82
+ - Optimal weight/density configurations for different use cases
83
+ - Common pitfalls and best practices
84
 
85
+ Be concise, practical, and specific. Always give concrete numbers for weights and densities.
86
+ Format responses with markdown headers and bullet points for readability."""
87
+
88
+
89
+ # ===== AI FEATURES =====
90
 
91
  def merge_advisor(
92
  models_text: str,
93
  goal: str = "",
94
+ api_key: Optional[str] = None,
95
  ) -> str:
96
+ """AI recommends the best merge method, weights, and configuration."""
 
 
 
 
 
 
 
 
 
97
  models = [m.strip() for m in models_text.strip().split("\n") if m.strip()]
98
  if len(models) < 2:
99
+ return "Add at least 2 models (one per line) to get a recommendation."
100
 
101
  models_str = "\n".join(f"- {m}" for m in models)
102
+ goal_str = f"\n\nThe user's goal: {goal}" if goal.strip() else ""
103
 
104
  prompt = f"""I want to merge these models:
105
  {models_str}
106
  {goal_str}
107
 
108
+ Give me a specific recommendation:
109
+ 1. **Best merge method** and why
110
+ 2. **Exact weights** for each model
111
+ 3. **Density values** (if applicable)
112
+ 4. **Which model as base** and why
113
  5. **Which tokenizer** to keep
114
+ 6. **Warnings or tips** for these specific models
115
+ 7. **The complete YAML config** ready for mergekit"""
 
116
 
117
+ return _query_groq(prompt, system=ADVISOR_SYSTEM, api_key=api_key)
118
 
119
 
120
  def model_describer(
121
  models_text: str,
122
  method: str = "",
123
  weights_text: str = "",
124
+ api_key: Optional[str] = None,
125
  ) -> str:
126
+ """AI predicts what the merged model will be good at."""
 
 
 
 
 
 
 
 
 
 
127
  models = [m.strip() for m in models_text.strip().split("\n") if m.strip()]
128
  if not models:
129
+ return "Add models first."
130
 
131
  models_str = "\n".join(f"- {m}" for m in models)
132
+ method_str = f" using **{method}**" if method else ""
133
  weights_str = f"\nWeights: {weights_text}" if weights_text.strip() else ""
134
 
135
  prompt = f"""I'm merging these models{method_str}:
136
  {models_str}{weights_str}
137
 
138
+ Predict:
139
+ 1. **What it will excel at** specific tasks and benchmarks
140
+ 2. **What it might lose** compared to individual source models
141
  3. **Ideal use cases** for this merge
142
+ 4. **Quality estimate** vs each source model
143
+ 5. **A creative name suggestion** for this merged model"""
144
 
145
+ return _query_groq(prompt, system=ADVISOR_SYSTEM, api_key=api_key)
 
 
146
 
147
 
148
  def config_explainer(
149
  yaml_config: str,
150
+ api_key: Optional[str] = None,
151
  ) -> str:
152
+ """AI explains a YAML merge config in plain English."""
 
 
 
 
 
 
 
 
153
  if not yaml_config.strip() or yaml_config.startswith("# Add"):
154
+ return "Generate or paste a YAML config first."
155
 
156
+ prompt = f"""Explain this mergekit config in plain English for a beginner:
157
 
158
  ```yaml
159
  {yaml_config}
160
  ```
161
 
162
+ Cover:
163
+ 1. **What this does** in simple terms
164
+ 2. **Why these settings** explain each parameter
165
+ 3. **What the output will be like**
166
+ 4. **Potential issues** to watch for
167
+ 5. **Resource requirements** (RAM, time, Colab tier)"""
 
 
168
 
169
+ return _query_groq(prompt, system=ADVISOR_SYSTEM, api_key=api_key)