File size: 8,991 Bytes
f03ee14 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 |
import torch
import hashlib
import numpy as np
class ParameterMemoryBank:
"""
Parameter Memory Bank (PMB) for infinite, queryable memory.
This implementation uses a two-level hashing system for constant-time
direct access and supports semantic similarity search.
- Level 1: A list of 'blocks'.
- Level 2: Each block is a dictionary-like structure mapping slots to items.
For simplicity, we use Python lists and dictionaries. A production system
would use a more optimized backend (e.g., Redis, custom memory store).
"""
def __init__(self, num_blocks=1024, slots_per_block=4096, embedding_dim=None):
self.num_blocks = num_blocks
self.slots_per_block = slots_per_block
self.embedding_dim = embedding_dim
# PMB is a list of blocks, where each block is a list of slots.
# Each slot can hold a tuple: (id, key_embedding, value)
self.pmb = [ [None] * slots_per_block for _ in range(num_blocks) ]
# For semantic search, we need a separate structure to hold all keys.
# This is a trade-off for efficient similarity search.
self.all_keys = []
self.key_locations = [] # Stores (block_idx, slot_idx) for each key
def _hash_fn(self, s, salt=""):
"""A simple, salted hash function."""
return int(hashlib.sha256((str(s) + salt).encode()).hexdigest(), 16)
def _get_hash_indices(self, item_id):
"""
Calculates the block and slot indices for a given item ID using
the two-level hashing scheme.
"""
block_hash = self._hash_fn(item_id, salt="block")
block_idx = block_hash % self.num_blocks
slot_hash = self._hash_fn(item_id, salt=f"slot_{block_idx}")
slot_idx = slot_hash % self.slots_per_block
return block_idx, slot_idx
def store(self, item_id, key_embedding, value):
"""
Stores a key-value pair in the PMB using its ID.
Args:
item_id (str or int): A unique identifier for the data.
key_embedding (torch.Tensor): The embedding vector (k_i,j).
value (any): The data to store (v_i,j), e.g., text, metadata.
"""
if not isinstance(key_embedding, torch.Tensor):
raise TypeError("key_embedding must be a torch.Tensor")
block_idx, slot_idx = self._get_hash_indices(item_id)
# Store the item in the hash-based location.
# Note: This simple implementation doesn't handle hash collisions.
# A real system would need a collision resolution strategy (e.g., cuckoo hashing, chaining).
if self.pmb[block_idx][slot_idx] is not None:
# Handle collision by updating the existing entry or finding an empty slot
pass # For now, just overwrite
self.pmb[block_idx][slot_idx] = (item_id, key_embedding.detach().cpu(), value.detach().cpu() if isinstance(value, torch.Tensor) else value)
# Also store the key for semantic search
self.all_keys.append(key_embedding.detach().cpu())
self.key_locations.append((block_idx, slot_idx))
def retrieve_direct(self, item_id):
"""
Retrieves a value directly using its ID in O(1) time.
Args:
item_id (str or int): The unique identifier of the item.
Returns:
The stored value, or None if not found.
"""
block_idx, slot_idx = self._get_hash_indices(item_id)
item = self.pmb[block_idx][slot_idx]
# Check if the found item ID matches, in case of no collision handling
if item and item[0] == item_id:
return item[2] # Return the value
return None
def retrieve_by_indices(self, indices):
"""
Retrieves items by their indices in the `all_keys` list.
Args:
indices (list or torch.Tensor): A list of indices.
Returns:
A list of the retrieved values.
"""
results = []
for idx in indices:
if idx < len(self.key_locations):
block_idx, slot_idx = self.key_locations[idx]
item = self.pmb[block_idx][slot_idx]
if item:
value = item[2] # Get the value
# Convert back to tensor if it was stored as tensor
if isinstance(value, torch.Tensor):
results.append(value)
else:
# If value is not a tensor, create a zero tensor of appropriate size
if self.embedding_dim:
results.append(torch.zeros(self.embedding_dim))
else:
# Fallback: use the key embedding as value
results.append(item[1]) # Use key embedding
else:
# No item found, append zero tensor
if self.embedding_dim:
results.append(torch.zeros(self.embedding_dim))
else:
results.append(torch.zeros_like(self.all_keys[0]) if self.all_keys else torch.zeros(1))
else:
# Index out of range
if self.embedding_dim:
results.append(torch.zeros(self.embedding_dim))
else:
results.append(torch.zeros_like(self.all_keys[0]) if self.all_keys else torch.zeros(1))
return results
def retrieve_semantic(self, query_embeddings, top_k=1):
"""
Retrieves the top_k most semantically similar items for a batch of query embeddings.
Args:
query_embeddings (torch.Tensor): Query vectors (batch_size, embedding_dim) or (batch_size, seq_len, embedding_dim).
top_k (int): The number of similar items to return for each query.
Returns:
A tensor of the aggregated retrieved values with the same shape as query_embeddings.
"""
if not self.all_keys or top_k == 0:
return torch.zeros_like(query_embeddings)
if not isinstance(query_embeddings, torch.Tensor):
raise TypeError("query_embeddings must be a torch.Tensor")
# Store original shape and device
original_shape = query_embeddings.shape
device = query_embeddings.device
# Flatten query embeddings to 2D for processing
if query_embeddings.dim() > 2:
query_flat = query_embeddings.view(-1, original_shape[-1])
else:
query_flat = query_embeddings
# Handle empty memory bank
if not self.all_keys:
return torch.zeros_like(query_embeddings)
try:
# Stack all keys into a single tensor
all_keys_tensor = torch.stack(self.all_keys, dim=0).to(device)
# Compute cosine similarity
query_norm = torch.nn.functional.normalize(query_flat, p=2, dim=-1)
keys_norm = torch.nn.functional.normalize(all_keys_tensor, p=2, dim=-1)
# Compute similarities: (batch_size, num_keys)
similarities = torch.mm(query_norm, keys_norm.T)
# Get top_k results for each query
k = min(top_k, len(self.all_keys))
if k > 0:
top_k_scores, top_k_indices = torch.topk(similarities, k=k, dim=1)
# Retrieve the corresponding values
batch_results = []
for i in range(query_flat.size(0)):
retrieved_values = self.retrieve_by_indices(top_k_indices[i].cpu().tolist())
if retrieved_values:
# Stack and move to correct device
stacked_values = torch.stack(retrieved_values, dim=0).to(device)
# Average the top_k retrieved values
aggregated_value = torch.mean(stacked_values, dim=0)
batch_results.append(aggregated_value)
else:
# No valid retrievals, use zero tensor
batch_results.append(torch.zeros(original_shape[-1], device=device))
# Stack all batch results
if batch_results:
result = torch.stack(batch_results, dim=0)
# Reshape back to original shape
return result.view(original_shape)
else:
return torch.zeros_like(query_embeddings)
else:
return torch.zeros_like(query_embeddings)
except Exception as e:
print(f"Error in PMB retrieve_semantic: {e}")
return torch.zeros_like(query_embeddings)
def __len__(self):
return len(self.all_keys) |