File size: 3,420 Bytes
4997aeb
 
da863bf
3ae066d
 
4997aeb
0fb4cf5
4997aeb
8eb3e51
4997aeb
2da6f20
4997aeb
 
2da6f20
4997aeb
2da6f20
 
4997aeb
2da6f20
 
 
4997aeb
 
 
 
 
 
 
8ce3d9b
4997aeb
 
 
 
8ce3d9b
4997aeb
 
9eb3e78
11bc07e
9eb3e78
dcca063
 
9eb3e78
 
 
 
 
79497a3
c9dd21c
9eb3e78
2376b2f
9eb3e78
 
c9dd21c
9eb3e78
 
 
 
 
 
 
 
c9dd21c
9eb3e78
 
 
 
 
c9dd21c
9eb3e78
 
dcca063
3ae066d
9eb3e78
 
8ce3d9b
9eb3e78
 
5d2299c
9eb3e78
8ce3d9b
 
 
 
9eb3e78
8ce3d9b
 
 
a37e99b
8ce3d9b
 
 
 
 
 
 
 
 
 
 
 
 
9eb3e78
 
 
dcca063
 
 
9eb3e78
dcca063
 
9eb3e78
dcca063
 
 
9eb3e78
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import streamlit as st
import os
from streamlit_chat import message
import numpy as np
import pandas as pd

# st.config(PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION="python")

# from datasets import load_dataset

# dataset = load_dataset("wikipedia", "20220301.en", split="train[240000:250000]")


# wikidata = []

# for record in dataset:
#     wikidata.append(record["text"])

# wikidata = list(set(wikidata))
# # print("\n".join(wikidata[:5]))
# # print(len(wikidata))

from sentence_transformers import SentenceTransformer
import torch

device = 'cuda' if torch.cuda.is_available() else 'cpu'

if device != 'cuda':
    st.text(f"you are using {device}. This is much slower than using "
    "a CUDA-enabled GPU. If on colab you can chnage this by "
    "clicking Runtime > change runtime type > GPU.")

model = SentenceTransformer("all-MiniLM-L6-v2", device=device)
st.divider()

# Creating a Index(Pinecone Vector Database)
import os
# import pinecone

from pinecone.grpc import PineconeGRPC


PINECONE_API_KEY=os.getenv("PINECONE_API_KEY")
PINECONE_ENV=os.getenv("PINECONE_ENV")
PINECONE_ENVIRONMENT=os.getenv("PINECONE_ENVIRONMENT")

# pc = PineconeGRPC( api_key=os.environ.get("PINECONE_API_KEY") ) # Now do stuff if 'my_index' not in pc.list_indexes().names(): pc.create_index( name='my_index', dimension=1536, metric='euclidean', spec=ServerlessSpec( cloud='aws', region='us-west-2' ) )

def connect_pinecone():
    pinecone = PineconeGRPC(api_key=PINECONE_API_KEY, environment=PINECONE_ENV)
    st.code(pinecone)
    st.divider()
    st.text(pinecone.list_indexes().names())
    st.divider()
    st.text(f"Succesfully connected to the pinecone")
    return pinecone

def get_pinecone_semantic_index(pinecone):
    index_name = "sematic-search"

    # only create if it deosnot exists
    if index_name not in pinecone.list_indexes().names():
        pinecone.create_index(
            name=index_name,
            description="Semantic search",
            dimension=model.get_sentence_embedding_dimension(),
            metric="cosine",
            spec=ServerlessSpec( cloud='gcp', region='us-central1' )
        )
    # now connect to index
    index = PineconeGRPC(index_name)
    st.text(f"Succesfully connected to the pinecone index")
    return index

def chat_actions():
    
    pinecone = connect_pinecone()
    index = get_pinecone_semantic_index(pinecone)

    st.session_state["chat_history"].append(
        {"role": "user", "content": st.session_state["chat_input"]},
    )

    response = model.encode(st.session_state["chat_input"])
    st.session_state["chat_history"].append(
        {
            "role": "assistant",
            "content": pd.DataFrame(response),
        },  # This can be replaced with your chat response logic
    )


if "chat_history" not in st.session_state:
    st.session_state["chat_history"] = []


st.chat_input("Enter your message", on_submit=chat_actions, key="chat_input")

for i in st.session_state["chat_history"]:
    with st.chat_message(name=i["role"]):
        st.write(i["content"])

### Creating a Index(Pinecone Vector Database)
# %%writefile .env
# PINECONE_API_KEY=os.getenv("PINECONE_API_KEY")
# PINECONE_ENV=os.getenv("PINECONE_ENV")
# PINECONE_ENVIRONMENT=os.getenv("PINECONE_ENVIRONMENT")

# import os
# import pinecone

# from pinecone import Index, GRPCIndex
# pinecone.init(api_key=PINECONE_API_KEY, environment=PINECONE_ENV)
# st.text(pinecone)