Update forgekit/ai_advisor.py
Browse files- forgekit/ai_advisor.py +80 -135
forgekit/ai_advisor.py
CHANGED
|
@@ -1,224 +1,169 @@
|
|
| 1 |
-
"""AI-powered merge advisor using
|
| 2 |
|
| 3 |
-
import
|
| 4 |
import requests
|
| 5 |
from typing import Optional
|
| 6 |
|
| 7 |
-
|
| 8 |
-
DEFAULT_MODEL = "
|
| 9 |
|
| 10 |
|
| 11 |
-
def
|
| 12 |
prompt: str,
|
| 13 |
system: str = "",
|
| 14 |
model: str = DEFAULT_MODEL,
|
| 15 |
-
|
| 16 |
-
max_tokens: int =
|
| 17 |
) -> str:
|
| 18 |
-
"""Query
|
| 19 |
|
| 20 |
Args:
|
| 21 |
prompt: User message
|
| 22 |
system: System prompt
|
| 23 |
-
model:
|
| 24 |
-
|
| 25 |
max_tokens: Max response length
|
| 26 |
|
| 27 |
Returns:
|
| 28 |
Generated text response
|
| 29 |
"""
|
| 30 |
-
|
| 31 |
-
if
|
| 32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
-
# Format as chat messages
|
| 35 |
messages = []
|
| 36 |
if system:
|
| 37 |
messages.append({"role": "system", "content": system})
|
| 38 |
messages.append({"role": "user", "content": prompt})
|
| 39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
payload = {
|
| 41 |
-
"
|
| 42 |
-
"
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
"do_sample": True,
|
| 46 |
-
"return_full_text": False,
|
| 47 |
-
},
|
| 48 |
}
|
| 49 |
|
| 50 |
try:
|
| 51 |
-
resp = requests.post(
|
| 52 |
-
f"{HF_INFERENCE_URL}/{model}",
|
| 53 |
-
headers=headers,
|
| 54 |
-
json=payload,
|
| 55 |
-
timeout=60,
|
| 56 |
-
)
|
| 57 |
-
|
| 58 |
-
if resp.status_code == 503:
|
| 59 |
-
# Model loading
|
| 60 |
-
return "⏳ The AI model is loading (this can take 1-2 minutes on first use). Please try again shortly."
|
| 61 |
|
| 62 |
if resp.status_code == 429:
|
| 63 |
-
return "
|
| 64 |
-
|
|
|
|
| 65 |
if resp.status_code != 200:
|
| 66 |
-
return f"
|
| 67 |
|
| 68 |
data = resp.json()
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
# Clean up any leftover template tokens
|
| 72 |
-
for tag in ["</s>", "<|im_end|>", "<|eot_id|>", "[/INST]"]:
|
| 73 |
-
text = text.replace(tag, "")
|
| 74 |
-
return text.strip()
|
| 75 |
-
|
| 76 |
-
return "⚠️ No response generated. The model may be overloaded — try again."
|
| 77 |
|
| 78 |
except requests.exceptions.Timeout:
|
| 79 |
-
return "
|
| 80 |
except Exception as e:
|
| 81 |
-
return f"
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
def _format_chat(messages: list[dict], model: str) -> str:
|
| 85 |
-
"""Format messages into the model's expected chat template."""
|
| 86 |
-
# Mistral Instruct format
|
| 87 |
-
if "mistral" in model.lower() or "mixtral" in model.lower():
|
| 88 |
-
parts = []
|
| 89 |
-
for msg in messages:
|
| 90 |
-
if msg["role"] == "system":
|
| 91 |
-
parts.append(f"[INST] {msg['content']}\n")
|
| 92 |
-
elif msg["role"] == "user":
|
| 93 |
-
if parts:
|
| 94 |
-
parts.append(f"{msg['content']} [/INST]")
|
| 95 |
-
else:
|
| 96 |
-
parts.append(f"[INST] {msg['content']} [/INST]")
|
| 97 |
-
return "".join(parts)
|
| 98 |
-
|
| 99 |
-
# Generic ChatML fallback
|
| 100 |
-
parts = []
|
| 101 |
-
for msg in messages:
|
| 102 |
-
parts.append(f"<|im_start|>{msg['role']}\n{msg['content']}<|im_end|>")
|
| 103 |
-
parts.append("<|im_start|>assistant\n")
|
| 104 |
-
return "\n".join(parts)
|
| 105 |
|
| 106 |
|
| 107 |
-
# =====
|
| 108 |
|
| 109 |
-
ADVISOR_SYSTEM = """You are ForgeKit AI, an expert assistant for merging large language models. You have deep knowledge of
|
| 110 |
|
| 111 |
-
|
|
|
|
|
|
|
|
|
|
| 112 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
|
| 114 |
def merge_advisor(
|
| 115 |
models_text: str,
|
| 116 |
goal: str = "",
|
| 117 |
-
|
| 118 |
) -> str:
|
| 119 |
-
"""AI recommends the best merge method, weights, and configuration.
|
| 120 |
-
|
| 121 |
-
Args:
|
| 122 |
-
models_text: Newline-separated model IDs
|
| 123 |
-
goal: What the user wants the merged model to do
|
| 124 |
-
token: HF API token
|
| 125 |
-
|
| 126 |
-
Returns:
|
| 127 |
-
AI recommendation as markdown
|
| 128 |
-
"""
|
| 129 |
models = [m.strip() for m in models_text.strip().split("\n") if m.strip()]
|
| 130 |
if len(models) < 2:
|
| 131 |
-
return "
|
| 132 |
|
| 133 |
models_str = "\n".join(f"- {m}" for m in models)
|
| 134 |
-
goal_str = f"\n\
|
| 135 |
|
| 136 |
prompt = f"""I want to merge these models:
|
| 137 |
{models_str}
|
| 138 |
{goal_str}
|
| 139 |
|
| 140 |
-
|
| 141 |
-
1. **Best merge method** and why
|
| 142 |
-
2. **
|
| 143 |
-
3. **Density values** if applicable
|
| 144 |
-
4. **Which model
|
| 145 |
5. **Which tokenizer** to keep
|
| 146 |
-
6. **
|
| 147 |
-
|
| 148 |
-
Be specific with numbers and keep it practical."""
|
| 149 |
|
| 150 |
-
return
|
| 151 |
|
| 152 |
|
| 153 |
def model_describer(
|
| 154 |
models_text: str,
|
| 155 |
method: str = "",
|
| 156 |
weights_text: str = "",
|
| 157 |
-
|
| 158 |
) -> str:
|
| 159 |
-
"""AI
|
| 160 |
-
|
| 161 |
-
Args:
|
| 162 |
-
models_text: Newline-separated model IDs
|
| 163 |
-
method: Merge method being used
|
| 164 |
-
weights_text: Comma-separated weights
|
| 165 |
-
token: HF API token
|
| 166 |
-
|
| 167 |
-
Returns:
|
| 168 |
-
AI description of expected capabilities
|
| 169 |
-
"""
|
| 170 |
models = [m.strip() for m in models_text.strip().split("\n") if m.strip()]
|
| 171 |
if not models:
|
| 172 |
-
return "
|
| 173 |
|
| 174 |
models_str = "\n".join(f"- {m}" for m in models)
|
| 175 |
-
method_str = f" using {method}" if method else ""
|
| 176 |
weights_str = f"\nWeights: {weights_text}" if weights_text.strip() else ""
|
| 177 |
|
| 178 |
prompt = f"""I'm merging these models{method_str}:
|
| 179 |
{models_str}{weights_str}
|
| 180 |
|
| 181 |
-
|
| 182 |
-
1. **What
|
| 183 |
-
2. **What it might
|
| 184 |
3. **Ideal use cases** for this merge
|
| 185 |
-
4. **
|
| 186 |
-
5. **A creative name suggestion** for this
|
| 187 |
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
return _query_llm(prompt, system=ADVISOR_SYSTEM, token=token)
|
| 191 |
|
| 192 |
|
| 193 |
def config_explainer(
|
| 194 |
yaml_config: str,
|
| 195 |
-
|
| 196 |
) -> str:
|
| 197 |
-
"""AI explains a YAML merge config in plain English.
|
| 198 |
-
|
| 199 |
-
Args:
|
| 200 |
-
yaml_config: The YAML configuration string
|
| 201 |
-
token: HF API token
|
| 202 |
-
|
| 203 |
-
Returns:
|
| 204 |
-
Plain English explanation
|
| 205 |
-
"""
|
| 206 |
if not yaml_config.strip() or yaml_config.startswith("# Add"):
|
| 207 |
-
return "
|
| 208 |
|
| 209 |
-
prompt = f"""Explain this mergekit
|
| 210 |
|
| 211 |
```yaml
|
| 212 |
{yaml_config}
|
| 213 |
```
|
| 214 |
|
| 215 |
-
|
| 216 |
-
1. **What this
|
| 217 |
-
2. **Why these
|
| 218 |
-
3. **What the output
|
| 219 |
-
4. **
|
| 220 |
-
5. **
|
| 221 |
-
|
| 222 |
-
Be clear and beginner-friendly."""
|
| 223 |
|
| 224 |
-
return
|
|
|
|
| 1 |
+
"""AI-powered merge advisor using Groq API (free, fast inference)."""
|
| 2 |
|
| 3 |
+
import os
|
| 4 |
import requests
|
| 5 |
from typing import Optional
|
| 6 |
|
| 7 |
+
GROQ_API_URL = "https://api.groq.com/openai/v1/chat/completions"
|
| 8 |
+
DEFAULT_MODEL = "llama-3.3-70b-versatile"
|
| 9 |
|
| 10 |
|
| 11 |
+
def _query_groq(
|
| 12 |
prompt: str,
|
| 13 |
system: str = "",
|
| 14 |
model: str = DEFAULT_MODEL,
|
| 15 |
+
api_key: Optional[str] = None,
|
| 16 |
+
max_tokens: int = 1024,
|
| 17 |
) -> str:
|
| 18 |
+
"""Query Groq's OpenAI-compatible API.
|
| 19 |
|
| 20 |
Args:
|
| 21 |
prompt: User message
|
| 22 |
system: System prompt
|
| 23 |
+
model: Groq model ID
|
| 24 |
+
api_key: Groq API key (free at console.groq.com)
|
| 25 |
max_tokens: Max response length
|
| 26 |
|
| 27 |
Returns:
|
| 28 |
Generated text response
|
| 29 |
"""
|
| 30 |
+
key = (api_key or "").strip() or os.environ.get("GROQ_API_KEY", "")
|
| 31 |
+
if not key:
|
| 32 |
+
return (
|
| 33 |
+
"**Groq API Key required** — the AI Advisor uses Groq for fast, free inference.\n\n"
|
| 34 |
+
"1. Go to [console.groq.com](https://console.groq.com) and sign up (free, no credit card)\n"
|
| 35 |
+
"2. Create an API key\n"
|
| 36 |
+
"3. Paste it in the field above\n\n"
|
| 37 |
+
"Groq gives you thousands of free requests per day with Llama 3.3 70B!"
|
| 38 |
+
)
|
| 39 |
|
|
|
|
| 40 |
messages = []
|
| 41 |
if system:
|
| 42 |
messages.append({"role": "system", "content": system})
|
| 43 |
messages.append({"role": "user", "content": prompt})
|
| 44 |
|
| 45 |
+
headers = {
|
| 46 |
+
"Authorization": f"Bearer {key}",
|
| 47 |
+
"Content-Type": "application/json",
|
| 48 |
+
}
|
| 49 |
payload = {
|
| 50 |
+
"model": model,
|
| 51 |
+
"messages": messages,
|
| 52 |
+
"max_tokens": max_tokens,
|
| 53 |
+
"temperature": 0.7,
|
|
|
|
|
|
|
|
|
|
| 54 |
}
|
| 55 |
|
| 56 |
try:
|
| 57 |
+
resp = requests.post(GROQ_API_URL, headers=headers, json=payload, timeout=30)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
|
| 59 |
if resp.status_code == 429:
|
| 60 |
+
return "Rate limited — Groq free tier allows ~30 requests/min. Wait a moment and try again."
|
| 61 |
+
if resp.status_code == 401:
|
| 62 |
+
return "Invalid Groq API key. Get a free one at [console.groq.com](https://console.groq.com)."
|
| 63 |
if resp.status_code != 200:
|
| 64 |
+
return f"Groq API error (status {resp.status_code}). Try again."
|
| 65 |
|
| 66 |
data = resp.json()
|
| 67 |
+
text = data["choices"][0]["message"]["content"]
|
| 68 |
+
return text.strip()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
|
| 70 |
except requests.exceptions.Timeout:
|
| 71 |
+
return "Request timed out — try again."
|
| 72 |
except Exception as e:
|
| 73 |
+
return f"Error: {str(e)}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
|
| 75 |
|
| 76 |
+
# ===== SYSTEM PROMPT =====
|
| 77 |
|
| 78 |
+
ADVISOR_SYSTEM = """You are ForgeKit AI, an expert assistant for merging large language models using mergekit. You have deep knowledge of:
|
| 79 |
|
| 80 |
+
- Model architectures (LLaMA, Qwen, Mistral, Gemma, Phi)
|
| 81 |
+
- Merge methods: DARE-TIES, TIES, SLERP, Linear, Task Arithmetic, Passthrough (Frankenmerge)
|
| 82 |
+
- Optimal weight/density configurations for different use cases
|
| 83 |
+
- Common pitfalls and best practices
|
| 84 |
|
| 85 |
+
Be concise, practical, and specific. Always give concrete numbers for weights and densities.
|
| 86 |
+
Format responses with markdown headers and bullet points for readability."""
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
# ===== AI FEATURES =====
|
| 90 |
|
| 91 |
def merge_advisor(
|
| 92 |
models_text: str,
|
| 93 |
goal: str = "",
|
| 94 |
+
api_key: Optional[str] = None,
|
| 95 |
) -> str:
|
| 96 |
+
"""AI recommends the best merge method, weights, and configuration."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
models = [m.strip() for m in models_text.strip().split("\n") if m.strip()]
|
| 98 |
if len(models) < 2:
|
| 99 |
+
return "Add at least 2 models (one per line) to get a recommendation."
|
| 100 |
|
| 101 |
models_str = "\n".join(f"- {m}" for m in models)
|
| 102 |
+
goal_str = f"\n\nThe user's goal: {goal}" if goal.strip() else ""
|
| 103 |
|
| 104 |
prompt = f"""I want to merge these models:
|
| 105 |
{models_str}
|
| 106 |
{goal_str}
|
| 107 |
|
| 108 |
+
Give me a specific recommendation:
|
| 109 |
+
1. **Best merge method** and why
|
| 110 |
+
2. **Exact weights** for each model
|
| 111 |
+
3. **Density values** (if applicable)
|
| 112 |
+
4. **Which model as base** and why
|
| 113 |
5. **Which tokenizer** to keep
|
| 114 |
+
6. **Warnings or tips** for these specific models
|
| 115 |
+
7. **The complete YAML config** ready for mergekit"""
|
|
|
|
| 116 |
|
| 117 |
+
return _query_groq(prompt, system=ADVISOR_SYSTEM, api_key=api_key)
|
| 118 |
|
| 119 |
|
| 120 |
def model_describer(
|
| 121 |
models_text: str,
|
| 122 |
method: str = "",
|
| 123 |
weights_text: str = "",
|
| 124 |
+
api_key: Optional[str] = None,
|
| 125 |
) -> str:
|
| 126 |
+
"""AI predicts what the merged model will be good at."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
models = [m.strip() for m in models_text.strip().split("\n") if m.strip()]
|
| 128 |
if not models:
|
| 129 |
+
return "Add models first."
|
| 130 |
|
| 131 |
models_str = "\n".join(f"- {m}" for m in models)
|
| 132 |
+
method_str = f" using **{method}**" if method else ""
|
| 133 |
weights_str = f"\nWeights: {weights_text}" if weights_text.strip() else ""
|
| 134 |
|
| 135 |
prompt = f"""I'm merging these models{method_str}:
|
| 136 |
{models_str}{weights_str}
|
| 137 |
|
| 138 |
+
Predict:
|
| 139 |
+
1. **What it will excel at** — specific tasks and benchmarks
|
| 140 |
+
2. **What it might lose** compared to individual source models
|
| 141 |
3. **Ideal use cases** for this merge
|
| 142 |
+
4. **Quality estimate** vs each source model
|
| 143 |
+
5. **A creative name suggestion** for this merged model"""
|
| 144 |
|
| 145 |
+
return _query_groq(prompt, system=ADVISOR_SYSTEM, api_key=api_key)
|
|
|
|
|
|
|
| 146 |
|
| 147 |
|
| 148 |
def config_explainer(
|
| 149 |
yaml_config: str,
|
| 150 |
+
api_key: Optional[str] = None,
|
| 151 |
) -> str:
|
| 152 |
+
"""AI explains a YAML merge config in plain English."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
if not yaml_config.strip() or yaml_config.startswith("# Add"):
|
| 154 |
+
return "Generate or paste a YAML config first."
|
| 155 |
|
| 156 |
+
prompt = f"""Explain this mergekit config in plain English for a beginner:
|
| 157 |
|
| 158 |
```yaml
|
| 159 |
{yaml_config}
|
| 160 |
```
|
| 161 |
|
| 162 |
+
Cover:
|
| 163 |
+
1. **What this does** in simple terms
|
| 164 |
+
2. **Why these settings** — explain each parameter
|
| 165 |
+
3. **What the output will be like**
|
| 166 |
+
4. **Potential issues** to watch for
|
| 167 |
+
5. **Resource requirements** (RAM, time, Colab tier)"""
|
|
|
|
|
|
|
| 168 |
|
| 169 |
+
return _query_groq(prompt, system=ADVISOR_SYSTEM, api_key=api_key)
|