File size: 2,649 Bytes
8066b54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from typing import List, Tuple
import numpy as np
from openai import AsyncOpenAI
import os

def cosine_similarity(a, b):
    """Calculate cosine similarity between two vectors."""
    a = np.array(a)
    b = np.array(b)
    return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))

class VectorDatabase:
    def __init__(self):
        self.embeddings = []
        self.texts = []
        self.client = AsyncOpenAI(api_key=os.getenv("OPENAI_API_KEY"))

    async def abuild_from_list(self, texts):
        self.texts = texts
        self.embeddings = []  # Clear existing embeddings
        
        try:
            for text in texts:
                if not text.strip():  # Skip empty texts
                    continue
                    
                response = await self.client.embeddings.create(
                    model="text-embedding-ada-002",
                    input=text.replace("\n", " ")  # Replace newlines with spaces
                )
                if response and response.data and len(response.data) > 0:
                    self.embeddings.append(response.data[0].embedding)
                else:
                    print(f"Warning: No embedding generated for text: {text[:100]}...")
            
            return self
        except Exception as e:
            print(f"Error in abuild_from_list: {str(e)}")
            raise e

    async def search_by_text(self, query, k=4):
        if not query.strip():
            return []
            
        try:
            # Get query embedding
            response = await self.client.embeddings.create(
                model="text-embedding-ada-002",
                input=query.replace("\n", " ")  # Replace newlines with spaces
            )
            
            if not response or not response.data or len(response.data) == 0:
                print("Warning: No embedding generated for query")
                return []
                
            query_embedding = response.data[0].embedding
            
            # Calculate similarities
            similarities = []
            for idx, embedding in enumerate(self.embeddings):
                if embedding:  # Check if embedding exists
                    similarity = cosine_similarity(query_embedding, embedding)
                    similarities.append((self.texts[idx], similarity))
            
            # Sort by similarity
            similarities.sort(key=lambda x: x[1], reverse=True)
            
            # Return top k results
            return similarities[:k]
            
        except Exception as e:
            print(f"Error in search_by_text: {str(e)}")
            raise e