File size: 5,157 Bytes
5f3b20a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
๋ฒกํ„ฐ ์Šคํ† ์–ด ๋ชจ๋“ˆ: ๋ฌธ์„œ ์ž„๋ฒ ๋”ฉ ์ƒ์„ฑ ๋ฐ ๋ฒกํ„ฐ ์Šคํ† ์–ด ๊ตฌ์ถ•
๋ฐฐ์น˜ ์ฒ˜๋ฆฌ ์ ์šฉ + ์ฒญํฌ ๊ธธ์ด ํ™•์ธ ์ถ”๊ฐ€
"""

import os
import argparse
import logging
from tqdm import tqdm
from langchain_community.vectorstores import FAISS
from langchain.schema.document import Document
from langchain_huggingface import HuggingFaceEmbeddings
from e5_embeddings import E5Embeddings

# ๋กœ๊น… ์„ค์ •
logging.getLogger().setLevel(logging.ERROR)

def get_embeddings(model_name="intfloat/multilingual-e5-large-instruct", device="cuda"):
    print(f"[INFO] ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ ๋””๋ฐ”์ด์Šค: {device}")
    return E5Embeddings(
        model_name=model_name,
        model_kwargs={'device': device},
        encode_kwargs={'normalize_embeddings': True}
    )

def build_vector_store_batch(documents, embeddings, save_path="vector_db", batch_size=4):
    if not documents:
        raise ValueError("๋ฌธ์„œ๊ฐ€ ์—†์Šต๋‹ˆ๋‹ค. ๋ฌธ์„œ๊ฐ€ ์˜ฌ๋ฐ”๋ฅด๊ฒŒ ๋กœ๋“œ๋˜์—ˆ๋Š”์ง€ ํ™•์ธํ•˜์„ธ์š”.")

    texts = [doc.page_content for doc in documents]
    metadatas = [doc.metadata for doc in documents]

    # ์ฒญํฌ ๊ธธ์ด ์ถœ๋ ฅ
    lengths = [len(t) for t in texts]
    print(f"๐Ÿ’ก ์ฒญํฌ ์ˆ˜: {len(texts)}")
    print(f"๐Ÿ’ก ๊ฐ€์žฅ ๊ธด ์ฒญํฌ ๊ธธ์ด: {max(lengths)} chars")
    print(f"๐Ÿ’ก ํ‰๊ท  ์ฒญํฌ ๊ธธ์ด: {sum(lengths) // len(lengths)} chars")

    # ๋ฐฐ์น˜๋กœ ๋‚˜๋ˆ„๊ธฐ
    batches = [texts[i:i + batch_size] for i in range(0, len(texts), batch_size)]
    metadata_batches = [metadatas[i:i + batch_size] for i in range(0, len(metadatas), batch_size)]

    print(f"Processing {len(batches)} batches with size {batch_size}")
    print(f"Initializing vector store with batch 1/{len(batches)}")

    # โœ… from_documents ์‚ฌ์šฉ
    first_docs = [
        Document(page_content=text, metadata=meta)
        for text, meta in zip(batches[0], metadata_batches[0])
    ]
    vectorstore = FAISS.from_documents(first_docs, embeddings)

    for i in tqdm(range(1, len(batches)), desc="Processing batches"):
        try:
            docs_batch = [
                Document(page_content=text, metadata=meta)
                for text, meta in zip(batches[i], metadata_batches[i])
            ]
            vectorstore.add_documents(docs_batch)

            if i % 10 == 0:
                temp_save_path = f"{save_path}_temp"
                os.makedirs(os.path.dirname(temp_save_path) if os.path.dirname(temp_save_path) else '.', exist_ok=True)
                vectorstore.save_local(temp_save_path)
                print(f"Temporary vector store saved to {temp_save_path} after batch {i}")

        except Exception as e:
            print(f"Error processing batch {i}: {e}")
            error_save_path = f"{save_path}_error_at_batch_{i}"
            os.makedirs(os.path.dirname(error_save_path) if os.path.dirname(error_save_path) else '.', exist_ok=True)
            vectorstore.save_local(error_save_path)
            print(f"Partial vector store saved to {error_save_path}")
            raise

    os.makedirs(os.path.dirname(save_path) if os.path.dirname(save_path) else '.', exist_ok=True)
    vectorstore.save_local(save_path)
    print(f"Vector store saved to {save_path}")

    return vectorstore

def load_vector_store(embeddings, load_path="vector_db"):
    if not os.path.exists(load_path):
        raise FileNotFoundError(f"๋ฒกํ„ฐ ์Šคํ† ์–ด๋ฅผ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค: {load_path}")
    return FAISS.load_local(load_path, embeddings, allow_dangerous_deserialization=True)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="๋ฒกํ„ฐ ์Šคํ† ์–ด ๊ตฌ์ถ•")
    parser.add_argument("--folder", type=str, default="final_dataset", help="๋ฌธ์„œ๊ฐ€ ์žˆ๋Š” ํด๋” ๊ฒฝ๋กœ")
    parser.add_argument("--save_path", type=str, default="vector_db", help="๋ฒกํ„ฐ ์Šคํ† ์–ด ์ €์žฅ ๊ฒฝ๋กœ")
    parser.add_argument("--batch_size", type=int, default=4, help="๋ฐฐ์น˜ ํฌ๊ธฐ")
    parser.add_argument("--model_name", type=str, default="intfloat/multilingual-e5-large-instruct", help="์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ ์ด๋ฆ„")
   # parser.add_argument("--device", type=str, default="cuda", help="์‚ฌ์šฉํ•  ๋””๋ฐ”์ด์Šค ('cuda' ๋˜๋Š” 'cpu')")
    parser.add_argument("--device", type=str, default="cuda", help="์‚ฌ์šฉํ•  ๋””๋ฐ”์ด์Šค ('cuda' ๋˜๋Š” 'cpu' ๋˜๋Š” 'cuda:1')")

    args = parser.parse_args()

    # ๋ฌธ์„œ ์ฒ˜๋ฆฌ ๋ชจ๋“ˆ import
    from document_processor_image_test import load_documents, split_documents

    documents = load_documents(args.folder)
    chunks = split_documents(documents, chunk_size=800, chunk_overlap=100)

    print(f"[DEBUG] ๋ฌธ์„œ ๋กœ๋”ฉ ๋ฐ ์ฒญํฌ ๋ถ„ํ•  ์™„๋ฃŒ, ์ž„๋ฒ ๋”ฉ ๋‹จ๊ณ„ ์ง„์ž… ์ „")
    print(f"[INFO] ์„ ํƒ๋œ ๋””๋ฐ”์ด์Šค: {args.device}")

    try:
        embeddings = get_embeddings(
            model_name=args.model_name,
            device=args.device
        )
        print(f"[DEBUG] ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ ์ƒ์„ฑ ์™„๋ฃŒ")
    except Exception as e:
        print(f"[ERROR] ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ ์ƒ์„ฑ ์ค‘ ์—๋Ÿฌ ๋ฐœ์ƒ: {e}")
        import traceback; traceback.print_exc()
        exit(1)

    build_vector_store_batch(chunks, embeddings, args.save_path, args.batch_size)