nazib61 commited on
Commit
2d181ba
·
verified ·
1 Parent(s): 0ae57e5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -39
app.py CHANGED
@@ -2,54 +2,53 @@ import gradio as gr
2
  from datasets import load_dataset
3
  from qdrant_client import QdrantClient, models
4
  from sentence_transformers import SentenceTransformer
 
5
 
6
  # --- Configuration ---
7
- QDRANT_HOST = "localhost" # Or your Hugging Face Space Qdrant URL
8
- QDRANT_PORT = 6333
 
 
9
  COLLECTION_NAME = "my_text_collection"
10
  MODEL_NAME = 'sentence-transformers/all-MiniLM-L6-v2'
11
 
12
- # --- Load Dataset and Model ---
13
- # Using a simple dataset from Hugging Face
14
- dataset = load_dataset("ag_news", split="test")
15
- data = [item['text'] for item in dataset]
16
- # Limiting the dataset for a quicker demo
17
- data = data[:1000]
18
-
19
- # Load a pre-trained sentence transformer model
20
- model = SentenceTransformer(MODEL_NAME)
21
 
22
  # --- Qdrant Client and Collection Setup ---
23
- # Initialize Qdrant client
24
- # In a Hugging Face Space, you might use a local in-memory instance or connect to a running Qdrant container.
25
- qdrant_client = QdrantClient(host=QDRANT_HOST, port=QDRANT_PORT)
26
 
27
- # Create a Qdrant collection if it doesn't exist
28
  try:
29
- qdrant_client.get_collection(collection_name=COLLECTION_NAME)
30
  print("Collection already exists.")
31
  except Exception as e:
32
- print("Creating collection...")
33
- qdrant_client.recreate_collection(
 
 
 
 
 
 
 
34
  collection_name=COLLECTION_NAME,
35
  vectors_config=models.VectorParams(size=model.get_sentence_embedding_dimension(), distance=models.Distance.COSINE),
36
  )
37
 
38
  # --- Generate and Index Embeddings ---
39
  print("Generating and indexing embeddings...")
40
- batch_size = 128
41
- for i in range(0, len(data), batch_size):
42
- batch_texts = data[i:i+batch_size]
43
- embeddings = model.encode(batch_texts, convert_to_tensor=True)
44
-
45
- qdrant_client.upsert(
46
- collection_name=COLLECTION_NAME,
47
- points=models.Batch(
48
- ids=list(range(i, i + len(batch_texts))),
49
- vectors=[embedding.tolist() for embedding in embeddings],
50
- payloads=[{"text": text} for text in batch_texts]
51
- )
52
- )
53
  print("Embeddings indexed successfully.")
54
 
55
 
@@ -61,18 +60,22 @@ def search_in_qdrant(query):
61
  if not query:
62
  return "Please enter a search query."
63
 
64
- query_embedding = model.encode(query).tolist()
65
-
66
- search_result = qdrant_client.search(
67
  collection_name=COLLECTION_NAME,
68
- query_vector=query_embedding,
69
- limit=5 # Return the top 5 most similar results
 
 
70
  )
71
 
72
  results_text = ""
73
- for hit in search_result:
 
 
 
74
  results_text += f"**Score:** {hit.score:.4f}\n"
75
- results_text += f"**Text:** {hit.payload['text']}\n\n"
76
 
77
  return results_text
78
 
@@ -82,7 +85,7 @@ with gr.Blocks() as demo:
82
  gr.Markdown("Enter a query to search for similar news articles from the AG News dataset.")
83
 
84
  with gr.Row():
85
- search_input = gr.Textbox(label="Search Query")
86
 
87
  search_button = gr.Button("Search")
88
  search_output = gr.Markdown()
 
2
  from datasets import load_dataset
3
  from qdrant_client import QdrantClient, models
4
  from sentence_transformers import SentenceTransformer
5
+ import torch # Ensure torch is imported
6
 
7
  # --- Configuration ---
8
+ # Use ":memory:" for a temporary, in-memory database.
9
+ # Or use a path like "./qdrant_db" to save the data to disk.
10
+ # Using a path is better for Spaces as data will be rebuilt only when the code changes.
11
+ QDRANT_PATH = "./qdrant_db"
12
  COLLECTION_NAME = "my_text_collection"
13
  MODEL_NAME = 'sentence-transformers/all-MiniLM-L6-v2'
14
 
15
+ # --- Load Model ---
16
+ # Specify that the model should run on the CPU, which is standard for HF Spaces
17
+ device = "cpu"
18
+ model = SentenceTransformer(MODEL_NAME, device=device)
 
 
 
 
 
19
 
20
  # --- Qdrant Client and Collection Setup ---
21
+ # Initialize Qdrant client to use a local, on-disk storage
22
+ # This avoids the need to run a separate Qdrant server
23
+ qdrant_client = QdrantClient(path=QDRANT_PATH)
24
 
25
+ # Check if the collection already exists
26
  try:
27
+ collection_info = qdrant_client.get_collection(collection_name=COLLECTION_NAME)
28
  print("Collection already exists.")
29
  except Exception as e:
30
+ print("Collection not found, creating a new one...")
31
+ # --- Load Dataset ---
32
+ # We only load the dataset and create embeddings if the collection doesn't exist
33
+ dataset = load_dataset("ag_news", split="test")
34
+ # Limiting the dataset for a quicker demo setup
35
+ data = [item['text'] for item in dataset][:1000]
36
+
37
+ # Create the collection
38
+ qdrant_client.create_collection(
39
  collection_name=COLLECTION_NAME,
40
  vectors_config=models.VectorParams(size=model.get_sentence_embedding_dimension(), distance=models.Distance.COSINE),
41
  )
42
 
43
  # --- Generate and Index Embeddings ---
44
  print("Generating and indexing embeddings...")
45
+ # This can take a moment on the first run
46
+ qdrant_client.add(
47
+ collection_name=COLLECTION_NAME,
48
+ documents=data,
49
+ ids=list(range(len(data))), # Simple sequential IDs
50
+ embedding_model=model
51
+ )
 
 
 
 
 
 
52
  print("Embeddings indexed successfully.")
53
 
54
 
 
60
  if not query:
61
  return "Please enter a search query."
62
 
63
+ # The client's search function can now take the model directly
64
+ hits = qdrant_client.search(
 
65
  collection_name=COLLECTION_NAME,
66
+ query_text=query,
67
+ query_filter=None, # No filters for now
68
+ limit=5, # Return the top 5 most similar results
69
+ embedding_model=model
70
  )
71
 
72
  results_text = ""
73
+ if not hits:
74
+ return "No results found."
75
+
76
+ for hit in hits:
77
  results_text += f"**Score:** {hit.score:.4f}\n"
78
+ results_text += f"**Text:** {hit.payload['document']}\n\n" # Payload key is 'document' when using .add()
79
 
80
  return results_text
81
 
 
85
  gr.Markdown("Enter a query to search for similar news articles from the AG News dataset.")
86
 
87
  with gr.Row():
88
+ search_input = gr.Textbox(label="Search Query", placeholder="e.g., 'Latest news on space exploration'")
89
 
90
  search_button = gr.Button("Search")
91
  search_output = gr.Markdown()