File size: 3,794 Bytes
4997aeb
 
da863bf
3ae066d
 
9c2d532
4997aeb
0fb4cf5
4997aeb
8eb3e51
4997aeb
2da6f20
4997aeb
 
2da6f20
4997aeb
2da6f20
 
4997aeb
2da6f20
 
 
4997aeb
 
 
 
 
 
 
8ce3d9b
4997aeb
 
 
 
8ce3d9b
4997aeb
 
9eb3e78
11bc07e
9eb3e78
dcca063
 
9eb3e78
 
 
 
 
79497a3
c9dd21c
9eb3e78
2376b2f
9eb3e78
 
c9dd21c
9eb3e78
 
 
 
 
 
 
 
c9dd21c
9eb3e78
 
 
 
 
c9dd21c
9eb3e78
 
f76455a
3ae066d
9eb3e78
 
8ce3d9b
9eb3e78
 
5d2299c
9eb3e78
8ce3d9b
 
 
 
dcb00f7
 
3d67d69
dcb00f7
 
909aec0
 
 
 
 
 
 
 
 
 
8ce3d9b
 
 
 
 
 
 
 
 
 
 
9eb3e78
 
 
dcca063
 
 
9eb3e78
dcca063
 
9eb3e78
dcca063
 
 
9eb3e78
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import streamlit as st
import os
from streamlit_chat import message
import numpy as np
import pandas as pd
# import json

# st.config(PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION="python")

# from datasets import load_dataset

# dataset = load_dataset("wikipedia", "20220301.en", split="train[240000:250000]")


# wikidata = []

# for record in dataset:
#     wikidata.append(record["text"])

# wikidata = list(set(wikidata))
# # print("\n".join(wikidata[:5]))
# # print(len(wikidata))

from sentence_transformers import SentenceTransformer
import torch

device = 'cuda' if torch.cuda.is_available() else 'cpu'

if device != 'cuda':
    st.text(f"you are using {device}. This is much slower than using "
    "a CUDA-enabled GPU. If on colab you can chnage this by "
    "clicking Runtime > change runtime type > GPU.")

model = SentenceTransformer("all-MiniLM-L6-v2", device=device)
st.divider()

# Creating a Index(Pinecone Vector Database)
import os
# import pinecone

from pinecone.grpc import PineconeGRPC


PINECONE_API_KEY=os.getenv("PINECONE_API_KEY")
PINECONE_ENV=os.getenv("PINECONE_ENV")
PINECONE_ENVIRONMENT=os.getenv("PINECONE_ENVIRONMENT")

# pc = PineconeGRPC( api_key=os.environ.get("PINECONE_API_KEY") ) # Now do stuff if 'my_index' not in pc.list_indexes().names(): pc.create_index( name='my_index', dimension=1536, metric='euclidean', spec=ServerlessSpec( cloud='aws', region='us-west-2' ) )

def connect_pinecone():
    pinecone = PineconeGRPC(api_key=PINECONE_API_KEY, environment=PINECONE_ENV)
    st.code(pinecone)
    st.divider()
    st.text(pinecone.list_indexes().names())
    st.divider()
    st.text(f"Succesfully connected to the pinecone")
    return pinecone

def get_pinecone_semantic_index(pinecone):
    index_name = "sematic-search"

    # only create if it deosnot exists
    if index_name not in pinecone.list_indexes().names():
        pinecone.create_index(
            name=index_name,
            description="Semantic search",
            dimension=model.get_sentence_embedding_dimension(),
            metric="cosine",
            spec=ServerlessSpec( cloud='gcp', region='us-central1' )
        )
    # now connect to index
    index = pinecone.Index(index_name)
    st.text(f"Succesfully connected to the pinecone index")
    return index

def chat_actions():
    
    pinecone = connect_pinecone()
    index = get_pinecone_semantic_index(pinecone)

    st.session_state["chat_history"].append(
        {"role": "user", "content": st.session_state["chat_input"]},
    )

    query_embedding = model.encode(st.session_state["chat_input"])
    # create the query vector
    query_vector = query_embedding.tolist()
    # now query vector database
    result = index.query(query_vector, top_k=5, include_metadata=True)  # xc is a list of tuples
    with st.sidebar:
        st.json(result)
        
    for result in xc['matches']:
        st.session_state["chat_history"].append(
            {
                "role": "assistant",
                "content": f"{round(result['score'],2)}: {result['metadata']['text']}",
            },  # This can be replaced with your chat response logic
        )


if "chat_history" not in st.session_state:
    st.session_state["chat_history"] = []


st.chat_input("Enter your message", on_submit=chat_actions, key="chat_input")

for i in st.session_state["chat_history"]:
    with st.chat_message(name=i["role"]):
        st.write(i["content"])

### Creating a Index(Pinecone Vector Database)
# %%writefile .env
# PINECONE_API_KEY=os.getenv("PINECONE_API_KEY")
# PINECONE_ENV=os.getenv("PINECONE_ENV")
# PINECONE_ENVIRONMENT=os.getenv("PINECONE_ENVIRONMENT")

# import os
# import pinecone

# from pinecone import Index, GRPCIndex
# pinecone.init(api_key=PINECONE_API_KEY, environment=PINECONE_ENV)
# st.text(pinecone)