|
import { getRepositoryFactory } from '../db/index.js'; |
|
import { VectorEmbeddingRepository } from '../db/repositories/index.js'; |
|
import { ToolInfo } from '../types/index.js'; |
|
import { getAppDataSource, initializeDatabase } from '../db/connection.js'; |
|
import { getSmartRoutingConfig } from '../utils/smartRouting.js'; |
|
import OpenAI from 'openai'; |
|
|
|
|
|
const getOpenAIConfig = () => { |
|
const smartRoutingConfig = getSmartRoutingConfig(); |
|
return { |
|
apiKey: smartRoutingConfig.openaiApiKey, |
|
baseURL: smartRoutingConfig.openaiApiBaseUrl, |
|
embeddingModel: smartRoutingConfig.openaiApiEmbeddingModel, |
|
}; |
|
}; |
|
|
|
|
|
const EMBEDDING_DIMENSIONS = 1536; |
|
const BGE_DIMENSIONS = 1024; |
|
const FALLBACK_DIMENSIONS = 100; |
|
|
|
|
|
const getDimensionsForModel = (model: string): number => { |
|
if (model.includes('bge-m3')) { |
|
return BGE_DIMENSIONS; |
|
} else if (model.includes('text-embedding-3')) { |
|
return EMBEDDING_DIMENSIONS; |
|
} else if (model === 'fallback' || model === 'simple-hash') { |
|
return FALLBACK_DIMENSIONS; |
|
} |
|
|
|
return EMBEDDING_DIMENSIONS; |
|
}; |
|
|
|
|
|
const getOpenAIClient = () => { |
|
const config = getOpenAIConfig(); |
|
return new OpenAI({ |
|
apiKey: config.apiKey, |
|
baseURL: config.baseURL, |
|
}); |
|
}; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async function generateEmbedding(text: string): Promise<number[]> { |
|
try { |
|
const config = getOpenAIConfig(); |
|
const openai = getOpenAIClient(); |
|
|
|
|
|
if (!openai.apiKey) { |
|
console.warn('OpenAI API key is not configured. Using fallback embedding method.'); |
|
return generateFallbackEmbedding(text); |
|
} |
|
|
|
|
|
const truncatedText = text.length > 8000 ? text.substring(0, 8000) : text; |
|
|
|
|
|
const response = await openai.embeddings.create({ |
|
model: config.embeddingModel, |
|
input: truncatedText, |
|
}); |
|
|
|
|
|
return response.data[0].embedding; |
|
} catch (error) { |
|
console.error('Error generating embedding:', error); |
|
console.warn('Falling back to simple embedding method'); |
|
return generateFallbackEmbedding(text); |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
function generateFallbackEmbedding(text: string): number[] { |
|
const words = text.toLowerCase().split(/\s+/); |
|
const vocabulary = [ |
|
'search', |
|
'find', |
|
'get', |
|
'fetch', |
|
'retrieve', |
|
'query', |
|
'map', |
|
'location', |
|
'weather', |
|
'file', |
|
'directory', |
|
'email', |
|
'message', |
|
'send', |
|
'create', |
|
'update', |
|
'delete', |
|
'browser', |
|
'web', |
|
'page', |
|
'click', |
|
'navigate', |
|
'screenshot', |
|
'automation', |
|
'database', |
|
'table', |
|
'record', |
|
'insert', |
|
'select', |
|
'schema', |
|
'data', |
|
'image', |
|
'photo', |
|
'video', |
|
'media', |
|
'upload', |
|
'download', |
|
'convert', |
|
'text', |
|
'document', |
|
'pdf', |
|
'excel', |
|
'word', |
|
'format', |
|
'parse', |
|
'api', |
|
'rest', |
|
'http', |
|
'request', |
|
'response', |
|
'json', |
|
'xml', |
|
'time', |
|
'date', |
|
'calendar', |
|
'schedule', |
|
'reminder', |
|
'clock', |
|
'math', |
|
'calculate', |
|
'number', |
|
'sum', |
|
'average', |
|
'statistics', |
|
'user', |
|
'account', |
|
'login', |
|
'auth', |
|
'permission', |
|
'role', |
|
]; |
|
|
|
|
|
const vector = new Array(FALLBACK_DIMENSIONS).fill(0); |
|
|
|
words.forEach((word) => { |
|
const index = vocabulary.indexOf(word); |
|
if (index >= 0 && index < vector.length) { |
|
vector[index] += 1; |
|
} |
|
|
|
const hash = word.split('').reduce((a, b) => a + b.charCodeAt(0), 0); |
|
vector[hash % vector.length] += 0.1; |
|
}); |
|
|
|
|
|
const magnitude = Math.sqrt(vector.reduce((sum, val) => sum + val * val, 0)); |
|
if (magnitude > 0) { |
|
return vector.map((val) => val / magnitude); |
|
} |
|
|
|
return vector; |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
export const saveToolsAsVectorEmbeddings = async ( |
|
serverName: string, |
|
tools: ToolInfo[], |
|
): Promise<void> => { |
|
try { |
|
const config = getOpenAIConfig(); |
|
const vectorRepository = getRepositoryFactory( |
|
'vectorEmbeddings', |
|
)() as VectorEmbeddingRepository; |
|
|
|
for (const tool of tools) { |
|
|
|
const searchableText = [ |
|
tool.name, |
|
tool.description, |
|
|
|
...(tool.inputSchema && typeof tool.inputSchema === 'object' |
|
? Object.keys(tool.inputSchema).filter((key) => key !== 'type' && key !== 'properties') |
|
: []), |
|
|
|
...(tool.inputSchema && |
|
tool.inputSchema.properties && |
|
typeof tool.inputSchema.properties === 'object' |
|
? Object.keys(tool.inputSchema.properties) |
|
: []), |
|
] |
|
.filter(Boolean) |
|
.join(' '); |
|
|
|
try { |
|
|
|
const embedding = await generateEmbedding(searchableText); |
|
|
|
|
|
await checkDatabaseVectorDimensions(embedding.length); |
|
|
|
|
|
await vectorRepository.saveEmbedding( |
|
'tool', |
|
`${serverName}:${tool.name}`, |
|
searchableText, |
|
embedding, |
|
{ |
|
serverName, |
|
toolName: tool.name, |
|
description: tool.description, |
|
inputSchema: tool.inputSchema, |
|
}, |
|
config.embeddingModel, |
|
); |
|
} catch (toolError) { |
|
console.error(`Error processing tool ${tool.name} for server ${serverName}:`, toolError); |
|
|
|
} |
|
} |
|
|
|
console.log(`Saved ${tools.length} tool embeddings for server: ${serverName}`); |
|
} catch (error) { |
|
console.error(`Error saving tool embeddings for server ${serverName}:`, error); |
|
} |
|
}; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
export const searchToolsByVector = async ( |
|
query: string, |
|
limit: number = 10, |
|
threshold: number = 0.7, |
|
serverNames?: string[], |
|
): Promise< |
|
Array<{ |
|
serverName: string; |
|
toolName: string; |
|
description: string; |
|
inputSchema: any; |
|
similarity: number; |
|
searchableText: string; |
|
}> |
|
> => { |
|
try { |
|
const vectorRepository = getRepositoryFactory( |
|
'vectorEmbeddings', |
|
)() as VectorEmbeddingRepository; |
|
|
|
|
|
const results = await vectorRepository.searchByText( |
|
query, |
|
generateEmbedding, |
|
limit, |
|
threshold, |
|
['tool'], |
|
); |
|
|
|
|
|
let filteredResults = results; |
|
if (serverNames && serverNames.length > 0) { |
|
filteredResults = results.filter((result) => { |
|
if (typeof result.embedding.metadata === 'string') { |
|
try { |
|
const parsedMetadata = JSON.parse(result.embedding.metadata); |
|
return serverNames.includes(parsedMetadata.serverName); |
|
} catch (error) { |
|
return false; |
|
} |
|
} |
|
return false; |
|
}); |
|
} |
|
|
|
|
|
return filteredResults.map((result) => { |
|
|
|
if (result.embedding?.metadata && typeof result.embedding.metadata === 'string') { |
|
try { |
|
|
|
const parsedMetadata = JSON.parse(result.embedding.metadata); |
|
|
|
if (parsedMetadata.serverName && parsedMetadata.toolName) { |
|
|
|
return { |
|
serverName: parsedMetadata.serverName, |
|
toolName: parsedMetadata.toolName, |
|
description: parsedMetadata.description || '', |
|
inputSchema: parsedMetadata.inputSchema || {}, |
|
similarity: result.similarity, |
|
searchableText: result.embedding.text_content, |
|
}; |
|
} |
|
} catch (error) { |
|
console.error('Error parsing metadata string:', error); |
|
|
|
} |
|
} |
|
|
|
|
|
const textContent = result.embedding?.text_content || ''; |
|
|
|
|
|
const toolNameMatch = textContent.match(/^(\S+)/); |
|
const toolName = toolNameMatch ? toolNameMatch[1] : ''; |
|
|
|
|
|
const serverNameMatch = toolName.match(/^([^_]+)_/); |
|
const serverName = serverNameMatch ? serverNameMatch[1] : 'unknown'; |
|
|
|
|
|
const description = textContent.replace(/^\S+\s*/, '').trim(); |
|
|
|
return { |
|
serverName, |
|
toolName, |
|
description, |
|
inputSchema: {}, |
|
similarity: result.similarity, |
|
searchableText: textContent, |
|
}; |
|
}); |
|
} catch (error) { |
|
console.error('Error searching tools by vector:', error); |
|
return []; |
|
} |
|
}; |
|
|
|
|
|
|
|
|
|
|
|
export const getAllVectorizedTools = async ( |
|
serverNames?: string[], |
|
): Promise< |
|
Array<{ |
|
serverName: string; |
|
toolName: string; |
|
description: string; |
|
inputSchema: any; |
|
}> |
|
> => { |
|
try { |
|
const config = getOpenAIConfig(); |
|
const vectorRepository = getRepositoryFactory( |
|
'vectorEmbeddings', |
|
)() as VectorEmbeddingRepository; |
|
|
|
|
|
let dimensionsToUse = getDimensionsForModel(config.embeddingModel); |
|
|
|
try { |
|
const result = await getAppDataSource().query(` |
|
SELECT atttypmod as dimensions |
|
FROM pg_attribute |
|
WHERE attrelid = 'vector_embeddings'::regclass |
|
AND attname = 'embedding' |
|
`); |
|
|
|
if (result && result.length > 0 && result[0].dimensions) { |
|
const rawValue = result[0].dimensions; |
|
|
|
if (rawValue === -1) { |
|
|
|
dimensionsToUse = getDimensionsForModel(config.embeddingModel); |
|
} else { |
|
|
|
dimensionsToUse = rawValue; |
|
} |
|
} |
|
} catch (error: any) { |
|
console.warn('Could not determine vector dimensions from database:', error?.message); |
|
} |
|
|
|
|
|
const results = await vectorRepository.searchSimilar( |
|
new Array(dimensionsToUse).fill(0), |
|
1000, |
|
-1, |
|
['tool'], |
|
); |
|
|
|
|
|
let filteredResults = results; |
|
if (serverNames && serverNames.length > 0) { |
|
filteredResults = results.filter((result) => { |
|
if (typeof result.embedding.metadata === 'string') { |
|
try { |
|
const parsedMetadata = JSON.parse(result.embedding.metadata); |
|
return serverNames.includes(parsedMetadata.serverName); |
|
} catch (error) { |
|
return false; |
|
} |
|
} |
|
return false; |
|
}); |
|
} |
|
|
|
|
|
return filteredResults.map((result) => { |
|
if (typeof result.embedding.metadata === 'string') { |
|
try { |
|
const parsedMetadata = JSON.parse(result.embedding.metadata); |
|
return { |
|
serverName: parsedMetadata.serverName, |
|
toolName: parsedMetadata.toolName, |
|
description: parsedMetadata.description, |
|
inputSchema: parsedMetadata.inputSchema, |
|
}; |
|
} catch (error) { |
|
console.error('Error parsing metadata string:', error); |
|
return { |
|
serverName: 'unknown', |
|
toolName: 'unknown', |
|
description: '', |
|
inputSchema: {}, |
|
}; |
|
} |
|
} |
|
return { |
|
serverName: 'unknown', |
|
toolName: 'unknown', |
|
description: '', |
|
inputSchema: {}, |
|
}; |
|
}); |
|
} catch (error) { |
|
console.error('Error getting all vectorized tools:', error); |
|
return []; |
|
} |
|
}; |
|
|
|
|
|
|
|
|
|
|
|
export const removeServerToolEmbeddings = async (serverName: string): Promise<void> => { |
|
try { |
|
const vectorRepository = getRepositoryFactory( |
|
'vectorEmbeddings', |
|
)() as VectorEmbeddingRepository; |
|
|
|
|
|
|
|
console.log(`TODO: Remove tool embeddings for server: ${serverName}`); |
|
} catch (error) { |
|
console.error(`Error removing tool embeddings for server ${serverName}:`, error); |
|
} |
|
}; |
|
|
|
|
|
|
|
|
|
|
|
export const syncAllServerToolsEmbeddings = async (): Promise<void> => { |
|
try { |
|
console.log('Starting synchronization of all server tools embeddings...'); |
|
|
|
|
|
const { getServersInfo } = await import('./mcpService.js'); |
|
|
|
const servers = getServersInfo(); |
|
let totalToolsSynced = 0; |
|
let serversSynced = 0; |
|
|
|
for (const server of servers) { |
|
if (server.status === 'connected' && server.tools && server.tools.length > 0) { |
|
try { |
|
console.log(`Syncing tools for server: ${server.name} (${server.tools.length} tools)`); |
|
await saveToolsAsVectorEmbeddings(server.name, server.tools); |
|
totalToolsSynced += server.tools.length; |
|
serversSynced++; |
|
} catch (error) { |
|
console.error(`Failed to sync tools for server ${server.name}:`, error); |
|
} |
|
} else if (server.status === 'connected' && (!server.tools || server.tools.length === 0)) { |
|
console.log(`Server ${server.name} is connected but has no tools to sync`); |
|
} else { |
|
console.log(`Skipping server ${server.name} (status: ${server.status})`); |
|
} |
|
} |
|
|
|
console.log( |
|
`Smart routing tools sync completed: synced ${totalToolsSynced} tools from ${serversSynced} servers`, |
|
); |
|
} catch (error) { |
|
console.error('Error during smart routing tools synchronization:', error); |
|
throw error; |
|
} |
|
}; |
|
|
|
|
|
|
|
|
|
|
|
|
|
async function checkDatabaseVectorDimensions(dimensionsNeeded: number): Promise<void> { |
|
try { |
|
|
|
if (!getAppDataSource().isInitialized) { |
|
console.info('Database not initialized, initializing...'); |
|
await initializeDatabase(); |
|
} |
|
|
|
|
|
|
|
let vectorTypeInfo; |
|
try { |
|
vectorTypeInfo = await getAppDataSource().query(` |
|
SELECT |
|
atttypmod, |
|
format_type(atttypid, atttypmod) as formatted_type |
|
FROM pg_attribute |
|
WHERE attrelid = 'vector_embeddings'::regclass |
|
AND attname = 'embedding' |
|
`); |
|
} catch (error) { |
|
console.warn('Could not get vector type info, falling back to atttypmod query'); |
|
} |
|
|
|
|
|
const result = await getAppDataSource().query(` |
|
SELECT atttypmod as dimensions |
|
FROM pg_attribute |
|
WHERE attrelid = 'vector_embeddings'::regclass |
|
AND attname = 'embedding' |
|
`); |
|
|
|
let currentDimensions = 0; |
|
|
|
|
|
if (result && result.length > 0 && result[0].dimensions) { |
|
if (vectorTypeInfo && vectorTypeInfo.length > 0) { |
|
|
|
const match = vectorTypeInfo[0].formatted_type?.match(/vector\((\d+)\)/); |
|
if (match) { |
|
currentDimensions = parseInt(match[1]); |
|
} |
|
} |
|
|
|
|
|
if (currentDimensions === 0) { |
|
const rawValue = result[0].dimensions; |
|
|
|
if (rawValue === -1) { |
|
|
|
currentDimensions = 0; |
|
} else { |
|
|
|
currentDimensions = rawValue; |
|
} |
|
} |
|
} |
|
|
|
|
|
try { |
|
const recordCheck = await getAppDataSource().query(` |
|
SELECT dimensions, model, COUNT(*) as count |
|
FROM vector_embeddings |
|
GROUP BY dimensions, model |
|
ORDER BY count DESC |
|
LIMIT 5 |
|
`); |
|
|
|
if (recordCheck && recordCheck.length > 0) { |
|
|
|
if (currentDimensions === 0 && recordCheck[0].dimensions) { |
|
currentDimensions = recordCheck[0].dimensions; |
|
} |
|
} |
|
} catch (error) { |
|
console.warn('Could not check dimensions from actual records:', error); |
|
} |
|
|
|
|
|
if (currentDimensions === 0 || currentDimensions !== dimensionsNeeded) { |
|
console.log( |
|
`Vector dimensions mismatch: database=${currentDimensions}, needed=${dimensionsNeeded}`, |
|
); |
|
|
|
if (currentDimensions === 0) { |
|
console.log('Setting up vector dimensions for the first time...'); |
|
} else { |
|
console.log('Dimension mismatch detected. Clearing existing incompatible vector data...'); |
|
|
|
|
|
await clearMismatchedVectorData(dimensionsNeeded); |
|
} |
|
|
|
|
|
await getAppDataSource().query(`DROP INDEX IF EXISTS idx_vector_embeddings_embedding;`); |
|
|
|
|
|
await getAppDataSource().query(` |
|
ALTER TABLE vector_embeddings |
|
ALTER COLUMN embedding TYPE vector(${dimensionsNeeded}); |
|
`); |
|
|
|
|
|
try { |
|
await getAppDataSource().query(` |
|
CREATE INDEX idx_vector_embeddings_embedding |
|
ON vector_embeddings USING ivfflat (embedding vector_cosine_ops) WITH (lists = 100); |
|
`); |
|
} catch (indexError: any) { |
|
|
|
|
|
if (indexError.code === '42P07' || indexError.code === '23505') { |
|
console.log('Index already exists, continuing...'); |
|
} else { |
|
console.warn('Warning: Failed to create index, but continuing:', indexError.message); |
|
} |
|
} |
|
|
|
console.log(`Successfully configured vector dimensions to ${dimensionsNeeded}`); |
|
} |
|
} catch (error: any) { |
|
console.error('Error checking/updating vector dimensions:', error); |
|
throw new Error(`Vector dimension check failed: ${error?.message || 'Unknown error'}`); |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
async function clearMismatchedVectorData(expectedDimensions: number): Promise<void> { |
|
try { |
|
console.log( |
|
`Clearing vector embeddings with dimensions different from ${expectedDimensions}...`, |
|
); |
|
|
|
|
|
await getAppDataSource().query( |
|
` |
|
DELETE FROM vector_embeddings |
|
WHERE dimensions != $1 |
|
`, |
|
[expectedDimensions], |
|
); |
|
|
|
console.log('Successfully cleared mismatched vector embeddings'); |
|
} catch (error: any) { |
|
console.error('Error clearing mismatched vector data:', error); |
|
throw error; |
|
} |
|
} |
|
|