|
|
|
|
|
""" |
|
|
Fix Qdrant collection dimensions for Manufacturing RAG Agent |
|
|
""" |
|
|
|
|
|
import os |
|
|
from dotenv import load_dotenv |
|
|
from qdrant_client import QdrantClient |
|
|
from qdrant_client.http import models |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def fix_qdrant_collection(): |
|
|
"""Fix the Qdrant collection dimensions.""" |
|
|
|
|
|
print("π§ Fixing Qdrant Collection Dimensions") |
|
|
print("=" * 50) |
|
|
|
|
|
|
|
|
qdrant_api_key = os.environ["QDRANT_API_KEY"] |
|
|
qdrant_url = os.environ["QDRANT_URL"] |
|
|
collection_name = 'manufacturing_docs' |
|
|
|
|
|
if not qdrant_url: |
|
|
print("β QDRANT_URL not found in environment variables") |
|
|
return False |
|
|
|
|
|
try: |
|
|
|
|
|
print(f"π Connecting to Qdrant: {qdrant_url}") |
|
|
client = QdrantClient( |
|
|
url="https://50f53cc8-bbb0-4939-8254-8f025a577222.us-west-2-0.aws.cloud.qdrant.io:6333", |
|
|
api_key="eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIn0.gHOXbfqPucRwhczrW8s3VSZbconqQ6Rk49Uaz9ZChdE",) |
|
|
|
|
|
|
|
|
collections = client.get_collections() |
|
|
collection_names = [col.name for col in collections.collections] |
|
|
|
|
|
if collection_name in collection_names: |
|
|
print(f"π Collection '{collection_name}' exists") |
|
|
|
|
|
|
|
|
collection_info = client.get_collection(collection_name) |
|
|
current_dim = collection_info.config.params.vectors.size |
|
|
print(f"π Current vector dimensions: {current_dim}") |
|
|
|
|
|
if current_dim != 1024: |
|
|
print(f"β οΈ Need to recreate collection with correct dimensions (1024)") |
|
|
|
|
|
|
|
|
response = input("ποΈ Delete existing collection and recreate? (y/N): ").strip().lower() |
|
|
if response != 'y': |
|
|
print("β Aborted by user") |
|
|
return False |
|
|
|
|
|
|
|
|
print(f"ποΈ Deleting collection '{collection_name}'...") |
|
|
client.delete_collection(collection_name) |
|
|
print("β
Collection deleted") |
|
|
else: |
|
|
print("β
Collection already has correct dimensions") |
|
|
return True |
|
|
|
|
|
|
|
|
print(f"π Creating collection '{collection_name}' with 4096 dimensions...") |
|
|
|
|
|
client.create_collection( |
|
|
collection_name=collection_name, |
|
|
vectors_config=models.VectorParams( |
|
|
size=4096, |
|
|
distance=models.Distance.COSINE |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
print("π Creating payload indexes...") |
|
|
|
|
|
indexes_to_create = [ |
|
|
("document_id", models.PayloadFieldSchema( |
|
|
data_type=models.PayloadSchemaType.KEYWORD |
|
|
)), |
|
|
("document_type", models.PayloadFieldSchema( |
|
|
data_type=models.PayloadSchemaType.KEYWORD |
|
|
)), |
|
|
("page_number", models.PayloadFieldSchema( |
|
|
data_type=models.PayloadSchemaType.INTEGER |
|
|
)), |
|
|
("worksheet_name", models.PayloadFieldSchema( |
|
|
data_type=models.PayloadSchemaType.KEYWORD |
|
|
)), |
|
|
] |
|
|
|
|
|
for field_name, field_schema in indexes_to_create: |
|
|
try: |
|
|
client.create_payload_index( |
|
|
collection_name=collection_name, |
|
|
field_name=field_name, |
|
|
field_schema=field_schema |
|
|
) |
|
|
print(f"β
Created index for '{field_name}'") |
|
|
except Exception as e: |
|
|
print(f"β οΈ Failed to create index for '{field_name}': {e}") |
|
|
|
|
|
print("β
Collection recreated successfully with correct dimensions!") |
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
print(f"β Error: {e}") |
|
|
return False |
|
|
|
|
|
def update_config_file(): |
|
|
"""Update config.yaml with correct vector dimensions.""" |
|
|
|
|
|
print("\nπ§ Updating Configuration") |
|
|
print("=" * 30) |
|
|
|
|
|
config_path = "src/config.yaml" |
|
|
|
|
|
if not os.path.exists(config_path): |
|
|
print(f"β Config file not found: {config_path}") |
|
|
return False |
|
|
|
|
|
try: |
|
|
|
|
|
with open(config_path, 'r') as f: |
|
|
content = f.read() |
|
|
|
|
|
|
|
|
import re |
|
|
|
|
|
|
|
|
if 'vector_size:' in content: |
|
|
|
|
|
content = re.sub(r'vector_size:\s*\d+', 'vector_size: 4096', content) |
|
|
print("β
Updated vector_size to 4096") |
|
|
else: |
|
|
|
|
|
if 'vector_store:' in content: |
|
|
content = re.sub( |
|
|
r'(vector_store:\s*\n)', |
|
|
r'\1 vector_size: 4096\n', |
|
|
content |
|
|
) |
|
|
print("β
Added vector_size: 4096 to vector_store section") |
|
|
else: |
|
|
print("β οΈ No vector_store section found, please add manually:") |
|
|
print("vector_store:") |
|
|
print(" vector_size: 4096") |
|
|
|
|
|
|
|
|
with open(config_path, 'w') as f: |
|
|
f.write(content) |
|
|
|
|
|
print(f"β
Updated {config_path}") |
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
print(f"β Error updating config: {e}") |
|
|
return False |
|
|
|
|
|
def test_embedding_dimensions(): |
|
|
"""Test the actual embedding dimensions from SiliconFlow.""" |
|
|
|
|
|
print("\nπ§ͺ Testing Embedding Dimensions") |
|
|
print("=" * 35) |
|
|
|
|
|
try: |
|
|
import requests |
|
|
|
|
|
api_key = os.getenv('SILICONFLOW_API_KEY') |
|
|
if not api_key: |
|
|
print("β SILICONFLOW_API_KEY not found") |
|
|
return None |
|
|
|
|
|
|
|
|
payload = { |
|
|
"model": "Qwen/Qwen3-Embedding-8B", |
|
|
"input": ["test embedding dimension"], |
|
|
"encoding_format": "float" |
|
|
} |
|
|
|
|
|
headers = { |
|
|
'Authorization': f'Bearer {api_key}', |
|
|
'Content-Type': 'application/json' |
|
|
} |
|
|
|
|
|
response = requests.post( |
|
|
"https://api.siliconflow.com/v1/embeddings", |
|
|
json=payload, |
|
|
headers=headers, |
|
|
timeout=10 |
|
|
) |
|
|
|
|
|
if response.status_code == 200: |
|
|
data = response.json() |
|
|
if data.get('data') and len(data['data']) > 0: |
|
|
embedding = data['data'][0]['embedding'] |
|
|
dim = len(embedding) |
|
|
print(f"β
Actual embedding dimensions: {dim}") |
|
|
return dim |
|
|
else: |
|
|
print("β No embedding data returned") |
|
|
else: |
|
|
print(f"β API error: {response.status_code} - {response.text}") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"β Error testing embeddings: {e}") |
|
|
|
|
|
return None |
|
|
|
|
|
def main(): |
|
|
"""Main function.""" |
|
|
|
|
|
print("π Manufacturing RAG Agent - Dimension Fix") |
|
|
print("=" * 60) |
|
|
|
|
|
|
|
|
actual_dim = test_embedding_dimensions() |
|
|
|
|
|
if actual_dim and actual_dim != 4096: |
|
|
print(f"β οΈ Warning: Expected 4096 dimensions, but got {actual_dim}") |
|
|
print("You may need to update the vector_size in your config") |
|
|
|
|
|
|
|
|
if fix_qdrant_collection(): |
|
|
print("\nβ
Qdrant collection fixed successfully!") |
|
|
else: |
|
|
print("\nβ Failed to fix Qdrant collection") |
|
|
return |
|
|
|
|
|
|
|
|
if update_config_file(): |
|
|
print("β
Configuration updated successfully!") |
|
|
else: |
|
|
print("β οΈ Please update config manually") |
|
|
|
|
|
print("\nπ Fix Complete!") |
|
|
print("\nπ Next Steps:") |
|
|
print("1. Restart your Gradio demo") |
|
|
print("2. Re-upload your documents") |
|
|
print("3. Test question answering") |
|
|
|
|
|
print("\nπ To restart the demo:") |
|
|
print("python fixed_gradio_demo.py") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |