RAG / app.py
dini15's picture
Update app.py
0d715e7 verified
from sentence_transformers import SentenceTransformer
from PyPDF2 import PdfReader
import tiktoken
import groq
import faiss
import numpy as np
import gradio as gr
import json
import os
import pickle
# == Buat folder models ==
os.makedirs("models", exist_ok=True)
# == Load API Key dari File ==
def load_api_key():
with open("config.json", "r") as f:
config = json.load(f)
return config["GROQ_API_KEY"]
GROQ_API_KEY = load_api_key()
# == Ekstraksi Teks dari PDF ==
def extract_text_from_pdf(pdf_file: str) -> str:
with open(pdf_file, 'rb') as pdf:
reader = PdfReader(pdf)
text = " ".join(page.extract_text() or "" for page in reader.pages)
return text
# == Chunking Teks ==
def chunk_text(text: str, max_tokens: int = 512) -> list:
tokenizer = tiktoken.get_encoding("cl100k_base")
tokens = tokenizer.encode(text)
chunks = []
for i in range(0, len(tokens), max_tokens):
chunk_tokens = tokens[i:i+max_tokens]
chunk_text = tokenizer.decode(chunk_tokens)
chunks.append(chunk_text)
return chunks
# == Embedding dengan SentenceTransformer ==
model = SentenceTransformer('all-MiniLM-L6-v2') # Global model
def get_embedding(text: str):
return np.array(model.encode(text), dtype=np.float32)
# == Setup FAISS ==
d = 384 # Dimensi embedding sesuai dengan model
index = faiss.IndexFlatL2(d)
text_chunks = []
def add_to_db(text_chunks_local):
global text_chunks
text_chunks = text_chunks_local
embeddings = np.array([get_embedding(text) for text in text_chunks], dtype=np.float32).reshape(-1, d)
index.add(embeddings)
def search_db(query, k=5):
if index.ntotal == 0:
return ["Database masih kosong, silakan tambahkan data."]
query_embedding = np.array([get_embedding(query)], dtype=np.float32).reshape(1, -1)
distances, indices = index.search(query_embedding, k)
return [text_chunks[i] for i in indices[0] if i < len(text_chunks)]
def save_to_faiss(index_path="vector_index.faiss"):
faiss.write_index(index, index_path)
def load_faiss(index_path="vector_index.faiss"):
global index
index = faiss.read_index(index_path)
def save_embeddings(embeddings_path="models/embeddings.pkl"):
with open(embeddings_path, "wb") as f:
pickle.dump(index, f)
def load_embeddings(embeddings_path="models/embeddings.pkl"):
global index
with open(embeddings_path, "rb") as f:
index = pickle.load(f)
# == Integrasi LLaMA via Groq API ==
client = groq.Client(api_key=GROQ_API_KEY)
def query_llama(prompt):
response = client.chat.completions.create(
model="llama3-8b-8192",
messages=[{"role": "user", "content": prompt}],
max_tokens=512
)
return response.choices[0].message.content.strip()
# == Main Workflow ==
if __name__ == '__main__':
pdf_text = extract_text_from_pdf('dini_anggriyani_synthetic_data.pdf')
text_chunks = chunk_text(pdf_text, max_tokens=1024)
add_to_db(text_chunks)
save_to_faiss()
save_embeddings()
retrieved_chunks = search_db("Apa isi dokumen ini?")
context = "\n".join(retrieved_chunks)
prompt = f"Gunakan informasi berikut untuk menjawab:\n{context}\n\nPertanyaan: Apa isi dokumen ini?"
answer = query_llama(prompt)
print(answer)
# == Chatbot Interface ==
def chatbot_interface(user_query):
retrieved_chunks = search_db(user_query)
context = "\n".join(retrieved_chunks)
prompt = f"Gunakan informasi berikut untuk menjawab:\n{context}\n\nPertanyaan: {user_query}"
answer = query_llama(prompt)
return answer
iface = gr.Interface(fn=chatbot_interface, inputs="text", outputs="text", title="RAG Chatbot")
iface.launch()