| | """ |
| | Kirim-1-Math Inference Script |
| | Mathematical reasoning with tool calling capabilities |
| | """ |
| |
|
| | import torch |
| | import json |
| | import re |
| | from transformers import AutoModelForCausalLM, AutoTokenizer |
| | from typing import List, Dict, Any, Optional |
| | import warnings |
| | warnings.filterwarnings('ignore') |
| |
|
| |
|
| | class MathToolExecutor: |
| | """Execute mathematical tools called by the model""" |
| | |
| | def __init__(self): |
| | try: |
| | import sympy as sp |
| | import numpy as np |
| | self.sp = sp |
| | self.np = np |
| | except ImportError: |
| | print("Warning: SymPy or NumPy not installed. Tool execution limited.") |
| | self.sp = None |
| | self.np = None |
| | |
| | def execute_tool(self, tool_name: str, arguments: Dict[str, Any]) -> str: |
| | """Execute a tool and return results""" |
| | try: |
| | if tool_name == "calculator": |
| | return self._calculator(arguments) |
| | elif tool_name == "symbolic_solver": |
| | return self._symbolic_solver(arguments) |
| | elif tool_name == "derivative": |
| | return self._derivative(arguments) |
| | elif tool_name == "integrate": |
| | return self._integrate(arguments) |
| | elif tool_name == "simplify": |
| | return self._simplify(arguments) |
| | elif tool_name == "latex_formatter": |
| | return self._latex_formatter(arguments) |
| | else: |
| | return f"Unknown tool: {tool_name}" |
| | except Exception as e: |
| | return f"Tool execution error: {str(e)}" |
| | |
| | def _calculator(self, args: Dict) -> str: |
| | """Precise calculator""" |
| | expr = args.get("expression", "") |
| | precision = args.get("precision", 15) |
| | |
| | if not self.sp: |
| | return "SymPy not available" |
| | |
| | try: |
| | result = self.sp.sympify(expr) |
| | result = self.sp.N(result, precision) |
| | return f"Result: {result}" |
| | except Exception as e: |
| | return f"Calculation error: {e}" |
| | |
| | def _symbolic_solver(self, args: Dict) -> str: |
| | """Solve equations symbolically""" |
| | equation = args.get("equation", "") |
| | variable = args.get("variable", "x") |
| | |
| | if not self.sp: |
| | return "SymPy not available" |
| | |
| | try: |
| | var = self.sp.Symbol(variable) |
| | eq = self.sp.sympify(equation) |
| | solutions = self.sp.solve(eq, var) |
| | return f"Solutions: {solutions}" |
| | except Exception as e: |
| | return f"Solver error: {e}" |
| | |
| | def _derivative(self, args: Dict) -> str: |
| | """Calculate derivatives""" |
| | function = args.get("function", "") |
| | variable = args.get("variable", "x") |
| | order = args.get("order", 1) |
| | |
| | if not self.sp: |
| | return "SymPy not available" |
| | |
| | try: |
| | var = self.sp.Symbol(variable) |
| | func = self.sp.sympify(function) |
| | result = self.sp.diff(func, var, order) |
| | return f"Derivative: {result}" |
| | except Exception as e: |
| | return f"Derivative error: {e}" |
| | |
| | def _integrate(self, args: Dict) -> str: |
| | """Calculate integrals""" |
| | function = args.get("function", "") |
| | variable = args.get("variable", "x") |
| | lower = args.get("lower_bound") |
| | upper = args.get("upper_bound") |
| | |
| | if not self.sp: |
| | return "SymPy not available" |
| | |
| | try: |
| | var = self.sp.Symbol(variable) |
| | func = self.sp.sympify(function) |
| | |
| | if lower is not None and upper is not None: |
| | result = self.sp.integrate(func, (var, lower, upper)) |
| | else: |
| | result = self.sp.integrate(func, var) |
| | |
| | return f"Integral: {result}" |
| | except Exception as e: |
| | return f"Integration error: {e}" |
| | |
| | def _simplify(self, args: Dict) -> str: |
| | """Simplify expressions""" |
| | expression = args.get("expression", "") |
| | |
| | if not self.sp: |
| | return "SymPy not available" |
| | |
| | try: |
| | expr = self.sp.sympify(expression) |
| | result = self.sp.simplify(expr) |
| | return f"Simplified: {result}" |
| | except Exception as e: |
| | return f"Simplification error: {e}" |
| | |
| | def _latex_formatter(self, args: Dict) -> str: |
| | """Format as LaTeX""" |
| | expression = args.get("expression", "") |
| | inline = args.get("inline", False) |
| | |
| | if not self.sp: |
| | return "SymPy not available" |
| | |
| | try: |
| | expr = self.sp.sympify(expression) |
| | latex = self.sp.latex(expr) |
| | |
| | if inline: |
| | return f"${latex}$" |
| | else: |
| | return f"$$\n{latex}\n$$" |
| | except Exception as e: |
| | return f"LaTeX formatting error: {e}" |
| |
|
| |
|
| | class KirimMath: |
| | """Kirim-1-Math inference with tool calling""" |
| | |
| | def __init__( |
| | self, |
| | model_path: str = "Kirim-ai/Kirim-1-Math", |
| | device: str = "auto", |
| | load_in_8bit: bool = False, |
| | load_in_4bit: bool = False |
| | ): |
| | print(f"Loading Kirim-1-Math from {model_path}...") |
| | |
| | |
| | self.tokenizer = AutoTokenizer.from_pretrained( |
| | model_path, |
| | trust_remote_code=True, |
| | use_fast=True |
| | ) |
| | |
| | |
| | model_kwargs = { |
| | "trust_remote_code": True, |
| | "torch_dtype": torch.bfloat16, |
| | "low_cpu_mem_usage": True, |
| | } |
| | |
| | if load_in_8bit: |
| | model_kwargs["load_in_8bit"] = True |
| | print("Loading in 8-bit mode (30GB VRAM)") |
| | elif load_in_4bit: |
| | model_kwargs["load_in_4bit"] = True |
| | print("Loading in 4-bit mode (20GB VRAM)") |
| | else: |
| | print("Loading in full precision (80GB VRAM)") |
| | |
| | if device == "auto": |
| | model_kwargs["device_map"] = "auto" |
| | |
| | |
| | self.model = AutoModelForCausalLM.from_pretrained( |
| | model_path, |
| | **model_kwargs |
| | ) |
| | |
| | if device not in ["auto"] and not (load_in_8bit or load_in_4bit): |
| | self.model = self.model.to(device) |
| | |
| | self.model.eval() |
| | |
| | |
| | self.tool_executor = MathToolExecutor() |
| | |
| | print("✓ Model loaded successfully!") |
| | print("✓ Tool calling enabled\n") |
| | |
| | def solve_problem( |
| | self, |
| | problem: str, |
| | show_work: bool = True, |
| | use_tools: bool = True, |
| | max_new_tokens: int = 4096, |
| | temperature: float = 0.1 |
| | ) -> str: |
| | """ |
| | Solve a mathematical problem |
| | |
| | Args: |
| | problem: Math problem to solve |
| | show_work: Show step-by-step solution |
| | use_tools: Enable tool calling |
| | max_new_tokens: Maximum tokens to generate |
| | temperature: Sampling temperature (lower = more deterministic) |
| | |
| | Returns: |
| | Solution with reasoning |
| | """ |
| | |
| | system_prompt = "You are Kirim-1-Math, an advanced mathematical reasoning AI. " |
| | |
| | if show_work: |
| | system_prompt += "Show your work step-by-step. " |
| | |
| | if use_tools: |
| | system_prompt += "You can use tools for calculations. Available tools: calculator, symbolic_solver, derivative, integrate, simplify." |
| | |
| | messages = [ |
| | {"role": "system", "content": system_prompt}, |
| | {"role": "user", "content": problem} |
| | ] |
| | |
| | |
| | response = self._generate(messages, max_new_tokens, temperature) |
| | |
| | |
| | if use_tools and "<tool_call>" in response: |
| | response = self._handle_tool_calls(response, messages, max_new_tokens, temperature) |
| | |
| | return response |
| | |
| | def _generate(self, messages: List[Dict], max_new_tokens: int, temperature: float) -> str: |
| | """Generate response from model""" |
| | formatted_prompt = self.tokenizer.apply_chat_template( |
| | messages, |
| | tokenize=False, |
| | add_generation_prompt=True |
| | ) |
| | |
| | inputs = self.tokenizer( |
| | formatted_prompt, |
| | return_tensors="pt", |
| | truncation=True, |
| | max_length=28672 |
| | ) |
| | |
| | if hasattr(self.model, 'device'): |
| | inputs = {k: v.to(self.model.device) for k, v in inputs.items()} |
| | |
| | gen_kwargs = { |
| | "max_new_tokens": max_new_tokens, |
| | "temperature": temperature, |
| | "top_p": 0.95, |
| | "do_sample": temperature > 0, |
| | "pad_token_id": self.tokenizer.pad_token_id, |
| | "eos_token_id": self.tokenizer.eos_token_id, |
| | } |
| | |
| | with torch.no_grad(): |
| | outputs = self.model.generate(**inputs, **gen_kwargs) |
| | |
| | full_response = self.tokenizer.decode(outputs[0], skip_special_tokens=False) |
| | |
| | |
| | if "<|assistant|>" in full_response: |
| | response = full_response.split("<|assistant|>")[-1] |
| | response = response.replace("<|end_of_text|>", "").strip() |
| | return response |
| | |
| | return full_response.strip() |
| | |
| | def _handle_tool_calls(self, response: str, messages: List[Dict], max_new_tokens: int, temperature: float) -> str: |
| | """Process tool calls in response""" |
| | |
| | tool_pattern = r'<tool_call>(.*?)</tool_call>' |
| | tool_calls = re.findall(tool_pattern, response, re.DOTALL) |
| | |
| | if not tool_calls: |
| | return response |
| | |
| | |
| | for tool_call_str in tool_calls: |
| | try: |
| | tool_call = json.loads(tool_call_str.strip()) |
| | tool_name = tool_call.get("name", "") |
| | arguments = tool_call.get("arguments", {}) |
| | |
| | print(f"\n🔧 Executing tool: {tool_name}") |
| | print(f" Arguments: {arguments}") |
| | |
| | |
| | result = self.tool_executor.execute_tool(tool_name, arguments) |
| | |
| | print(f" Result: {result}\n") |
| | |
| | |
| | messages.append({"role": "assistant", "content": response}) |
| | messages.append({"role": "tool", "content": f"<tool_result>{result}</tool_result>"}) |
| | |
| | |
| | response = self._generate(messages, max_new_tokens, temperature) |
| | |
| | except json.JSONDecodeError: |
| | print(f"⚠️ Failed to parse tool call: {tool_call_str}") |
| | continue |
| | |
| | return response |
| | |
| | def interactive_math(self): |
| | """Interactive math problem solver""" |
| | print("\n" + "="*60) |
| | print(" Kirim-1-Math - Interactive Mode") |
| | print(" First model with tool calling!") |
| | print("="*60) |
| | print("\nCommands:") |
| | print(" 'quit' or 'exit' - End session") |
| | print(" 'tools off/on' - Toggle tool calling") |
| | print(" 'work off/on' - Toggle showing work") |
| | print("\n" + "="*60 + "\n") |
| | |
| | use_tools = True |
| | show_work = True |
| | |
| | while True: |
| | try: |
| | user_input = input("Problem: ").strip() |
| | |
| | if user_input.lower() in ['quit', 'exit', 'q']: |
| | print("\nGoodbye! Happy solving! 🧮\n") |
| | break |
| | |
| | if user_input.lower().startswith('tools'): |
| | use_tools = 'on' in user_input.lower() |
| | print(f"✓ Tool calling: {'enabled' if use_tools else 'disabled'}\n") |
| | continue |
| | |
| | if user_input.lower().startswith('work'): |
| | show_work = 'on' in user_input.lower() |
| | print(f"✓ Show work: {'enabled' if show_work else 'disabled'}\n") |
| | continue |
| | |
| | if not user_input: |
| | continue |
| | |
| | |
| | print("\n" + "-"*60) |
| | solution = self.solve_problem( |
| | user_input, |
| | show_work=show_work, |
| | use_tools=use_tools |
| | ) |
| | print(solution) |
| | print("-"*60 + "\n") |
| | |
| | except KeyboardInterrupt: |
| | print("\n\nGoodbye! 🧮\n") |
| | break |
| | except Exception as e: |
| | print(f"\n❌ Error: {e}\n") |
| |
|
| |
|
| | def main(): |
| | import argparse |
| | |
| | parser = argparse.ArgumentParser(description="Kirim-1-Math Inference") |
| | parser.add_argument("--model_path", type=str, default="Kirim-ai/Kirim-1-Math") |
| | parser.add_argument("--device", type=str, default="auto") |
| | parser.add_argument("--load_in_8bit", action="store_true") |
| | parser.add_argument("--load_in_4bit", action="store_true") |
| | parser.add_argument("--interactive", action="store_true") |
| | parser.add_argument("--problem", type=str, help="Single problem to solve") |
| | |
| | args = parser.parse_args() |
| | |
| | |
| | kirim_math = KirimMath( |
| | model_path=args.model_path, |
| | device=args.device, |
| | load_in_8bit=args.load_in_8bit, |
| | load_in_4bit=args.load_in_4bit |
| | ) |
| | |
| | if args.interactive: |
| | kirim_math.interactive_math() |
| | elif args.problem: |
| | solution = kirim_math.solve_problem(args.problem) |
| | print(f"\nProblem: {args.problem}") |
| | print(f"\nSolution:\n{solution}\n") |
| | else: |
| | |
| | print("="*60) |
| | print(" Demo Examples") |
| | print("="*60 + "\n") |
| | |
| | demos = [ |
| | "Solve: x² - 5x + 6 = 0", |
| | "Calculate the derivative of x³ + 2x² - x + 1", |
| | "解方程: 2x + 3y = 12, 4x - y = 5", |
| | "Integrate: ∫(x² + 1)dx" |
| | ] |
| | |
| | for problem in demos: |
| | print(f"\nProblem: {problem}") |
| | print("-" * 60) |
| | solution = kirim_math.solve_problem(problem) |
| | print(solution) |
| | print("=" * 60) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |