import os import logging from typing import Any, List from pydantic import Extra import pinecone import google.generativeai as genai from llama_index import ( ServiceContext, PromptHelper, VectorStoreIndex ) from llama_index.vector_stores import PineconeVectorStore from llama_index.storage.storage_context import StorageContext from llama_index.node_parser import SimpleNodeParser from llama_index.text_splitter import TokenTextSplitter from llama_index.embeddings.base import BaseEmbedding from llama_index.llms import ( CustomLLM, CompletionResponse, CompletionResponseGen, LLMMetadata, ) from llama_index.llms.base import llm_completion_callback class LlamaIndexPaLMEmbeddings(BaseEmbedding, extra=Extra.allow): def __init__( self, model_name: str = 'models/embedding-gecko-001', **kwargs: Any, ) -> None: super().__init__(**kwargs) self._model_name = model_name @classmethod def class_name(cls) -> str: return 'PaLMEmbeddings' def gen_embeddings(self, text: str) -> List[float]: return genai.generate_embeddings(self._model_name, text) def _get_query_embedding(self, query: str) -> List[float]: embeddings = self.gen_embeddings(query) return embeddings['embedding'] def _get_text_embedding(self, text: str) -> List[float]: embeddings = self.gen_embeddings(text) return embeddings['embedding'] def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]: embeddings = [ self.gen_embeddings(text)['embedding'] for text in texts ] return embeddings async def _aget_query_embedding(self, query: str) -> List[float]: return self._get_query_embedding(query) async def _aget_text_embedding(self, text: str) -> List[float]: return self._get_text_embedding(text) class LlamaIndexPaLMText(CustomLLM, extra=Extra.allow): def __init__( self, model_name: str = 'models/text-bison-001', context_window: int = 8196, num_output: int = 1024, **kwargs: Any, ) -> None: super().__init__(**kwargs) self._model_name = model_name self._context_window = context_window self._num_output = num_output @property def metadata(self) -> LLMMetadata: """Get LLM metadata.""" return LLMMetadata( context_window=self._context_window, num_output=self._num_output, model_name=self._model_name ) def gen_texts(self, prompt): logging.debug(f"prompt: {prompt}") response = genai.generate_text( model=self._model_name, prompt=prompt, safety_settings=[ { 'category': genai.types.HarmCategory.HARM_CATEGORY_UNSPECIFIED, 'threshold': genai.types.HarmBlockThreshold.BLOCK_NONE, }, ] ) logging.debug(f"response:\n{response}") return response.candidates[0]['output'] @llm_completion_callback() def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse: text = self.gen_texts(prompt) return CompletionResponse(text=text) @llm_completion_callback() def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen: raise NotImplementedError() class LlamaIndexPaLM(): def __init__( self, emb_model: LlamaIndexPaLMEmbeddings = LlamaIndexPaLMEmbeddings(), model: LlamaIndexPaLMText = LlamaIndexPaLMText() ) -> None: self.emb_model = emb_model self.llm = model # Google Generative AI genai.configure(api_key=os.environ['PALM_API_KEY']) # Pinecone pinecone.init( api_key=os.environ['PINECONE_API_KEY'], environment=os.getenv('PINECONE_ENV', 'us-west1-gcp-free') ) # model metadata CONTEXT_WINDOW = os.getenv('CONTEXT_WINDOW', 8196) NUM_OUTPUT = os.getenv('NUM_OUTPUT', 1024) TEXT_CHUNK_SIZE = os.getenv('TEXT_CHUNK_SIZE', 512) TEXT_CHUNK_OVERLAP = os.getenv('TEXT_CHUNK_OVERLAP', 20) TEXT_CHUNK_OVERLAP_RATIO = os.getenv('TEXT_CHUNK_OVERLAP_RATIO', 0.1) TEXT_CHUNK_SIZE_LIMIT = os.getenv('TEXT_CHUNK_SIZE_LIMIT', None) self.node_parser = SimpleNodeParser.from_defaults( text_splitter=TokenTextSplitter( chunk_size=TEXT_CHUNK_SIZE, chunk_overlap=TEXT_CHUNK_OVERLAP ) ) self.prompt_helper = PromptHelper( context_window=CONTEXT_WINDOW, num_output=NUM_OUTPUT, chunk_overlap_ratio=TEXT_CHUNK_OVERLAP_RATIO, chunk_size_limit=TEXT_CHUNK_SIZE_LIMIT ) self.service_context = ServiceContext.from_defaults( llm=self.llm, embed_model=self.emb_model, node_parser=self.node_parser, prompt_helper=self.prompt_helper, ) def set_index_from_pinecone( self, index_name: str = 'experience' ) -> None: # Pinecone VectorStore pinecone_index = pinecone.Index(index_name) self.vector_store = PineconeVectorStore(pinecone_index=pinecone_index, add_sparse_vector=True) self.pinecone_index = VectorStoreIndex.from_vector_store(self.vector_store, self.service_context) return None def generate_response( self, query: str ) -> str: response = self.pinecone_index.as_query_engine().query(query) return response.response