Spaces:
Sleeping
Sleeping
File size: 5,442 Bytes
0efb4de |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import chromadb
import sqlite3
import hashlib
import pandas as pd
from sentence_transformers import SentenceTransformer
#--- Initialize ChromaDB and SentenceTransformer ---
SCHEMA_DESCRIPTIONS = {
"restaurants": """Table restaurants contains restaurant details:
- id: unique identifier
- name: restaurant name
- cuisine: type of cuisine
- location: area or neighborhood
- seating_capacity: total seats
- rating: average rating
- address: full address
- contact: phone or email
- price_range: price category
- special_features: amenities or highlights""",
"tables": """Table tables contains table details:
- id: unique identifier
- restaurant_id: links to restaurants.id
- capacity: number of seats (default 4)""",
"slots": """Table slots contains reservation time slots:
- id: unique identifier
- table_id: links to tables.id
- date: reservation date
- hour: reservation hour
- is_reserved: 0=available, 1=booked"""
}
class SchemaVectorDB:
def __init__(self):
self.client = chromadb.Client()
self.collection = self.client.get_or_create_collection("schema")
self.model = SentenceTransformer('all-MiniLM-L6-v2')
for idx, (name, desc) in enumerate(SCHEMA_DESCRIPTIONS.items()):
self.collection.add(ids=str(idx), documents=desc, metadatas={"name": name})
def get_relevant_schema(self, query, k=2):
query_embedding = self.model.encode(query).tolist()
results = self.collection.query(query_embeddings=[query_embedding], n_results=k)
# results['metadatas'] is a list of lists: [[{...}, {...}], ...]
# We only have one query, so grab the first list
metadatas = results['metadatas'][0] if results['metadatas'] else []
return [m['name'] for m in metadatas if m and 'name' in m]
class FullVectorDB:
def __init__(self):
self.client = chromadb.PersistentClient(path="db/chroma")
self.model = SentenceTransformer('all-MiniLM-L6-v2')
# Get existing collections or create if not exist
self.restaurants_col = self.client.get_or_create_collection("restaurants")
self.tables_col = self.client.get_or_create_collection("tables")
self.slots_col = self.client.get_or_create_collection("slots")
# Initialize only if collections are empty
if len(self.restaurants_col.get()['ids']) == 0:
self._initialize_collections()
def _row_to_text(self, row):
return ' '.join(str(v) for v in row.values if pd.notnull(v))
def _row_hash(self, row):
return hashlib.sha256(str(row.values).encode()).hexdigest()
def _initialize_collections(self):
conn = sqlite3.connect("db/restaurant_reservation.db")
# Create external changelog table
conn.execute("""
CREATE TABLE IF NOT EXISTS chroma_changelog (
id INTEGER PRIMARY KEY,
table_name TEXT,
record_id INTEGER,
content_hash TEXT,
UNIQUE(table_name, record_id)
)
""")
conn.commit()
# Process tables
self._process_table(conn, "restaurants", self.restaurants_col)
self._process_table(conn, "tables", self.tables_col)
self._process_table(conn, "slots", self.slots_col)
conn.close()
def _process_table(self, conn, table_name, collection):
# Get existing records from Chroma
existing_ids = set(collection.get()['ids'])
# Get all records from SQLite with hash
df = pd.read_sql(f"SELECT * FROM {table_name}", conn)
# Process each row
for _, row in df.iterrows():
chroma_id = f"{table_name}_{row['id']}"
current_hash = self._row_hash(row)
# Check if exists in changelog
changelog = pd.read_sql(f"""
SELECT content_hash
FROM chroma_changelog
WHERE table_name = ? AND record_id = ?
""", conn, params=(table_name, row['id']))
# Skip if hash matches
if not changelog.empty and changelog.iloc[0]['content_hash'] == current_hash:
continue
# Generate embedding
embedding = self.model.encode(self._row_to_text(row))
# Update Chroma
collection.upsert(
ids=[chroma_id],
embeddings=[embedding.tolist()],
metadatas=[row.to_dict()]
)
# Update changelog
conn.execute("""
INSERT OR REPLACE INTO chroma_changelog
(table_name, record_id, content_hash)
VALUES (?, ?, ?)
""", (table_name, row['id'], current_hash))
conn.commit()
def semantic_search(self, query, collection_name, k=5):
query_embedding = self.model.encode(query).tolist()
collection = getattr(self, f"{collection_name}_col")
results = collection.query(
query_embeddings=[query_embedding],
n_results=k,
include=["metadatas"]
)
return results['metadatas'][0]
|