Spaces:
Runtime error
Runtime error
"""Tree of Thoughts reasoning implementation with advanced tree exploration.""" | |
import logging | |
from typing import Dict, Any, List, Optional, Set, Tuple | |
import json | |
from dataclasses import dataclass | |
from enum import Enum | |
import heapq | |
from collections import defaultdict | |
from .base import ReasoningStrategy | |
class NodeType(Enum): | |
"""Types of nodes in the thought tree.""" | |
ROOT = "root" | |
HYPOTHESIS = "hypothesis" | |
EVIDENCE = "evidence" | |
ANALYSIS = "analysis" | |
SYNTHESIS = "synthesis" | |
EVALUATION = "evaluation" | |
CONCLUSION = "conclusion" | |
class TreeNode: | |
"""Represents a node in the thought tree.""" | |
id: str | |
type: NodeType | |
content: str | |
confidence: float | |
children: List['TreeNode'] | |
parent: Optional['TreeNode'] | |
metadata: Dict[str, Any] | |
depth: int | |
evaluation_score: float = 0.0 | |
class TreeOfThoughtsStrategy(ReasoningStrategy): | |
""" | |
Advanced Tree of Thoughts reasoning implementation with: | |
- Beam search for path exploration | |
- Dynamic node evaluation | |
- Pruning strategies | |
- Path optimization | |
- Meta-learning from tree patterns | |
""" | |
def __init__(self, | |
min_confidence: float = 0.7, | |
parallel_threshold: int = 3, | |
learning_rate: float = 0.1, | |
strategy_weights: Optional[Dict[str, float]] = None): | |
self.min_confidence = min_confidence | |
self.parallel_threshold = parallel_threshold | |
self.learning_rate = learning_rate | |
self.strategy_weights = strategy_weights or { | |
"LOCAL_LLM": 0.8, | |
"CHAIN_OF_THOUGHT": 0.6, | |
"TREE_OF_THOUGHTS": 0.5, | |
"META_LEARNING": 0.4 | |
} | |
self.node_history: Dict[str, TreeNode] = {} | |
self.path_patterns: Dict[str, float] = defaultdict(float) | |
async def reason(self, query: str, context: Dict[str, Any]) -> Dict[str, Any]: | |
"""Main reasoning method implementing tree of thoughts.""" | |
try: | |
# Initialize root node | |
root = await self._create_root_node(query, context) | |
# Build and explore tree | |
tree = await self._build_tree(root, context) | |
# Find best paths | |
paths = await self._find_best_paths(tree, context) | |
# Synthesize conclusion | |
conclusion = await self._synthesize_conclusion(paths, context) | |
# Update history and patterns | |
self._update_history(tree) | |
self._update_patterns(paths) | |
return { | |
"success": True, | |
"answer": conclusion["answer"], | |
"confidence": conclusion["confidence"], | |
"tree": self._tree_to_dict(tree), | |
"best_paths": [self._path_to_dict(p) for p in paths], | |
"reasoning_trace": conclusion["trace"], | |
"meta_insights": conclusion["meta_insights"] | |
} | |
except Exception as e: | |
logging.error(f"Error in tree of thoughts reasoning: {str(e)}") | |
return {"success": False, "error": str(e)} | |
async def _create_root_node(self, query: str, context: Dict[str, Any]) -> TreeNode: | |
"""Create the root node of the thought tree.""" | |
prompt = f""" | |
Initialize root thought node for query: | |
Query: {query} | |
Context: {json.dumps(context)} | |
Provide: | |
1. Initial problem decomposition | |
2. Key aspects to explore | |
3. Evaluation criteria | |
4. Success metrics | |
Format as: | |
[Root] | |
Decomposition: ... | |
Aspects: ... | |
Criteria: ... | |
Metrics: ... | |
""" | |
response = await context["groq_api"].predict(prompt) | |
return self._parse_root_node(response["answer"], query) | |
async def _build_tree(self, root: TreeNode, context: Dict[str, Any]) -> TreeNode: | |
"""Build and explore the thought tree.""" | |
# Initialize beam with root | |
beam = [(root.evaluation_score, root)] | |
visited: Set[str] = set() | |
for depth in range(5): | |
next_beam = [] | |
for _, node in beam: | |
if node.id in visited: | |
continue | |
visited.add(node.id) | |
# Generate child nodes | |
children = await self._generate_children(node, context) | |
# Evaluate and filter children | |
evaluated_children = await self._evaluate_nodes(children, context) | |
# Add to beam | |
for child in evaluated_children: | |
if child.evaluation_score > 0.4: | |
next_beam.append((child.evaluation_score, child)) | |
node.children.append(child) | |
# Select best nodes for next iteration | |
beam = heapq.nlargest(3, next_beam, key=lambda x: x[0]) | |
if not beam: | |
break | |
return root | |
async def _generate_children(self, parent: TreeNode, context: Dict[str, Any]) -> List[TreeNode]: | |
"""Generate child nodes for a given parent.""" | |
prompt = f""" | |
Generate child thoughts for node: | |
Parent: {json.dumps(self._node_to_dict(parent))} | |
Context: {json.dumps(context)} | |
For each child provide: | |
1. [Type]: {" | ".join([t.value for t in NodeType if t != NodeType.ROOT])} | |
2. [Content]: Main thought | |
3. [Confidence]: 0-1 score | |
4. [Rationale]: Why this follows from parent | |
5. [Potential]: Future exploration potential | |
Format as: | |
[C1] | |
Type: ... | |
Content: ... | |
Confidence: ... | |
Rationale: ... | |
Potential: ... | |
""" | |
response = await context["groq_api"].predict(prompt) | |
return self._parse_child_nodes(response["answer"], parent) | |
async def _evaluate_nodes(self, nodes: List[TreeNode], context: Dict[str, Any]) -> List[TreeNode]: | |
"""Evaluate a list of nodes.""" | |
prompt = f""" | |
Evaluate thought nodes: | |
Nodes: {json.dumps([self._node_to_dict(n) for n in nodes])} | |
Context: {json.dumps(context)} | |
For each node evaluate: | |
1. Logical coherence | |
2. Evidence support | |
3. Novelty value | |
4. Exploration potential | |
Format as: | |
[N1] | |
Coherence: 0-1 | |
Evidence: 0-1 | |
Novelty: 0-1 | |
Potential: 0-1 | |
Overall: 0-1 | |
""" | |
response = await context["groq_api"].predict(prompt) | |
return self._apply_evaluations(nodes, response["answer"]) | |
async def _find_best_paths(self, root: TreeNode, context: Dict[str, Any]) -> List[List[TreeNode]]: | |
"""Find the best paths through the tree.""" | |
paths = [] | |
current_path = [root] | |
def dfs(node: TreeNode, path: List[TreeNode]): | |
if not node.children: | |
paths.append(path[:]) | |
return | |
# Sort children by score | |
sorted_children = sorted(node.children, key=lambda x: x.evaluation_score, reverse=True) | |
# Explore top paths | |
for child in sorted_children[:3]: | |
path.append(child) | |
dfs(child, path) | |
path.pop() | |
dfs(root, current_path) | |
# Evaluate complete paths | |
evaluated_paths = await self._evaluate_paths(paths, context) | |
# Return top paths | |
return sorted(evaluated_paths, key=lambda p: sum(n.evaluation_score for n in p), reverse=True)[:3] | |
async def _synthesize_conclusion(self, paths: List[List[TreeNode]], context: Dict[str, Any]) -> Dict[str, Any]: | |
"""Synthesize final conclusion from best paths.""" | |
prompt = f""" | |
Synthesize conclusion from thought paths: | |
Paths: {json.dumps([[self._node_to_dict(n) for n in path] for path in paths])} | |
Context: {json.dumps(context)} | |
Provide: | |
1. Main conclusion | |
2. Confidence level | |
3. Reasoning trace | |
4. Supporting evidence | |
5. Alternative perspectives | |
6. Meta-insights | |
Format as: | |
[Conclusion] | |
Answer: ... | |
Confidence: ... | |
Trace: ... | |
Evidence: ... | |
Alternatives: ... | |
[Meta] | |
Insights: ... | |
Patterns: ... | |
""" | |
response = await context["groq_api"].predict(prompt) | |
return self._parse_conclusion(response["answer"]) | |
def _parse_root_node(self, response: str, query: str) -> TreeNode: | |
"""Parse root node from response.""" | |
root = TreeNode( | |
id="root", | |
type=NodeType.ROOT, | |
content=query, | |
confidence=1.0, | |
children=[], | |
parent=None, | |
metadata={}, | |
depth=0 | |
) | |
for line in response.split('\n'): | |
line = line.strip() | |
if line.startswith('Decomposition:'): | |
root.metadata["decomposition"] = line[14:].strip() | |
elif line.startswith('Aspects:'): | |
root.metadata["aspects"] = [a.strip() for a in line[8:].split(',')] | |
elif line.startswith('Criteria:'): | |
root.metadata["criteria"] = [c.strip() for c in line[9:].split(',')] | |
elif line.startswith('Metrics:'): | |
root.metadata["metrics"] = [m.strip() for m in line[8:].split(',')] | |
return root | |
def _parse_child_nodes(self, response: str, parent: TreeNode) -> List[TreeNode]: | |
"""Parse child nodes from response.""" | |
children = [] | |
current = None | |
for line in response.split('\n'): | |
line = line.strip() | |
if not line: | |
continue | |
if line.startswith('[C'): | |
if current: | |
children.append(current) | |
current = None | |
elif line.startswith('Type:'): | |
type_str = line[5:].strip() | |
try: | |
node_type = NodeType(type_str.lower()) | |
current = TreeNode( | |
id=f"{parent.id}_{len(children)}", | |
type=node_type, | |
content="", | |
confidence=0.0, | |
children=[], | |
parent=parent, | |
metadata={}, | |
depth=parent.depth + 1 | |
) | |
except ValueError: | |
logging.warning(f"Invalid node type: {type_str}") | |
elif current: | |
if line.startswith('Content:'): | |
current.content = line[8:].strip() | |
elif line.startswith('Confidence:'): | |
try: | |
current.confidence = float(line[11:].strip()) | |
except: | |
current.confidence = 0.5 | |
elif line.startswith('Rationale:'): | |
current.metadata["rationale"] = line[10:].strip() | |
elif line.startswith('Potential:'): | |
current.metadata["potential"] = line[10:].strip() | |
if current: | |
children.append(current) | |
return children | |
def _apply_evaluations(self, nodes: List[TreeNode], response: str) -> List[TreeNode]: | |
"""Apply evaluation scores to nodes.""" | |
current_node_idx = 0 | |
current_scores = {} | |
for line in response.split('\n'): | |
line = line.strip() | |
if not line: | |
continue | |
if line.startswith('[N'): | |
if current_scores and current_node_idx < len(nodes): | |
nodes[current_node_idx].evaluation_score = current_scores.get("Overall", 0.0) | |
nodes[current_node_idx].metadata.update(current_scores) | |
current_node_idx += 1 | |
current_scores = {} | |
elif ':' in line: | |
key, value = line.split(':') | |
try: | |
current_scores[key.strip()] = float(value.strip()) | |
except: | |
pass | |
if current_scores and current_node_idx < len(nodes): | |
nodes[current_node_idx].evaluation_score = current_scores.get("Overall", 0.0) | |
nodes[current_node_idx].metadata.update(current_scores) | |
return nodes | |
async def _evaluate_paths(self, paths: List[List[TreeNode]], context: Dict[str, Any]) -> List[List[TreeNode]]: | |
"""Evaluate complete reasoning paths.""" | |
prompt = f""" | |
Evaluate complete reasoning paths: | |
Paths: {json.dumps([[self._node_to_dict(n) for n in path] for path in paths])} | |
Context: {json.dumps(context)} | |
For each path evaluate: | |
1. Coherence of progression | |
2. Evidence support | |
3. Conclusion strength | |
4. Novel insights | |
Format as: | |
[P1] | |
Coherence: 0-1 | |
Evidence: 0-1 | |
Conclusion: 0-1 | |
Insights: 0-1 | |
Overall: 0-1 | |
""" | |
response = await context["groq_api"].predict(prompt) | |
scores = self._parse_path_scores(response["answer"]) | |
# Apply scores to paths | |
for i, path in enumerate(paths): | |
if i < len(scores): | |
for node in path: | |
node.evaluation_score *= scores[i] | |
return paths | |
def _parse_path_scores(self, response: str) -> List[float]: | |
"""Parse path evaluation scores.""" | |
scores = [] | |
current_score = None | |
for line in response.split('\n'): | |
line = line.strip() | |
if not line: | |
continue | |
if line.startswith('[P'): | |
if current_score is not None: | |
scores.append(current_score) | |
current_score = None | |
elif line.startswith('Overall:'): | |
try: | |
current_score = float(line[8:].strip()) | |
except: | |
current_score = 0.5 | |
if current_score is not None: | |
scores.append(current_score) | |
return scores | |
def _parse_conclusion(self, response: str) -> Dict[str, Any]: | |
"""Parse final conclusion.""" | |
conclusion = { | |
"answer": "", | |
"confidence": 0.0, | |
"trace": [], | |
"evidence": [], | |
"alternatives": [], | |
"meta_insights": [] | |
} | |
section = None | |
for line in response.split('\n'): | |
line = line.strip() | |
if not line: | |
continue | |
if line.startswith('[Conclusion]'): | |
section = "conclusion" | |
elif line.startswith('[Meta]'): | |
section = "meta" | |
elif section == "conclusion": | |
if line.startswith('Answer:'): | |
conclusion["answer"] = line[7:].strip() | |
elif line.startswith('Confidence:'): | |
try: | |
conclusion["confidence"] = float(line[11:].strip()) | |
except: | |
conclusion["confidence"] = 0.5 | |
elif line.startswith('Trace:'): | |
conclusion["trace"] = [t.strip() for t in line[6:].split(',')] | |
elif line.startswith('Evidence:'): | |
conclusion["evidence"] = [e.strip() for e in line[9:].split(',')] | |
elif line.startswith('Alternatives:'): | |
conclusion["alternatives"] = [a.strip() for a in line[13:].split(',')] | |
elif section == "meta": | |
if line.startswith('Insights:'): | |
conclusion["meta_insights"].extend([i.strip() for i in line[9:].split(',')]) | |
return conclusion | |
def _node_to_dict(self, node: TreeNode) -> Dict[str, Any]: | |
"""Convert node to dictionary for serialization.""" | |
return { | |
"id": node.id, | |
"type": node.type.value, | |
"content": node.content, | |
"confidence": node.confidence, | |
"evaluation_score": node.evaluation_score, | |
"metadata": node.metadata, | |
"depth": node.depth | |
} | |
def _tree_to_dict(self, root: TreeNode) -> Dict[str, Any]: | |
"""Convert entire tree to dictionary.""" | |
def convert_node(node: TreeNode) -> Dict[str, Any]: | |
node_dict = self._node_to_dict(node) | |
node_dict["children"] = [convert_node(c) for c in node.children] | |
return node_dict | |
return convert_node(root) | |
def _path_to_dict(self, path: List[TreeNode]) -> List[Dict[str, Any]]: | |
"""Convert path to dictionary.""" | |
return [self._node_to_dict(n) for n in path] | |
def _update_history(self, root: TreeNode): | |
"""Update node history.""" | |
def add_to_history(node: TreeNode): | |
self.node_history[node.id] = node | |
for child in node.children: | |
add_to_history(child) | |
add_to_history(root) | |
def _update_patterns(self, paths: List[List[TreeNode]]): | |
"""Update path patterns.""" | |
for path in paths: | |
pattern = "->".join(n.type.value for n in path) | |
self.path_patterns[pattern] += path[-1].evaluation_score | |
def get_node_history(self) -> Dict[str, Dict[str, Any]]: | |
"""Get history of all nodes.""" | |
return {k: self._node_to_dict(v) for k, v in self.node_history.items()} | |
def get_successful_patterns(self) -> Dict[str, float]: | |
"""Get successful reasoning patterns.""" | |
return dict(sorted(self.path_patterns.items(), key=lambda x: x[1], reverse=True)) | |
def clear_history(self): | |
"""Clear node history and patterns.""" | |
self.node_history.clear() | |
self.path_patterns.clear() | |