File size: 3,586 Bytes
a5245e5
 
05c83f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a5245e5
 
 
 
05c83f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a5245e5
 
 
05c83f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a5245e5
 
 
 
 
 
 
 
 
 
05c83f3
 
 
 
 
 
 
 
 
 
 
 
a5245e5
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
from fastapi import FastAPI, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware
# Loading
import os
from datasets import load_dataset
import torch
from tqdm import tqdm
from sentence_transformers import SentenceTransformer
import uuid
from qdrant_client import models, QdrantClient
from itertools import islice

# Create function to upsert embeddings in batches
def batched(iterable, n):
    iterator = iter(iterable)
    while batch := list(islice(iterator, n)):
        yield batch

batch_size = 100
# Create an in-memory Qdrant instance
client2 = QdrantClient(path ="database.db")

# Create a Qdrant collection for the embeddings
client2.create_collection(
    collection_name="law",
    vectors_config=models.VectorParams(
        size=model.get_sentence_embedding_dimension(),
        distance=models.Distance.COSINE,
    ),
)

# Determine device based on GPU availability
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

FILEPATH_PATTERN = "structured_data_doc.parquet"
CACHE_DIR = "/.cache"
NUM_PROC = os.cpu_count()


app = FastAPI()



# Load the desired model
model = SentenceTransformer(
          'sentence-transformers/all-MiniLM-L6-v2',
          device=device
)
# Create function to generate embeddings (in batches) for a given dataset split
def generate_embeddings(dataset, batch_size=32):
    embeddings = []

    with tqdm(total=len(dataset), desc=f"Generating embeddings for dataset") as pbar:
        for i in range(0, len(dataset), batch_size):
            batch_sentences = dataset['content'][i:i+batch_size]
            batch_embeddings = model.encode(batch_sentences)
            embeddings.extend(batch_embeddings)
            pbar.update(len(batch_sentences))

    return embeddings
    
@app.post("/uploadfile/")
async def create_upload_file(file: UploadFile = File(...)):
    # Here you can save the file and do other operations as needed
    full_dataset = load_dataset("parquet",
                data_files=FILEPATH_PATTERN,
                split="train",
                keep_in_memory=True,
                cache_dir=CACHE_DIR,
                num_proc=NUM_PROC*2)
    # Generate and append embeddings to the train split
    law_embeddings = generate_embeddings(full_dataset)
    full_dataset= full_dataset.add_column("embeddings", law_embeddings)
    
    if not 'uuid' in full_dataset.column_names:
      full_dataset = full_dataset.add_column('uuid', [str(uuid.uuid4()) for _ in range(len(full_dataset))])
    # Upsert the embeddings in batches
    for batch in batched(full_dataset, batch_size):
        ids = [point.pop("uuid") for point in batch]
        vectors = [point.pop("embeddings") for point in batch]
    
        client2.upsert(
            collection_name="law",
            points=models.Batch(
                ids=ids,
                vectors=vectors,
                payloads=batch,
            ),
        )
    return {"filename": file.filename, "message": "Done"}

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

@app.get("/search")
def search(prompt: str):
    # Let's see what senators are saying about immigration policy
    hits = client2.search(
        collection_name="law",
        query_vector=model.encode(prompt).tolist(),
        limit=5
    )
    for hit in hits:
      print(hit.payload, "score:", hit.score)
    return {'detail': 'hit.payload', 'score:', hit.score}
    
@app.get("/")
def api_home():
    return {'detail': 'Welcome to FastAPI Qdrant importer!'}