hoshingakag commited on
Commit
1871bfe
1 Parent(s): 108d5cf
Files changed (3) hide show
  1. app.py +52 -49
  2. requirements.txt +2 -1
  3. src/llamaindex_palm.py +171 -0
app.py CHANGED
@@ -1,63 +1,74 @@
1
  import os
2
  import time
3
- import itertools
4
- from typing import Literal
5
  import gradio as gr
6
  import google.generativeai as genai
 
 
 
 
 
 
 
 
 
7
 
8
  # Credentials
9
  genai.configure(api_key=os.getenv('PALM_API_KEY'))
10
 
11
  # Gradio
12
- chat_defaults = {
13
- 'model': 'models/chat-bison-001',
14
- 'temperature': 0.25,
15
- 'candidate_count': 1,
16
- 'top_k': 40,
17
- 'top_p': 0,
18
- }
19
-
20
  chat_history = []
21
 
22
- def clear_chat():
 
23
  chat_history = []
24
  return None
25
 
26
- def generate_chat(prompt: str, chat_messages=chat_history):
27
- print(chat_messages)
28
- context = """
29
- You are a proxy to Gerard Lee to answer questions as himself.
30
- I am chatting with someone who might interest in my background. I am going to reply their messages given the context below and without hallucinations.
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
- Context: '
33
- I am a data enthusiast with more than 5 years of experience on data analytics domain. Currently working under DBS as a data scientist.
34
- I drive NLP and ML use cases also lead 2 contract data analysts to deliver analytical solutions like developing Tableau dashboards.
35
- '
 
 
 
 
36
  """
37
- print("Generating Chat Message...")
38
- print(f"User Message:\n{prompt}\n")
39
- chat_messages.append(prompt)
40
-
41
  try:
42
- response = genai.chat(
43
- **chat_defaults,
44
- context=context,
45
- messages=chat_messages
 
 
 
 
46
  )
47
- result = response.last
48
- if result is None:
49
- result = "Apologies but something went wrong. Please try again later."
50
- chat_messages = chat_messages[:-1]
51
- else:
52
- chat_messages.append(result)
53
-
54
  except Exception as e:
55
- result = "Apologies but something went wrong. Please try again later."
56
- chat_messages = chat_messages[:-1]
57
- print(f"Exception {e} occured\n")
58
-
59
- chat_history = chat_messages
60
- print(f"Bot Message:\n{result}\n")
61
  return result
62
 
63
  with gr.Blocks() as app:
@@ -74,11 +85,6 @@ with gr.Blocks() as app:
74
  placeholder="Hi Gerard, can you introduce yourself?",
75
  container=False,
76
  scale=6)
77
- # send = gr.Button(
78
- # value="",
79
- # icon="./send-message.png",
80
- # scale=1
81
- # )
82
  clear = gr.Button("Clear")
83
 
84
  def user(user_message, history):
@@ -95,9 +101,6 @@ with gr.Blocks() as app:
95
  msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
96
  bot, chatbot, chatbot
97
  )
98
- # send.click(user, [msg, chatbot], [msg, chatbot], queue=False).then(
99
- # bot, chatbot, chatbot
100
- # )
101
  clear.click(clear_chat, None, chatbot, queue=False)
102
 
103
  gr.HTML("<p><center>Hosted on 🤗 Spaces. Powered by Google PaLM 🌴</center></p>")
 
1
  import os
2
  import time
 
 
3
  import gradio as gr
4
  import google.generativeai as genai
5
+ from src.llamaindex_palm import LlamaIndexPaLM
6
+
7
+ import logging
8
+ logging.basicConfig(format='%(asctime)s %(clientip)-15s %(user)-8s %(message)s', level=logging.INFO)
9
+ logger = logging.getLogger('llm')
10
+
11
+ # Llama-Index LLM
12
+ llm = LlamaIndexPaLM()
13
+ llm.set_index_from_pinecone()
14
 
15
  # Credentials
16
  genai.configure(api_key=os.getenv('PALM_API_KEY'))
17
 
18
  # Gradio
 
 
 
 
 
 
 
 
19
  chat_history = []
20
 
21
+ def clear_chat() -> None:
22
+ global chat_history
23
  chat_history = []
24
  return None
25
 
26
+ def generate_chat(prompt: str, llamaindex_llm: LlamaIndexPaLM):
27
+ global chat_history
28
+ # get chat history
29
+ context_chat_history = "\n".join(chat_history)
30
+
31
+ logger.info("Generating Message...")
32
+ logger.info(f"User Message:\n{prompt}\n")
33
+ chat_history.append(prompt)
34
+
35
+ # get context
36
+ context_from_index = llamaindex_llm.generate_response(prompt)
37
+ logger.info(f"Context from Llama-Index:\n{context_from_index}\n")
38
+
39
+ prompt_with_context = f"""
40
+ Rule:
41
+ You are in a role play of Gerard Lee and you need to pretend to be him to answer questions from people who interested in Gerard's background.
42
+ You are going to reply their messages given the context below and without hallucinations. If you don't know the answer, simply say "I have no idea how to answer this question".
43
 
44
+ Chat History:
45
+ {context_chat_history}
46
+
47
+ Context:
48
+ {context_from_index}
49
+
50
+ User Query:
51
+ {prompt}
52
  """
53
+
 
 
 
54
  try:
55
+ response = genai.generate_text(
56
+ prompt=prompt_with_context,
57
+ safety_settings=[
58
+ {
59
+ 'category': genai.types.HarmCategory.HARM_CATEGORY_UNSPECIFIED,
60
+ 'threshold': genai.types.HarmBlockThreshold.BLOCK_NONE,
61
+ },
62
+ ]
63
  )
64
+ result = response.result
65
+
 
 
 
 
 
66
  except Exception as e:
67
+ result = "Seems something went wrong. Please try again later."
68
+ logger.error(f"Exception {e} occured\n")
69
+
70
+ chat_history.append(result)
71
+ logger.info(f"Bot Message:\n{result}\n")
 
72
  return result
73
 
74
  with gr.Blocks() as app:
 
85
  placeholder="Hi Gerard, can you introduce yourself?",
86
  container=False,
87
  scale=6)
 
 
 
 
 
88
  clear = gr.Button("Clear")
89
 
90
  def user(user_message, history):
 
101
  msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
102
  bot, chatbot, chatbot
103
  )
 
 
 
104
  clear.click(clear_chat, None, chatbot, queue=False)
105
 
106
  gr.HTML("<p><center>Hosted on 🤗 Spaces. Powered by Google PaLM 🌴</center></p>")
requirements.txt CHANGED
@@ -1,2 +1,3 @@
1
  google-generativeai
2
- langchain
 
 
1
  google-generativeai
2
+ llama-index
3
+ pinecone-client
src/llamaindex_palm.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+
4
+ from typing import Any, List
5
+ from pydantic import Extra
6
+
7
+ import pinecone
8
+ import google.generativeai as genai
9
+
10
+ from llama_index import (
11
+ ServiceContext,
12
+ PromptHelper,
13
+ VectorStoreIndex
14
+ )
15
+ from llama_index.vector_stores import PineconeVectorStore
16
+ from llama_index.storage.storage_context import StorageContext
17
+ from llama_index.node_parser import SimpleNodeParser
18
+ from llama_index.text_splitter import TokenTextSplitter
19
+ from llama_index.embeddings.base import BaseEmbedding
20
+ from llama_index.llms import (
21
+ CustomLLM,
22
+ CompletionResponse,
23
+ CompletionResponseGen,
24
+ LLMMetadata,
25
+ )
26
+ from llama_index.llms.base import llm_completion_callback
27
+
28
+ class LlamaIndexPaLMEmbeddings(BaseEmbedding, extra=Extra.allow):
29
+ def __init__(
30
+ self,
31
+ model_name: str = 'models/embedding-gecko-001',
32
+ **kwargs: Any,
33
+ ) -> None:
34
+ super().__init__(**kwargs)
35
+ self._model_name = model_name
36
+
37
+ @classmethod
38
+ def class_name(cls) -> str:
39
+ return 'PaLMEmbeddings'
40
+
41
+ def gen_embeddings(self, text: str) -> List[float]:
42
+ return genai.generate_embeddings(self._model_name, text)
43
+
44
+ def _get_query_embedding(self, query: str) -> List[float]:
45
+ embeddings = self.gen_embeddings(query)
46
+ return embeddings['embedding']
47
+
48
+ def _get_text_embedding(self, text: str) -> List[float]:
49
+ embeddings = self.gen_embeddings(text)
50
+ return embeddings['embedding']
51
+
52
+ def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
53
+ embeddings = [
54
+ self.gen_embeddings(text)['embedding'] for text in texts
55
+ ]
56
+ return embeddings
57
+
58
+ async def _aget_query_embedding(self, query: str) -> List[float]:
59
+ return self._get_query_embedding(query)
60
+
61
+ async def _aget_text_embedding(self, text: str) -> List[float]:
62
+ return self._get_text_embedding(text)
63
+
64
+ class LlamaIndexPaLMText(CustomLLM, extra=Extra.allow):
65
+ def __init__(
66
+ self,
67
+ model_name: str = 'models/text-bison-001',
68
+ context_window: int = 8196,
69
+ num_output: int = 1024,
70
+ **kwargs: Any,
71
+ ) -> None:
72
+ super().__init__(**kwargs)
73
+ self._model_name = model_name
74
+ self._context_window = context_window
75
+ self._num_output = num_output
76
+
77
+ @property
78
+ def metadata(self) -> LLMMetadata:
79
+ """Get LLM metadata."""
80
+ return LLMMetadata(
81
+ context_window=self._context_window,
82
+ num_output=self._num_output,
83
+ model_name=self._model_name
84
+ )
85
+
86
+ def gen_texts(self, prompt):
87
+ logging.debug(f"prompt: {prompt}")
88
+ response = genai.generate_text(
89
+ model=self._model_name,
90
+ prompt=prompt,
91
+ safety_settings=[
92
+ {
93
+ 'category': genai.types.HarmCategory.HARM_CATEGORY_UNSPECIFIED,
94
+ 'threshold': genai.types.HarmBlockThreshold.BLOCK_NONE,
95
+ },
96
+ ]
97
+ )
98
+ logging.debug(f"response:\n{response}")
99
+ return response.candidates[0]['output']
100
+
101
+ @llm_completion_callback()
102
+ def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
103
+ text = self.gen_texts(prompt)
104
+ return CompletionResponse(text=text)
105
+
106
+ @llm_completion_callback()
107
+ def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
108
+ raise NotImplementedError()
109
+
110
+ class LlamaIndexPaLM():
111
+ def __init__(
112
+ self,
113
+ emb_model: LlamaIndexPaLMEmbeddings = LlamaIndexPaLMEmbeddings(),
114
+ model: LlamaIndexPaLMText = LlamaIndexPaLMText()
115
+ ) -> None:
116
+ self.emb_model = emb_model
117
+ self.llm = model
118
+
119
+ # Google Generative AI
120
+ genai.configure(api_key=os.environ['PALM_API_KEY'])
121
+
122
+ # Pinecone
123
+ pinecone.init(
124
+ api_key=os.environ['PINECONE_API_KEY'],
125
+ environment=os.getenv('PINECONE_ENV', 'us-west1-gcp-free')
126
+ )
127
+
128
+ # model metadata
129
+ CONTEXT_WINDOW = os.getenv('CONTEXT_WINDOW', 8196)
130
+ NUM_OUTPUT = os.getenv('NUM_OUTPUT', 1024)
131
+ TEXT_CHUNK_SIZE = os.getenv('TEXT_CHUNK_SIZE', 512)
132
+ TEXT_CHUNK_OVERLAP = os.getenv('TEXT_CHUNK_OVERLAP', 20)
133
+ TEXT_CHUNK_OVERLAP_RATIO = os.getenv('TEXT_CHUNK_OVERLAP_RATIO', 0.1)
134
+ TEXT_CHUNK_SIZE_LIMIT = os.getenv('TEXT_CHUNK_SIZE_LIMIT', None)
135
+
136
+ self.node_parser = SimpleNodeParser.from_defaults(
137
+ text_splitter=TokenTextSplitter(
138
+ chunk_size=TEXT_CHUNK_SIZE, chunk_overlap=TEXT_CHUNK_OVERLAP
139
+ )
140
+ )
141
+
142
+ self.prompt_helper = PromptHelper(
143
+ context_window=CONTEXT_WINDOW,
144
+ num_output=NUM_OUTPUT,
145
+ chunk_overlap_ratio=TEXT_CHUNK_OVERLAP_RATIO,
146
+ chunk_size_limit=TEXT_CHUNK_SIZE_LIMIT
147
+ )
148
+
149
+ self.service_context = ServiceContext.from_defaults(
150
+ llm=self.llm,
151
+ embed_model=self.embed_model,
152
+ node_parser=self.node_parser,
153
+ prompt_helper=self.prompt_helper,
154
+ )
155
+
156
+ def set_index_from_pinecone(
157
+ self,
158
+ index_name: str = 'experience'
159
+ ) -> None:
160
+ # Pinecone VectorStore
161
+ pinecone_index = pinecone.Index(index_name)
162
+ self.vector_store = PineconeVectorStore(pinecone_index=pinecone_index, add_sparse_vector=True)
163
+ self.pinecone_index = VectorStoreIndex.from_vector_store(self.vector_store, self.service_context)
164
+ return None
165
+
166
+ def generate_response(
167
+ self,
168
+ query: str
169
+ ) -> str:
170
+ response = self.pinecone_index.as_query_engine().query(query)
171
+ return response.response