File size: 3,322 Bytes
9ccd468
22cfb6e
43a8cd8
 
 
 
 
bda7944
 
 
 
8838db8
43a8cd8
bda7944
 
fa72820
 
bda7944
 
 
 
 
 
 
22cfb6e
43a8cd8
9cc7e25
 
43a8cd8
9cc7e25
 
 
43a8cd8
 
 
 
5ecd97e
9cc7e25
43a8cd8
 
 
9cc7e25
43a8cd8
9cc7e25
8324d73
43a8cd8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bda7944
 
43a8cd8
bda7944
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43a8cd8
99cdb28
9ccd468
 
43a8cd8
 
 
bda7944
43a8cd8
9ccd468
 
 
 
43a8cd8
8324d73
43a8cd8
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import os
import gradio as gr
import copy
from llama_cpp import Llama
from huggingface_hub import hf_hub_download
import chromadb
from sentence_transformers import SentenceTransformer
import logging

# Initialize logging
logging.basicConfig(level=logging.INFO)

# Initialize the Llama model
try:
    llm = Llama(
        # model_path="./models/Phi-3-mini-4k-instruct-gguf",
        model_path = "Ankitajadhav/Phi-3-mini-4k-instruct-q4.gguf"
        n_ctx=2048,
        n_gpu_layers=50,  # Adjust based on your VRAM
    )
    logging.info("Llama model loaded successfully.")
except Exception as e:
    logging.error(f"Error loading Llama model: {e}")
    raise

# Initialize ChromaDB Vector Store
class VectorStore:
    def __init__(self, collection_name):
        self.embedding_model = SentenceTransformer('sentence-transformers/multi-qa-MiniLM-L6-cos-v1')
        self.chroma_client = chromadb.Client()
        self.collection = self.chroma_client.create_collection(name=collection_name)

    def populate_vectors(self, texts, ids):
        embeddings = self.embedding_model.encode(texts, batch_size=32).tolist()
        for text, embedding, doc_id in zip(texts, embeddings, ids):
            self.collection.add(embeddings=[embedding], documents=[text], ids=[doc_id])

    def search_context(self, query, n_results=1):
        query_embedding = self.embedding_model.encode([query]).tolist()
        results = self.collection.query(query_embeddings=query_embedding, n_results=n_results)
        return results['documents']

# Example initialization (assuming you've already populated the vector store)
vector_store = VectorStore("embedding_vector")

def generate_text(
    message,
    history: list[tuple[str, str]],
    system_message,
    max_tokens,
    temperature,
    top_p,
):
    # Retrieve context from vector store
    context_results = vector_store.search_context(message, n_results=1)
    context = context_results[0] if context_results else ""

    input_prompt = f"[INST] <<SYS>>\n{system_message}\n<</SYS>>\n\n {context}\n"
    for interaction in history:
        input_prompt += f"{interaction[0]} [/INST] {interaction[1]} </s><s> [INST] "
    input_prompt += f"{message} [/INST] "

    logging.info("Input prompt:\n%s", input_prompt)  # Debugging output

    temp = ""
    try:
        output = llm(
            input_prompt,
            temperature=temperature,
            top_p=top_p,
            top_k=40,
            repeat_penalty=1.1,
            max_tokens=max_tokens,
            stop=["", " \n", "ASSISTANT:", "USER:", "SYSTEM:"],
            stream=True,
        )
        for out in output:
            temp += out["choices"][0]["text"]
            logging.info("Model output:\n%s", temp)  # Log model output
            yield temp
    except Exception as e:
        logging.error(f"Error during text generation: {e}")
        yield "An error occurred during text generation."

# Define the Gradio interface
demo = gr.ChatInterface(
    generate_text,
    examples=[
        ["I have leftover rice, what can I make out of it?"],
        ["Can I make lunch for two people with this?"],
        ["Some good dessert with leftover cake"]
    ],
    cache_examples=False,
    retry_btn=None,
    undo_btn="Delete Previous",
    clear_btn="Clear",
)

if __name__ == "__main__":
    demo.launch()