Spaces:
Sleeping
Sleeping
import gradio as gr | |
from langchain_community.document_loaders import TextLoader | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_huggingface import HuggingFaceEmbeddings | |
from langchain.vectorstores import Chroma | |
from groq import Groq | |
import os | |
# === Config === | |
GROQ_API_KEY = os.getenv("GROQ_API_KEY") # Use environment variable for Groq API Key | |
LLM_MODEL = "llama3-70b-8192" | |
FILE_PATH = "./Estonia.txt" # Use relative path for Hugging Face Space | |
DB_DIR = "chroma_db" | |
# === Load and Chunk Document === | |
def load_and_split(filepath): | |
loader = TextLoader(filepath, encoding="utf-8") | |
docs = loader.load() | |
splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) | |
chunks = splitter.split_documents(docs) | |
return chunks | |
# === Create or Load Vector Store === | |
def get_vector_store(chunks): | |
embeddings = HuggingFaceEmbeddings( | |
model_name="all-MiniLM-L6-v2", | |
model_kwargs={"token": os.getenv("HF_TOKEN")} # Use environment variable for HuggingFace token | |
) | |
vectordb = Chroma.from_documents(documents=chunks, embedding=embeddings, persist_directory=DB_DIR) | |
return vectordb | |
def load_vector_store(): | |
embeddings = HuggingFaceEmbeddings( | |
model_name="all-MiniLM-L6-v2", | |
model_kwargs={"token": os.getenv("HF_TOKEN")} # Use environment variable for HuggingFace token | |
) | |
vectordb = Chroma(persist_directory=DB_DIR, embedding_function=embeddings) | |
return vectordb | |
# === Query LLaMA via Groq === | |
def query_llama(context, question): | |
client = Groq(api_key=GROQ_API_KEY) | |
prompt = f"""Use the following context to answer the question: | |
You're a cheerful, funky band member from Curly Strings π»π€πΆ. When fans ask you questions, respond with playful, short, and friendly answers β like you're chatting backstage after a show. | |
π Hereβs how to guide fans: | |
- **π΅ Music**: If they ask for songs or albums, suggest a tune and drop **just one** of these links (rotate each time!): | |
- [Our Music Page](https://www.curlystrings.ee/music/) | |
- [Spotify](https://open.spotify.com/playlist/37i9dQZF1DZ06evO3XF2F2) | |
- [Apple Music](https://music.apple.com/us/artist/curly-strings/888454075) | |
- [YouTube](https://www.youtube.com/@CurlyStringsEstonia) | |
- **π« Tickets & Shows**: If itβs about concerts or dates, suggest one link from: | |
- [Tour Dates](https://www.curlystrings.ee/tour-dates/) | |
- [BandsInTown](https://www.bandsintown.com/a/6429648-curly-strings) | |
- **ποΈ Merch**: If they ask about merch, shirts, or goodies, share: | |
- [Our Merch Store](https://www.311.ee/curly-strings) | |
Important: Always suggest **only one** link at a time. Rotate the links so fans get different suggestions each time! | |
\n\n{context}\n\nQuestion: {question}\nAnswer:""" | |
response = client.chat.completions.create( | |
model=LLM_MODEL, | |
messages=[{"role": "user", "content": prompt}], | |
max_tokens=300 | |
) | |
return response.choices[0].message.content.strip() | |
# === RAG Pipeline === | |
def rag_pipeline(query): | |
if not os.path.exists(DB_DIR): | |
chunks = load_and_split(FILE_PATH) | |
vectordb = get_vector_store(chunks) | |
else: | |
vectordb = load_vector_store() | |
retriever = vectordb.as_retriever(search_kwargs={"k": 4}) | |
docs = retriever.invoke(query) | |
context = "\n\n".join([doc.page_content for doc in docs]) | |
return query_llama(context, query) | |
# === Gradio Interface === | |
def chat_with_bot(question): | |
try: | |
answer = rag_pipeline(question) | |
return answer | |
except Exception as e: | |
return f"Oops! Something went wrong: {e}" | |
# Launch Gradio UI | |
iface = gr.Interface( | |
fn=chat_with_bot, | |
inputs=gr.Textbox(lines=2, placeholder="Ask me anything about Curly Strings π»π€"), | |
outputs="text", | |
title="πΆ Curly Strings Chatbot πΆ", | |
description="Talk to a cheerful band member! Ask about music, shows, or merch." | |
) | |
iface.launch() | |