Spaces:
Running
Running
hoshingakag
commited on
Commit
•
1871bfe
1
Parent(s):
108d5cf
v0.1
Browse files- app.py +52 -49
- requirements.txt +2 -1
- src/llamaindex_palm.py +171 -0
app.py
CHANGED
@@ -1,63 +1,74 @@
|
|
1 |
import os
|
2 |
import time
|
3 |
-
import itertools
|
4 |
-
from typing import Literal
|
5 |
import gradio as gr
|
6 |
import google.generativeai as genai
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
# Credentials
|
9 |
genai.configure(api_key=os.getenv('PALM_API_KEY'))
|
10 |
|
11 |
# Gradio
|
12 |
-
chat_defaults = {
|
13 |
-
'model': 'models/chat-bison-001',
|
14 |
-
'temperature': 0.25,
|
15 |
-
'candidate_count': 1,
|
16 |
-
'top_k': 40,
|
17 |
-
'top_p': 0,
|
18 |
-
}
|
19 |
-
|
20 |
chat_history = []
|
21 |
|
22 |
-
def clear_chat():
|
|
|
23 |
chat_history = []
|
24 |
return None
|
25 |
|
26 |
-
def generate_chat(prompt: str,
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
|
|
|
|
|
|
|
|
36 |
"""
|
37 |
-
|
38 |
-
print(f"User Message:\n{prompt}\n")
|
39 |
-
chat_messages.append(prompt)
|
40 |
-
|
41 |
try:
|
42 |
-
response = genai.
|
43 |
-
|
44 |
-
|
45 |
-
|
|
|
|
|
|
|
|
|
46 |
)
|
47 |
-
result = response.
|
48 |
-
|
49 |
-
result = "Apologies but something went wrong. Please try again later."
|
50 |
-
chat_messages = chat_messages[:-1]
|
51 |
-
else:
|
52 |
-
chat_messages.append(result)
|
53 |
-
|
54 |
except Exception as e:
|
55 |
-
result = "
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
print(f"Bot Message:\n{result}\n")
|
61 |
return result
|
62 |
|
63 |
with gr.Blocks() as app:
|
@@ -74,11 +85,6 @@ with gr.Blocks() as app:
|
|
74 |
placeholder="Hi Gerard, can you introduce yourself?",
|
75 |
container=False,
|
76 |
scale=6)
|
77 |
-
# send = gr.Button(
|
78 |
-
# value="",
|
79 |
-
# icon="./send-message.png",
|
80 |
-
# scale=1
|
81 |
-
# )
|
82 |
clear = gr.Button("Clear")
|
83 |
|
84 |
def user(user_message, history):
|
@@ -95,9 +101,6 @@ with gr.Blocks() as app:
|
|
95 |
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
|
96 |
bot, chatbot, chatbot
|
97 |
)
|
98 |
-
# send.click(user, [msg, chatbot], [msg, chatbot], queue=False).then(
|
99 |
-
# bot, chatbot, chatbot
|
100 |
-
# )
|
101 |
clear.click(clear_chat, None, chatbot, queue=False)
|
102 |
|
103 |
gr.HTML("<p><center>Hosted on 🤗 Spaces. Powered by Google PaLM 🌴</center></p>")
|
|
|
1 |
import os
|
2 |
import time
|
|
|
|
|
3 |
import gradio as gr
|
4 |
import google.generativeai as genai
|
5 |
+
from src.llamaindex_palm import LlamaIndexPaLM
|
6 |
+
|
7 |
+
import logging
|
8 |
+
logging.basicConfig(format='%(asctime)s %(clientip)-15s %(user)-8s %(message)s', level=logging.INFO)
|
9 |
+
logger = logging.getLogger('llm')
|
10 |
+
|
11 |
+
# Llama-Index LLM
|
12 |
+
llm = LlamaIndexPaLM()
|
13 |
+
llm.set_index_from_pinecone()
|
14 |
|
15 |
# Credentials
|
16 |
genai.configure(api_key=os.getenv('PALM_API_KEY'))
|
17 |
|
18 |
# Gradio
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
chat_history = []
|
20 |
|
21 |
+
def clear_chat() -> None:
|
22 |
+
global chat_history
|
23 |
chat_history = []
|
24 |
return None
|
25 |
|
26 |
+
def generate_chat(prompt: str, llamaindex_llm: LlamaIndexPaLM):
|
27 |
+
global chat_history
|
28 |
+
# get chat history
|
29 |
+
context_chat_history = "\n".join(chat_history)
|
30 |
+
|
31 |
+
logger.info("Generating Message...")
|
32 |
+
logger.info(f"User Message:\n{prompt}\n")
|
33 |
+
chat_history.append(prompt)
|
34 |
+
|
35 |
+
# get context
|
36 |
+
context_from_index = llamaindex_llm.generate_response(prompt)
|
37 |
+
logger.info(f"Context from Llama-Index:\n{context_from_index}\n")
|
38 |
+
|
39 |
+
prompt_with_context = f"""
|
40 |
+
Rule:
|
41 |
+
You are in a role play of Gerard Lee and you need to pretend to be him to answer questions from people who interested in Gerard's background.
|
42 |
+
You are going to reply their messages given the context below and without hallucinations. If you don't know the answer, simply say "I have no idea how to answer this question".
|
43 |
|
44 |
+
Chat History:
|
45 |
+
{context_chat_history}
|
46 |
+
|
47 |
+
Context:
|
48 |
+
{context_from_index}
|
49 |
+
|
50 |
+
User Query:
|
51 |
+
{prompt}
|
52 |
"""
|
53 |
+
|
|
|
|
|
|
|
54 |
try:
|
55 |
+
response = genai.generate_text(
|
56 |
+
prompt=prompt_with_context,
|
57 |
+
safety_settings=[
|
58 |
+
{
|
59 |
+
'category': genai.types.HarmCategory.HARM_CATEGORY_UNSPECIFIED,
|
60 |
+
'threshold': genai.types.HarmBlockThreshold.BLOCK_NONE,
|
61 |
+
},
|
62 |
+
]
|
63 |
)
|
64 |
+
result = response.result
|
65 |
+
|
|
|
|
|
|
|
|
|
|
|
66 |
except Exception as e:
|
67 |
+
result = "Seems something went wrong. Please try again later."
|
68 |
+
logger.error(f"Exception {e} occured\n")
|
69 |
+
|
70 |
+
chat_history.append(result)
|
71 |
+
logger.info(f"Bot Message:\n{result}\n")
|
|
|
72 |
return result
|
73 |
|
74 |
with gr.Blocks() as app:
|
|
|
85 |
placeholder="Hi Gerard, can you introduce yourself?",
|
86 |
container=False,
|
87 |
scale=6)
|
|
|
|
|
|
|
|
|
|
|
88 |
clear = gr.Button("Clear")
|
89 |
|
90 |
def user(user_message, history):
|
|
|
101 |
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
|
102 |
bot, chatbot, chatbot
|
103 |
)
|
|
|
|
|
|
|
104 |
clear.click(clear_chat, None, chatbot, queue=False)
|
105 |
|
106 |
gr.HTML("<p><center>Hosted on 🤗 Spaces. Powered by Google PaLM 🌴</center></p>")
|
requirements.txt
CHANGED
@@ -1,2 +1,3 @@
|
|
1 |
google-generativeai
|
2 |
-
|
|
|
|
1 |
google-generativeai
|
2 |
+
llama-index
|
3 |
+
pinecone-client
|
src/llamaindex_palm.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import logging
|
3 |
+
|
4 |
+
from typing import Any, List
|
5 |
+
from pydantic import Extra
|
6 |
+
|
7 |
+
import pinecone
|
8 |
+
import google.generativeai as genai
|
9 |
+
|
10 |
+
from llama_index import (
|
11 |
+
ServiceContext,
|
12 |
+
PromptHelper,
|
13 |
+
VectorStoreIndex
|
14 |
+
)
|
15 |
+
from llama_index.vector_stores import PineconeVectorStore
|
16 |
+
from llama_index.storage.storage_context import StorageContext
|
17 |
+
from llama_index.node_parser import SimpleNodeParser
|
18 |
+
from llama_index.text_splitter import TokenTextSplitter
|
19 |
+
from llama_index.embeddings.base import BaseEmbedding
|
20 |
+
from llama_index.llms import (
|
21 |
+
CustomLLM,
|
22 |
+
CompletionResponse,
|
23 |
+
CompletionResponseGen,
|
24 |
+
LLMMetadata,
|
25 |
+
)
|
26 |
+
from llama_index.llms.base import llm_completion_callback
|
27 |
+
|
28 |
+
class LlamaIndexPaLMEmbeddings(BaseEmbedding, extra=Extra.allow):
|
29 |
+
def __init__(
|
30 |
+
self,
|
31 |
+
model_name: str = 'models/embedding-gecko-001',
|
32 |
+
**kwargs: Any,
|
33 |
+
) -> None:
|
34 |
+
super().__init__(**kwargs)
|
35 |
+
self._model_name = model_name
|
36 |
+
|
37 |
+
@classmethod
|
38 |
+
def class_name(cls) -> str:
|
39 |
+
return 'PaLMEmbeddings'
|
40 |
+
|
41 |
+
def gen_embeddings(self, text: str) -> List[float]:
|
42 |
+
return genai.generate_embeddings(self._model_name, text)
|
43 |
+
|
44 |
+
def _get_query_embedding(self, query: str) -> List[float]:
|
45 |
+
embeddings = self.gen_embeddings(query)
|
46 |
+
return embeddings['embedding']
|
47 |
+
|
48 |
+
def _get_text_embedding(self, text: str) -> List[float]:
|
49 |
+
embeddings = self.gen_embeddings(text)
|
50 |
+
return embeddings['embedding']
|
51 |
+
|
52 |
+
def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
|
53 |
+
embeddings = [
|
54 |
+
self.gen_embeddings(text)['embedding'] for text in texts
|
55 |
+
]
|
56 |
+
return embeddings
|
57 |
+
|
58 |
+
async def _aget_query_embedding(self, query: str) -> List[float]:
|
59 |
+
return self._get_query_embedding(query)
|
60 |
+
|
61 |
+
async def _aget_text_embedding(self, text: str) -> List[float]:
|
62 |
+
return self._get_text_embedding(text)
|
63 |
+
|
64 |
+
class LlamaIndexPaLMText(CustomLLM, extra=Extra.allow):
|
65 |
+
def __init__(
|
66 |
+
self,
|
67 |
+
model_name: str = 'models/text-bison-001',
|
68 |
+
context_window: int = 8196,
|
69 |
+
num_output: int = 1024,
|
70 |
+
**kwargs: Any,
|
71 |
+
) -> None:
|
72 |
+
super().__init__(**kwargs)
|
73 |
+
self._model_name = model_name
|
74 |
+
self._context_window = context_window
|
75 |
+
self._num_output = num_output
|
76 |
+
|
77 |
+
@property
|
78 |
+
def metadata(self) -> LLMMetadata:
|
79 |
+
"""Get LLM metadata."""
|
80 |
+
return LLMMetadata(
|
81 |
+
context_window=self._context_window,
|
82 |
+
num_output=self._num_output,
|
83 |
+
model_name=self._model_name
|
84 |
+
)
|
85 |
+
|
86 |
+
def gen_texts(self, prompt):
|
87 |
+
logging.debug(f"prompt: {prompt}")
|
88 |
+
response = genai.generate_text(
|
89 |
+
model=self._model_name,
|
90 |
+
prompt=prompt,
|
91 |
+
safety_settings=[
|
92 |
+
{
|
93 |
+
'category': genai.types.HarmCategory.HARM_CATEGORY_UNSPECIFIED,
|
94 |
+
'threshold': genai.types.HarmBlockThreshold.BLOCK_NONE,
|
95 |
+
},
|
96 |
+
]
|
97 |
+
)
|
98 |
+
logging.debug(f"response:\n{response}")
|
99 |
+
return response.candidates[0]['output']
|
100 |
+
|
101 |
+
@llm_completion_callback()
|
102 |
+
def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
|
103 |
+
text = self.gen_texts(prompt)
|
104 |
+
return CompletionResponse(text=text)
|
105 |
+
|
106 |
+
@llm_completion_callback()
|
107 |
+
def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
|
108 |
+
raise NotImplementedError()
|
109 |
+
|
110 |
+
class LlamaIndexPaLM():
|
111 |
+
def __init__(
|
112 |
+
self,
|
113 |
+
emb_model: LlamaIndexPaLMEmbeddings = LlamaIndexPaLMEmbeddings(),
|
114 |
+
model: LlamaIndexPaLMText = LlamaIndexPaLMText()
|
115 |
+
) -> None:
|
116 |
+
self.emb_model = emb_model
|
117 |
+
self.llm = model
|
118 |
+
|
119 |
+
# Google Generative AI
|
120 |
+
genai.configure(api_key=os.environ['PALM_API_KEY'])
|
121 |
+
|
122 |
+
# Pinecone
|
123 |
+
pinecone.init(
|
124 |
+
api_key=os.environ['PINECONE_API_KEY'],
|
125 |
+
environment=os.getenv('PINECONE_ENV', 'us-west1-gcp-free')
|
126 |
+
)
|
127 |
+
|
128 |
+
# model metadata
|
129 |
+
CONTEXT_WINDOW = os.getenv('CONTEXT_WINDOW', 8196)
|
130 |
+
NUM_OUTPUT = os.getenv('NUM_OUTPUT', 1024)
|
131 |
+
TEXT_CHUNK_SIZE = os.getenv('TEXT_CHUNK_SIZE', 512)
|
132 |
+
TEXT_CHUNK_OVERLAP = os.getenv('TEXT_CHUNK_OVERLAP', 20)
|
133 |
+
TEXT_CHUNK_OVERLAP_RATIO = os.getenv('TEXT_CHUNK_OVERLAP_RATIO', 0.1)
|
134 |
+
TEXT_CHUNK_SIZE_LIMIT = os.getenv('TEXT_CHUNK_SIZE_LIMIT', None)
|
135 |
+
|
136 |
+
self.node_parser = SimpleNodeParser.from_defaults(
|
137 |
+
text_splitter=TokenTextSplitter(
|
138 |
+
chunk_size=TEXT_CHUNK_SIZE, chunk_overlap=TEXT_CHUNK_OVERLAP
|
139 |
+
)
|
140 |
+
)
|
141 |
+
|
142 |
+
self.prompt_helper = PromptHelper(
|
143 |
+
context_window=CONTEXT_WINDOW,
|
144 |
+
num_output=NUM_OUTPUT,
|
145 |
+
chunk_overlap_ratio=TEXT_CHUNK_OVERLAP_RATIO,
|
146 |
+
chunk_size_limit=TEXT_CHUNK_SIZE_LIMIT
|
147 |
+
)
|
148 |
+
|
149 |
+
self.service_context = ServiceContext.from_defaults(
|
150 |
+
llm=self.llm,
|
151 |
+
embed_model=self.embed_model,
|
152 |
+
node_parser=self.node_parser,
|
153 |
+
prompt_helper=self.prompt_helper,
|
154 |
+
)
|
155 |
+
|
156 |
+
def set_index_from_pinecone(
|
157 |
+
self,
|
158 |
+
index_name: str = 'experience'
|
159 |
+
) -> None:
|
160 |
+
# Pinecone VectorStore
|
161 |
+
pinecone_index = pinecone.Index(index_name)
|
162 |
+
self.vector_store = PineconeVectorStore(pinecone_index=pinecone_index, add_sparse_vector=True)
|
163 |
+
self.pinecone_index = VectorStoreIndex.from_vector_store(self.vector_store, self.service_context)
|
164 |
+
return None
|
165 |
+
|
166 |
+
def generate_response(
|
167 |
+
self,
|
168 |
+
query: str
|
169 |
+
) -> str:
|
170 |
+
response = self.pinecone_index.as_query_engine().query(query)
|
171 |
+
return response.response
|