petrojm commited on
Commit
a6c26b1
·
1 Parent(s): 08cf07c

add EKR files

Browse files
app.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import logging
4
+ import yaml
5
+ import gradio as gr
6
+ import time
7
+
8
+ current_dir = os.path.dirname(os.path.abspath(__file__))
9
+ print(current_dir)
10
+
11
+ from src.document_retrieval import DocumentRetrieval
12
+ from utils.visual.env_utils import env_input_fields, initialize_env_variables, are_credentials_set, save_credentials
13
+ from utils.parsing.sambaparse import parse_doc_universal # added Petro
14
+ from utils.vectordb.vector_db import VectorDb
15
+
16
+ CONFIG_PATH = os.path.join(current_dir,'config.yaml')
17
+ PERSIST_DIRECTORY = os.path.join(current_dir,f"data/my-vector-db") # changed to current_dir
18
+
19
+ logging.basicConfig(level=logging.INFO)
20
+ logging.info("Gradio app is running")
21
+
22
+ class ChatState:
23
+ def __init__(self):
24
+ self.conversation = None
25
+ self.chat_history = []
26
+ self.show_sources = True
27
+ self.sources_history = []
28
+ self.vectorstore = None
29
+ self.input_disabled = True
30
+ self.document_retrieval = None
31
+
32
+ chat_state = ChatState()
33
+
34
+ chat_state.document_retrieval = DocumentRetrieval()
35
+
36
+ def handle_userinput(user_question):
37
+ if user_question:
38
+ try:
39
+ response_time = time.time()
40
+ response = chat_state.conversation.invoke({"question": user_question})
41
+ response_time = time.time() - response_time
42
+ chat_state.chat_history.append((user_question, response["answer"]))
43
+
44
+ #sources = set([f'{sd.metadata["filename"]}' for sd in response["source_documents"]])
45
+ #sources_text = "\n".join([f"{i+1}. {source}" for i, source in enumerate(sources)])
46
+ #state.sources_history.append(sources_text)
47
+
48
+ return chat_state.chat_history, "" #, state.sources_history
49
+ except Exception as e:
50
+ return f"An error occurred: {str(e)}", "" #, state.sources_history
51
+ return chat_state.chat_history, "" #, state.sources_history
52
+
53
+ def process_documents(files, save_location=None):
54
+ try:
55
+ #for doc in files:
56
+ _, _, text_chunks = parse_doc_universal(doc=files)
57
+ print(text_chunks)
58
+ #text_chunks = chat_state.document_retrieval.parse_doc(files)
59
+ embeddings = chat_state.document_retrieval.load_embedding_model()
60
+ collection_name = 'ekr_default_collection' if not config['prod_mode'] else None
61
+ vectorstore = chat_state.document_retrieval.create_vector_store(text_chunks, embeddings, output_db=save_location, collection_name=collection_name)
62
+ chat_state.vectorstore = vectorstore
63
+ chat_state.document_retrieval.init_retriever(vectorstore)
64
+ chat_state.conversation = chat_state.document_retrieval.get_qa_retrieval_chain()
65
+ chat_state.input_disabled = False
66
+ return "Documents processed successfully. You can now ask questions."
67
+ except Exception as e:
68
+ return f"An error occurred while processing: {str(e)}"
69
+
70
+ def reset_conversation():
71
+ chat_state.chat_history = []
72
+ #chat_state.sources_history = []
73
+ return chat_state.chat_history, ""
74
+
75
+ def show_selection(model):
76
+ return f"You selected: {model}"
77
+
78
+ # Read config file
79
+ with open(CONFIG_PATH, 'r') as yaml_file:
80
+ config = yaml.safe_load(yaml_file)
81
+
82
+ prod_mode = config.get('prod_mode', False)
83
+ default_collection = 'ekr_default_collection'
84
+
85
+ # Load env variables
86
+ initialize_env_variables(prod_mode)
87
+
88
+ caution_text = """⚠️ Note: depending on the size of your document, this could take several minutes.
89
+ """
90
+
91
+ with gr.Blocks() as demo:
92
+ #gr.Markdown("# SambaNova Analyst Assistant") # title
93
+ gr.Markdown("# 🟠 SambaNova Analyst Assistant",
94
+ elem_id="title")
95
+
96
+ gr.Markdown("Powered by SambaNova Cloud. Get your API key [here](https://cloud.sambanova.ai/apis).")
97
+
98
+ api_key = gr.Textbox(label="API Key", type="password", placeholder="(Optional) Enter your API key here for more availability")
99
+
100
+ # Step 1: Add PDF file
101
+ gr.Markdown("## 1️⃣ Pick a datasource")
102
+ docs = gr.File(label="Add PDF file", file_types=["pdf"], file_count="single")
103
+
104
+ # Step 2: Process PDF file
105
+ gr.Markdown(("## 2️⃣ Process your documents and create vector store"))
106
+ process_btn = gr.Button("🔄 Process")
107
+ gr.Markdown(caution_text)
108
+ setup_output = gr.Textbox(label="Setup Output", visible=True)
109
+
110
+ process_btn.click(process_documents, inputs=[docs], outputs=setup_output, concurrency_limit=10)
111
+ #process_save_btn.click(process_documents, inputs=[file_upload, save_location], outputs=setup_output)
112
+ #load_db_btn.click(load_existing_db, inputs=[db_path], outputs=setup_output)
113
+
114
+ # Step 3: Chat with your data
115
+ gr.Markdown("## 3️⃣ Chat")
116
+ chatbot = gr.Chatbot(label="Chatbot", show_label=True, show_share_button=False, show_copy_button=True, likeable=True)
117
+ msg = gr.Textbox(label="Ask questions about your data", placeholder="Enter your message...")
118
+ clear = gr.Button("Clear chat")
119
+ #show_sources = gr.Checkbox(label="Show sources", value=True)
120
+ sources_output = gr.Textbox(label="Sources", visible=False)
121
+
122
+ #msg.submit(handle_userinput, inputs=[msg], outputs=[chatbot, sources_output])
123
+ msg.submit(handle_userinput, inputs=[msg], outputs=[chatbot, msg])
124
+ clear.click(reset_conversation, outputs=[chatbot,msg])
125
+ #show_sources.change(lambda x: gr.update(visible=x), show_sources, sources_output)
126
+
127
+ if __name__ == "__main__":
128
+ demo.launch()
config.yaml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ api: "sncloud" # set either sambastudio or sncloud
2
+
3
+ embedding_model:
4
+ "type": "cpu" # set either sambastudio or cpu
5
+ "batch_size": 1 #set depending of your endpoint configuration (1 if CoE embedding expert)
6
+ "coe": True #set true if using Sambastudio embeddings in a CoE endpoint
7
+ "select_expert": "e5-mistral-7b-instruct" #set if using SambaStudio CoE embedding expert
8
+
9
+ llm:
10
+ "temperature": 0.0
11
+ "do_sample": False
12
+ "max_tokens_to_generate": 1200
13
+ "coe": True #set as true if using Sambastudio CoE endpoint
14
+ "select_expert": "llama3-8b" #set if using sncloud, SambaStudio CoE llm expert
15
+ #sncloud CoE expert name -> "llama3-8b"
16
+
17
+ retrieval:
18
+ "k_retrieved_documents": 15 #set if rerank enabled
19
+ "score_threshold": 0.2
20
+ "rerank": False # set if you want to rerank retriever results
21
+ "reranker": 'BAAI/bge-reranker-large' # set if you rerank enabled
22
+ "final_k_retrieved_documents": 5
23
+
24
+ pdf_only_mode: True # Set to true for PDF-only mode, false for all file types
25
+ prod_mode: False
26
+
27
+ prompts:
28
+ "qa_prompt": "prompts/qa_prompt.yaml"
29
+ "final_chain_prompt": "prompts/final_chain_prompt.yaml"
prompts/final_chain_prompt.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _type: prompt
2
+ input_types: {}
3
+ input_variables:
4
+ - question
5
+ - answers
6
+ name: null
7
+ output_parser: null
8
+ partial_variables: {}
9
+ template: |
10
+ <|begin_of_text|><|start_header_id|>system<|end_header_id|> You are an assistant for question-answering tasks.
11
+ Use the following intermediate answers, provide a final answer to the original question. If you cannot answer based on the intermediate answers provided to you, say "Whoops! I don't know!". <|eot_id|><|start_header_id|>user<|end_header_id|>
12
+ Original Question: {question}
13
+ Intermediate Answers: {answers}
14
+ \n ------- \n
15
+ Answer: <|eot_id|><|start_header_id|>assistant<|end_header_id|>
16
+ template_format: f-string
17
+ validate_template: false
prompts/llama7b-knowledge_retriever-custom_qa_prompt.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _type: prompt
2
+ input_types: {}
3
+ input_variables:
4
+ - context
5
+ - question
6
+ name: null
7
+ output_parser: null
8
+ partial_variables: {}
9
+ template: "[INST]<<SYS>> You are a helpful assistant for question-answering tasks.\
10
+ \ Use the following pieces of retrieved context to answer the question.\n \
11
+ \ each piece of context includes the Source for reference\n if the question \
12
+ \ references a specific source then filter out that source and give a response based on that source\n If\
13
+ \ the answer is not in the context, say that you don't know. Cross check if the\
14
+ \ answer is contained in provided context. If not than say \"I do not have information\
15
+ \ regarding this.\n Do not use images or emojis in your answer. Keep the answer\
16
+ \ conversational and professional.<</SYS>>\n\n {context} \n \n Question:\
17
+ \ {question} \n Helpful answer: [/INST]"
18
+ template_format: f-string
19
+ validate_template: false
prompts/qa_prompt.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _type: prompt
2
+ input_types: {}
3
+ input_variables:
4
+ - context
5
+ - question
6
+ name: null
7
+ output_parser: null
8
+ partial_variables: {}
9
+ template: |
10
+ <|begin_of_text|><|start_header_id|>system<|end_header_id|> You are a knowledge base assistant chatbot powered by Sambanova's AI chip accelerator, designed to answer questions based on user-uploaded documents.
11
+ Use the following pieces of retrieved context to answer the question. Each piece of context includes the Source for reference. If the question references a specific source, then filter out that source and give a response based on that source.
12
+ If the answer is not in the context, say: "This information isn't in my current knowledge base." Then, suggest a related topic you can discuss based on the available context.
13
+ Maintain a professional yet conversational tone. Do not use images or emojis in your answer.
14
+ Prioritize accuracy and only provide information directly supported by the context. <|eot_id|><|start_header_id|>user<|end_header_id|>
15
+ Question: {question}
16
+ Context: {context}
17
+ \n ------- \n
18
+ Answer: <|eot_id|><|start_header_id|>assistant<|end_header_id|>
19
+
20
+ template_format: f-string
21
+ validate_template: false
requirements.txt ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ streamlit==1.36.0
2
+ pydantic==2.7.0
3
+ pydantic_core==2.18.1
4
+
5
+ langchain==0.2.16
6
+ langchain-core==0.2.38
7
+ langchain-community==0.2.16
8
+
9
+ sentence_transformers==2.2.2
10
+ instructorembedding==1.0.1
11
+ faiss-cpu==1.7.4
12
+ python-dotenv==1.0.0
13
+ streamlit-extras==0.4.3
14
+ pillow==10.4.0
15
+ sseclient-py==1.8.0
16
+ # unstructured==0.14.9
17
+ # unstructured_inference==0.7.36
18
+ # unstructured_pytesseract==0.3.12
19
+ # pytesseract==0.3.10
20
+ chromadb==0.5.3
21
+ langgraph==0.0.55
22
+ openpyxl==3.1.4
23
+ psutil==6.0.0
24
+ pillow_heif==0.16.0
25
+ ipython==8.26.0
26
+ PyMuPDF==1.23.4
27
+ PyMuPDFb==1.23.3
28
+
29
+ #LLM Eval
30
+ weave==0.51.1
src/bulkQA.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import argparse
4
+ import pandas as pd
5
+ import time
6
+ from typing import Any, Dict, Optional
7
+ from langchain_core.callbacks import CallbackManagerForChainRun
8
+ from langchain.prompts import load_prompt
9
+ from langchain_core.output_parsers import StrOutputParser
10
+ from transformers import AutoTokenizer
11
+
12
+ current_dir = os.path.dirname(os.path.abspath(__file__))
13
+ kit_dir = os.path.abspath(os.path.join(current_dir, ".."))
14
+ repo_dir = os.path.abspath(os.path.join(kit_dir, ".."))
15
+
16
+ sys.path.append(kit_dir)
17
+ sys.path.append(repo_dir)
18
+
19
+ from enterprise_knowledge_retriever.src.document_retrieval import DocumentRetrieval, RetrievalQAChain
20
+
21
+ class TimedRetrievalQAChain(RetrievalQAChain):
22
+ #override call method to return times
23
+ def _call(self,
24
+ inputs: Dict[str, Any],
25
+ run_manager: Optional[CallbackManagerForChainRun] = None,
26
+ ) -> Dict[str, Any]:
27
+ qa_chain = self.qa_prompt | self.llm | StrOutputParser()
28
+ response = {}
29
+ start_time = time.time()
30
+ documents = self.retriever.invoke(inputs["question"])
31
+ if self.rerank:
32
+ documents = self.rerank_docs(inputs["question"], documents, self.final_k_retrieved_documents)
33
+ docs = self._format_docs(documents)
34
+ end_preprocessing_time=time.time()
35
+ response["answer"] = qa_chain.invoke({"question": inputs["question"], "context": docs})
36
+ end_llm_time=time.time()
37
+ response["source_documents"] = documents
38
+ response["start_time"] = start_time
39
+ response["end_preprocessing_time"] = end_preprocessing_time
40
+ response["end_llm_time"] = end_llm_time
41
+ return response
42
+
43
+ def analyze_times(answer, start_time, end_preprocessing_time, end_llm_time, tokenizer):
44
+ preprocessing_time=end_preprocessing_time-start_time
45
+ llm_time=end_llm_time-end_preprocessing_time
46
+ token_count=len(tokenizer.encode(answer))
47
+ tokens_per_second = token_count / llm_time
48
+ perf = {"preprocessing_time": preprocessing_time,
49
+ "llm_time": llm_time,
50
+ "token_count": token_count,
51
+ "tokens_per_second": tokens_per_second}
52
+ return perf
53
+
54
+ def generate(qa_chain, question, tokenizer):
55
+ response = qa_chain.invoke({"question": question})
56
+ answer = response.get('answer')
57
+ sources = set([
58
+ f'{sd.metadata["filename"]}'
59
+ for sd in response["source_documents"]
60
+ ])
61
+ times = analyze_times(
62
+ answer,
63
+ response.get("start_time"),
64
+ response.get("end_preprocessing_time"),
65
+ response.get("end_llm_time"),
66
+ tokenizer
67
+ )
68
+ return answer, sources, times
69
+
70
+ def process_bulk_QA(vectordb_path, questions_file_path):
71
+ documentRetrieval = DocumentRetrieval()
72
+ tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-2-7b-chat-hf")
73
+ if os.path.exists(vectordb_path):
74
+ # load the vectorstore
75
+ embeddings = documentRetrieval.load_embedding_model()
76
+ vectorstore = documentRetrieval.load_vdb(vectordb_path, embeddings)
77
+ print("Database loaded")
78
+ documentRetrieval.init_retriever(vectorstore)
79
+ print("retriever initialized")
80
+ #get qa chain
81
+ qa_chain = TimedRetrievalQAChain(
82
+ retriever=documentRetrieval.retriever,
83
+ llm=documentRetrieval.llm,
84
+ qa_prompt = load_prompt(os.path.join(kit_dir, documentRetrieval.prompts["qa_prompt"])),
85
+ rerank = documentRetrieval.retrieval_info["rerank"],
86
+ final_k_retrieved_documents = documentRetrieval.retrieval_info["final_k_retrieved_documents"]
87
+
88
+ )
89
+ else:
90
+ raise f"vector db path {vectordb_path} does not exist"
91
+ if os.path.exists(questions_file_path):
92
+ df = pd.read_excel(questions_file_path)
93
+ print(df)
94
+ output_file_path = questions_file_path.replace('.xlsx', '_output.xlsx')
95
+ if 'Answer' not in df.columns:
96
+ df['Answer'] = ''
97
+ df['Sources'] = ''
98
+ df['preprocessing_time'] = ''
99
+ df['llm_time'] = ''
100
+ df['token_count'] = ''
101
+ df['tokens_per_second'] = ''
102
+ for index, row in df.iterrows():
103
+ if row['Answer'].strip()=='': # Only process if 'Answer' is empty
104
+ try:
105
+ # Generate the answer
106
+ print(f"Generating answer for row {index}")
107
+ answer, sources, times = generate(qa_chain, row['Questions'], tokenizer)
108
+ df.at[index, 'Answer'] = answer
109
+ df.at[index, 'Sources'] = sources
110
+ df.at[index, 'preprocessing_time'] = times.get("preprocessing_time")
111
+ df.at[index, 'llm_time'] = times.get("llm_time")
112
+ df.at[index, 'token_count'] = times.get("token_count")
113
+ df.at[index, 'tokens_per_second'] = times.get("tokens_per_second")
114
+ except Exception as e:
115
+ print(f"Error processing row {index}: {e}")
116
+ # Save the file after each iteration to avoid data loss
117
+ df.to_excel(output_file_path, index=False)
118
+ else:
119
+ print(f"Skipping row {index} because 'Answer' is already in the document")
120
+ return output_file_path
121
+ else:
122
+ raise f"questions file path {questions_file_path} does not exist"
123
+
124
+ if __name__ == "__main__":
125
+ # Parse the arguments
126
+ parser = argparse.ArgumentParser(description='use a vectordb and an excel file with questions in the first column and generate answers for all the questions')
127
+ parser.add_argument('vectordb_path', type=str, help='vector db path with stored documents for RAG')
128
+ parser.add_argument('questions_path', type=str, help='xlsx file containing questions in a column named Questions')
129
+ args = parser.parse_args()
130
+ # process in bulk
131
+ out_file = process_bulk_QA(args.vectordb_path, args.questions_path)
132
+ print(f"Finished, responses in: {out_file}")
src/document_retrieval.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ import sys
4
+ from typing import Any, Dict, List, Optional
5
+
6
+ import torch
7
+ import yaml
8
+ from dotenv import load_dotenv
9
+ from langchain.chains.base import Chain
10
+ from langchain.docstore.document import Document
11
+ from langchain.prompts import BasePromptTemplate, load_prompt
12
+ from langchain_core.callbacks import CallbackManagerForChainRun
13
+ from langchain_core.language_models import BaseLanguageModel
14
+ from langchain_core.output_parsers import StrOutputParser
15
+ from langchain_core.retrievers import BaseRetriever
16
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
17
+
18
+ current_dir = os.path.dirname(os.path.abspath(__file__)) # src/ directory
19
+ kit_dir = os.path.abspath(os.path.join(current_dir, '..')) # EKR/ directory
20
+ repo_dir = os.path.abspath(os.path.join(kit_dir, '..'))
21
+ sys.path.append(kit_dir)
22
+ sys.path.append(repo_dir)
23
+
24
+ import streamlit as st
25
+
26
+ from utils.model_wrappers.api_gateway import APIGateway
27
+ from utils.vectordb.vector_db import VectorDb
28
+ from utils.visual.env_utils import get_wandb_key
29
+
30
+ CONFIG_PATH = os.path.join(kit_dir, 'config.yaml')
31
+ PERSIST_DIRECTORY = os.path.join(kit_dir, 'data/my-vector-db')
32
+
33
+ load_dotenv(os.path.join(kit_dir, '.env'))
34
+
35
+
36
+ from utils.parsing.sambaparse import parse_doc_universal
37
+
38
+ # Handle the WANDB_API_KEY resolution before importing weave
39
+ #wandb_api_key = get_wandb_key()
40
+
41
+ # If WANDB_API_KEY is set, proceed with weave initialization
42
+ #if wandb_api_key:
43
+ # import weave
44
+
45
+ # Initialize Weave with your project name
46
+ # weave.init('sambanova_ekr')
47
+ #else:
48
+ # print('WANDB_API_KEY is not set. Weave initialization skipped.')
49
+
50
+
51
+ class RetrievalQAChain(Chain):
52
+ """class for question-answering."""
53
+
54
+ retriever: BaseRetriever
55
+ rerank: bool = True
56
+ llm: BaseLanguageModel
57
+ qa_prompt: BasePromptTemplate
58
+ final_k_retrieved_documents: int = 3
59
+
60
+ @property
61
+ def input_keys(self) -> List[str]:
62
+ """Input keys.
63
+ :meta private:
64
+ """
65
+ return ['question']
66
+
67
+ @property
68
+ def output_keys(self) -> List[str]:
69
+ """Output keys.
70
+ :meta private:
71
+ """
72
+ return ['answer', 'source_documents']
73
+
74
+ def _format_docs(self, docs):
75
+ return '\n\n'.join(doc.page_content for doc in docs)
76
+
77
+ def rerank_docs(self, query, docs, final_k):
78
+ # Lazy hardcoding for now
79
+ tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-reranker-large')
80
+ reranker = AutoModelForSequenceClassification.from_pretrained('BAAI/bge-reranker-large')
81
+ pairs = []
82
+ for d in docs:
83
+ pairs.append([query, d.page_content])
84
+
85
+ with torch.no_grad():
86
+ inputs = tokenizer(
87
+ pairs,
88
+ padding=True,
89
+ truncation=True,
90
+ return_tensors='pt',
91
+ max_length=512,
92
+ )
93
+ scores = (
94
+ reranker(**inputs, return_dict=True)
95
+ .logits.view(
96
+ -1,
97
+ )
98
+ .float()
99
+ )
100
+
101
+ scores_list = scores.tolist()
102
+ scores_sorted_idx = sorted(range(len(scores_list)), key=lambda k: scores_list[k], reverse=True)
103
+
104
+ docs_sorted = [docs[k] for k in scores_sorted_idx]
105
+ # docs_sorted = [docs[k] for k in scores_sorted_idx if scores_list[k]>0]
106
+ docs_sorted = docs_sorted[:final_k]
107
+
108
+ return docs_sorted
109
+
110
+ def _call(
111
+ self,
112
+ inputs: Dict[str, Any],
113
+ run_manager: Optional[CallbackManagerForChainRun] = None,
114
+ ) -> Dict[str, Any]:
115
+ qa_chain = self.qa_prompt | self.llm | StrOutputParser()
116
+ response = {}
117
+ documents = self.retriever.invoke(inputs['question'])
118
+ if self.rerank:
119
+ documents = self.rerank_docs(inputs['question'], documents, self.final_k_retrieved_documents)
120
+ docs = self._format_docs(documents)
121
+ response['answer'] = qa_chain.invoke({'question': inputs['question'], 'context': docs})
122
+ response['source_documents'] = documents
123
+ return response
124
+
125
+
126
+ class DocumentRetrieval:
127
+ def __init__(self):
128
+ self.vectordb = VectorDb()
129
+ config_info = self.get_config_info()
130
+ self.api_info = config_info[0]
131
+ self.llm_info = config_info[1]
132
+ self.embedding_model_info = config_info[2]
133
+ self.retrieval_info = config_info[3]
134
+ self.prompts = config_info[4]
135
+ self.prod_mode = config_info[5]
136
+ self.retriever = None
137
+ self.llm = self.set_llm()
138
+
139
+ def get_config_info(self):
140
+ """
141
+ Loads json config file
142
+ """
143
+ # Read config file
144
+ with open(CONFIG_PATH, 'r') as yaml_file:
145
+ config = yaml.safe_load(yaml_file)
146
+ api_info = config['api']
147
+ llm_info = config['llm']
148
+ embedding_model_info = config['embedding_model']
149
+ retrieval_info = config['retrieval']
150
+ prompts = config['prompts']
151
+ prod_mode = config['prod_mode']
152
+
153
+ return api_info, llm_info, embedding_model_info, retrieval_info, prompts, prod_mode
154
+
155
+ def set_llm(self):
156
+ if self.prod_mode:
157
+ sambanova_api_key = st.session_state.SAMBANOVA_API_KEY
158
+ else:
159
+ if 'SAMBANOVA_API_KEY' in st.session_state:
160
+ sambanova_api_key = os.environ.get('SAMBANOVA_API_KEY') or st.session_state.SAMBANOVA_API_KEY
161
+ else:
162
+ sambanova_api_key = os.environ.get('SAMBANOVA_API_KEY')
163
+
164
+ llm = APIGateway.load_llm(
165
+ type=self.api_info,
166
+ streaming=True,
167
+ coe=self.llm_info['coe'],
168
+ do_sample=self.llm_info['do_sample'],
169
+ max_tokens_to_generate=self.llm_info['max_tokens_to_generate'],
170
+ temperature=self.llm_info['temperature'],
171
+ select_expert=self.llm_info['select_expert'],
172
+ process_prompt=False,
173
+ sambanova_api_key=sambanova_api_key,
174
+ )
175
+ return llm
176
+
177
+ def parse_doc(self, docs: List, additional_metadata: Optional[Dict] = None) -> List[Document]:
178
+ """
179
+ Parse the uploaded documents and return a list of LangChain documents.
180
+
181
+ Args:
182
+ docs (List[UploadFile]): A list of uploaded files.
183
+ additional_metadata (Optional[Dict], optional): Additional metadata to include in the processed documents.
184
+ Defaults to an empty dictionary.
185
+
186
+ Returns:
187
+ List[Document]: A list of LangChain documents.
188
+ """
189
+ if additional_metadata is None:
190
+ additional_metadata = {}
191
+
192
+ # Create the data/tmp folder if it doesn't exist
193
+ temp_folder = os.path.join(kit_dir, 'data/tmp')
194
+ if not os.path.exists(temp_folder):
195
+ os.makedirs(temp_folder)
196
+ else:
197
+ # If there are already files there, delete them
198
+ for filename in os.listdir(temp_folder):
199
+ file_path = os.path.join(temp_folder, filename)
200
+ try:
201
+ if os.path.isfile(file_path) or os.path.islink(file_path):
202
+ os.unlink(file_path)
203
+ elif os.path.isdir(file_path):
204
+ shutil.rmtree(file_path)
205
+ except Exception as e:
206
+ print(f'Failed to delete {file_path}. Reason: {e}')
207
+
208
+ # Save all selected files to the tmp dir with their file names
209
+ #for doc in docs:
210
+ # temp_file = os.path.join(temp_folder, doc.name)
211
+ # with open(temp_file, 'wb') as f:
212
+ # f.write(doc.getvalue())
213
+
214
+ for doc_info in docs:
215
+ file_name, file_obj = doc_info
216
+ temp_file = os.path.join(temp_folder, file_name)
217
+ with open(temp_file, 'wb') as f:
218
+ f.write(file_obj.read())
219
+
220
+ # Pass in the temp folder for processing into the parse_doc_universal function
221
+ _, _, langchain_docs = parse_doc_universal(doc=temp_folder, additional_metadata=additional_metadata)
222
+ return langchain_docs
223
+
224
+ def load_embedding_model(self):
225
+ embeddings = APIGateway.load_embedding_model(
226
+ type=self.embedding_model_info['type'],
227
+ batch_size=self.embedding_model_info['batch_size'],
228
+ coe=self.embedding_model_info['coe'],
229
+ select_expert=self.embedding_model_info['select_expert'],
230
+ )
231
+ return embeddings
232
+
233
+ def create_vector_store(self, text_chunks, embeddings, output_db=None, collection_name=None):
234
+ print(f'Collection name is {collection_name}')
235
+ vectorstore = self.vectordb.create_vector_store(
236
+ text_chunks, embeddings, output_db=output_db, collection_name=collection_name, db_type='chroma'
237
+ )
238
+ return vectorstore
239
+
240
+ def load_vdb(self, db_path, embeddings, collection_name=None):
241
+ print(f'Loading collection name is {collection_name}')
242
+ vectorstore = self.vectordb.load_vdb(db_path, embeddings, db_type='chroma', collection_name=collection_name)
243
+ return vectorstore
244
+
245
+ def init_retriever(self, vectorstore):
246
+ if self.retrieval_info['rerank']:
247
+ self.retriever = vectorstore.as_retriever(
248
+ search_type='similarity_score_threshold',
249
+ search_kwargs={
250
+ 'score_threshold': self.retrieval_info['score_threshold'],
251
+ 'k': self.retrieval_info['k_retrieved_documents'],
252
+ },
253
+ )
254
+ else:
255
+ self.retriever = vectorstore.as_retriever(
256
+ search_type='similarity_score_threshold',
257
+ search_kwargs={
258
+ 'score_threshold': self.retrieval_info['score_threshold'],
259
+ 'k': self.retrieval_info['final_k_retrieved_documents'],
260
+ },
261
+ )
262
+
263
+ def get_qa_retrieval_chain(self):
264
+ """
265
+ Generate a qa_retrieval chain using a language model.
266
+
267
+ This function uses a language model, specifically a SambaNova LLM, to generate a qa_retrieval chain
268
+ based on the input vector store of text chunks.
269
+
270
+ Parameters:
271
+ vectorstore (Chroma): A Vector Store containing embeddings of text chunks used as context
272
+ for generating the conversation chain.
273
+
274
+ Returns:
275
+ RetrievalQA: A chain ready for QA without memory
276
+ """
277
+ # customprompt = load_prompt(os.path.join(kit_dir, self.prompts["qa_prompt"]))
278
+ # qa_chain = customprompt | self.llm | StrOutputParser()
279
+
280
+ # response = {}
281
+ # documents = self.retriever.invoke(question)
282
+ # if self.retrieval_info["rerank"]:
283
+ # documents = self.rerank_docs(question, documents, self.retrieval_info["final_k_retrieved_documents"])
284
+ # docs = self._format_docs(documents)
285
+
286
+ # response["answer"] = qa_chain.invoke({"question": question, "context": docs})
287
+ # response["source_documents"] = documents
288
+
289
+ retrievalQAChain = RetrievalQAChain(
290
+ retriever=self.retriever,
291
+ llm=self.llm,
292
+ qa_prompt=load_prompt(os.path.join(kit_dir, self.prompts['qa_prompt'])),
293
+ rerank=self.retrieval_info['rerank'],
294
+ final_k_retrieved_documents=self.retrieval_info['final_k_retrieved_documents'],
295
+ )
296
+ return retrievalQAChain
297
+
298
+ def get_conversational_qa_retrieval_chain(self):
299
+ """
300
+ Generate a conversational retrieval qa chain using a language model.
301
+
302
+ This function uses a language model, specifically a SambaNova LLM, to generate a conversational_qa_retrieval chain
303
+ based on the chat history and the relevant retrieved content from the input vector store of text chunks.
304
+
305
+ Parameters:
306
+ vectorstore (Chroma): A Vector Store containing embeddings of text chunks used as context
307
+ for generating the conversation chain.
308
+
309
+ Returns:
310
+ RetrievalQA: A chain ready for QA with memory
311
+ """
utils/model_wrappers/api_gateway.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import sys
4
+ from typing import Optional, Dict
5
+
6
+ from langchain_community.embeddings import HuggingFaceInstructEmbeddings
7
+ from langchain_core.embeddings import Embeddings
8
+ from langchain_core.language_models.llms import LLM
9
+ from langchain_core.language_models.chat_models import BaseChatModel
10
+
11
+ current_dir = os.path.dirname(os.path.abspath(__file__))
12
+ utils_dir = os.path.abspath(os.path.join(current_dir, '..'))
13
+ repo_dir = os.path.abspath(os.path.join(utils_dir, '..'))
14
+ sys.path.append(utils_dir)
15
+ sys.path.append(repo_dir)
16
+
17
+ from utils.model_wrappers.langchain_embeddings import SambaStudioEmbeddings
18
+ from utils.model_wrappers.langchain_llms import SambaStudio
19
+ from utils.model_wrappers.langchain_llms import SambaNovaCloud
20
+ from utils.model_wrappers.langchain_chat_models import ChatSambaNovaCloud
21
+
22
+ EMBEDDING_MODEL = 'intfloat/e5-large-v2'
23
+ NORMALIZE_EMBEDDINGS = True
24
+
25
+ # Configure the logger
26
+ logging.basicConfig(
27
+ level=logging.INFO,
28
+ format='%(asctime)s [%(levelname)s] - %(message)s',
29
+ handlers=[
30
+ logging.StreamHandler(),
31
+ ],
32
+ )
33
+ logger = logging.getLogger(__name__)
34
+
35
+
36
+ class APIGateway:
37
+ @staticmethod
38
+ def load_embedding_model(
39
+ type: str = 'cpu',
40
+ batch_size: Optional[int] = None,
41
+ coe: bool = False,
42
+ select_expert: Optional[str] = None,
43
+ sambastudio_embeddings_base_url: Optional[str] = None,
44
+ sambastudio_embeddings_base_uri: Optional[str] = None,
45
+ sambastudio_embeddings_project_id: Optional[str] = None,
46
+ sambastudio_embeddings_endpoint_id: Optional[str] = None,
47
+ sambastudio_embeddings_api_key: Optional[str] = None,
48
+ ) -> Embeddings:
49
+ """Loads a langchain embedding model given a type and parameters
50
+ Args:
51
+ type (str): wether to use sambastudio embedding model or in local cpu model
52
+ batch_size (int, optional): batch size for sambastudio model. Defaults to None.
53
+ coe (bool, optional): whether to use coe model. Defaults to False. only for sambastudio models
54
+ select_expert (str, optional): expert model to be used when coe selected. Defaults to None.
55
+ only for sambastudio models.
56
+ sambastudio_embeddings_base_url (str, optional): base url for sambastudio model. Defaults to None.
57
+ sambastudio_embeddings_base_uri (str, optional): endpoint base uri for sambastudio model. Defaults to None.
58
+ sambastudio_embeddings_project_id (str, optional): project id for sambastudio model. Defaults to None.
59
+ sambastudio_embeddings_endpoint_id (str, optional): endpoint id for sambastudio model. Defaults to None.
60
+ sambastudio_embeddings_api_key (str, optional): api key for sambastudio model. Defaults to None.
61
+ Returns:
62
+ langchain embedding model
63
+ """
64
+
65
+ if type == 'sambastudio':
66
+ envs = {
67
+ 'sambastudio_embeddings_base_url': sambastudio_embeddings_base_url,
68
+ 'sambastudio_embeddings_base_uri': sambastudio_embeddings_base_uri,
69
+ 'sambastudio_embeddings_project_id': sambastudio_embeddings_project_id,
70
+ 'sambastudio_embeddings_endpoint_id': sambastudio_embeddings_endpoint_id,
71
+ 'sambastudio_embeddings_api_key': sambastudio_embeddings_api_key,
72
+ }
73
+ envs = {k: v for k, v in envs.items() if v is not None}
74
+
75
+ if coe:
76
+ if batch_size is None:
77
+ batch_size = 1
78
+ embeddings = SambaStudioEmbeddings(
79
+ **envs, batch_size=batch_size, model_kwargs={'select_expert': select_expert}
80
+ )
81
+ else:
82
+ if batch_size is None:
83
+ batch_size = 32
84
+ embeddings = SambaStudioEmbeddings(**envs, batch_size=batch_size)
85
+ elif type == 'cpu':
86
+ encode_kwargs = {'normalize_embeddings': NORMALIZE_EMBEDDINGS}
87
+ embedding_model = EMBEDDING_MODEL
88
+ embeddings = HuggingFaceInstructEmbeddings(
89
+ model_name=embedding_model,
90
+ embed_instruction='', # no instruction is needed for candidate passages
91
+ query_instruction='Represent this sentence for searching relevant passages: ',
92
+ encode_kwargs=encode_kwargs,
93
+ )
94
+ else:
95
+ raise ValueError(f'{type} is not a valid embedding model type')
96
+
97
+ return embeddings
98
+
99
+ @staticmethod
100
+ def load_llm(
101
+ type: str,
102
+ streaming: bool = False,
103
+ coe: bool = False,
104
+ do_sample: Optional[bool] = None,
105
+ max_tokens_to_generate: Optional[int] = None,
106
+ temperature: Optional[float] = None,
107
+ select_expert: Optional[str] = None,
108
+ top_p: Optional[float] = None,
109
+ top_k: Optional[int] = None,
110
+ repetition_penalty: Optional[float] = None,
111
+ stop_sequences: Optional[str] = None,
112
+ process_prompt: Optional[bool] = False,
113
+ sambastudio_base_url: Optional[str] = None,
114
+ sambastudio_base_uri: Optional[str] = None,
115
+ sambastudio_project_id: Optional[str] = None,
116
+ sambastudio_endpoint_id: Optional[str] = None,
117
+ sambastudio_api_key: Optional[str] = None,
118
+ sambanova_url: Optional[str] = None,
119
+ sambanova_api_key: Optional[str] = None,
120
+ ) -> LLM:
121
+ """Loads a langchain Sambanova llm model given a type and parameters
122
+ Args:
123
+ type (str): wether to use sambastudio, or SambaNova Cloud model "sncloud"
124
+ streaming (bool): wether to use streaming method. Defaults to False.
125
+ coe (bool): whether to use coe model. Defaults to False.
126
+
127
+ do_sample (bool) : Optional wether to do sample.
128
+ max_tokens_to_generate (int) : Optional max number of tokens to generate.
129
+ temperature (float) : Optional model temperature.
130
+ select_expert (str) : Optional expert to use when using CoE models.
131
+ top_p (float) : Optional model top_p.
132
+ top_k (int) : Optional model top_k.
133
+ repetition_penalty (float) : Optional model repetition penalty.
134
+ stop_sequences (str) : Optional model stop sequences.
135
+ process_prompt (bool) : Optional default to false.
136
+
137
+ sambastudio_base_url (str): Optional SambaStudio environment URL".
138
+ sambastudio_base_uri (str): Optional SambaStudio-base-URI".
139
+ sambastudio_project_id (str): Optional SambaStudio project ID.
140
+ sambastudio_endpoint_id (str): Optional SambaStudio endpoint ID.
141
+ sambastudio_api_token (str): Optional SambaStudio endpoint API key.
142
+
143
+ sambanova_url (str): Optional SambaNova Cloud URL",
144
+ sambanova_api_key (str): Optional SambaNovaCloud API key.
145
+
146
+ Returns:
147
+ langchain llm model
148
+ """
149
+
150
+ if type == 'sambastudio':
151
+ envs = {
152
+ 'sambastudio_base_url': sambastudio_base_url,
153
+ 'sambastudio_base_uri': sambastudio_base_uri,
154
+ 'sambastudio_project_id': sambastudio_project_id,
155
+ 'sambastudio_endpoint_id': sambastudio_endpoint_id,
156
+ 'sambastudio_api_key': sambastudio_api_key,
157
+ }
158
+ envs = {k: v for k, v in envs.items() if v is not None}
159
+ if coe:
160
+ model_kwargs = {
161
+ 'do_sample': do_sample,
162
+ 'max_tokens_to_generate': max_tokens_to_generate,
163
+ 'temperature': temperature,
164
+ 'select_expert': select_expert,
165
+ 'top_p': top_p,
166
+ 'top_k': top_k,
167
+ 'repetition_penalty': repetition_penalty,
168
+ 'stop_sequences': stop_sequences,
169
+ 'process_prompt': process_prompt,
170
+ }
171
+ model_kwargs = {k: v for k, v in model_kwargs.items() if v is not None}
172
+
173
+ llm = SambaStudio(
174
+ **envs,
175
+ streaming=streaming,
176
+ model_kwargs=model_kwargs,
177
+ )
178
+ else:
179
+ model_kwargs = {
180
+ 'do_sample': do_sample,
181
+ 'max_tokens_to_generate': max_tokens_to_generate,
182
+ 'temperature': temperature,
183
+ 'top_p': top_p,
184
+ 'top_k': top_k,
185
+ 'repetition_penalty': repetition_penalty,
186
+ 'stop_sequences': stop_sequences,
187
+ }
188
+ model_kwargs = {k: v for k, v in model_kwargs.items() if v is not None}
189
+ llm = SambaStudio(
190
+ **envs,
191
+ streaming=streaming,
192
+ model_kwargs=model_kwargs,
193
+ )
194
+
195
+ elif type == 'sncloud':
196
+ envs = {
197
+ 'sambanova_url': sambanova_url,
198
+ 'sambanova_api_key': sambanova_api_key,
199
+ }
200
+ envs = {k: v for k, v in envs.items() if v is not None}
201
+ llm = SambaNovaCloud(
202
+ **envs,
203
+ max_tokens=max_tokens_to_generate,
204
+ model=select_expert,
205
+ temperature=temperature,
206
+ top_k=top_k,
207
+ top_p=top_p,
208
+ )
209
+
210
+ else:
211
+ raise ValueError(f"Invalid LLM API: {type}, only 'sncloud' and 'sambastudio' are supported.")
212
+
213
+ return llm
214
+
215
+ @staticmethod
216
+ def load_chat(
217
+ model: str,
218
+ streaming: bool = False,
219
+ max_tokens: int = 1024,
220
+ temperature: Optional[float] = 0.0,
221
+ top_p: Optional[float] = None,
222
+ top_k: Optional[int] = None,
223
+ stream_options: Optional[Dict[str, bool]] = {"include_usage": True},
224
+ sambanova_url: Optional[str] = None,
225
+ sambanova_api_key: Optional[str] = None,
226
+ ) -> BaseChatModel:
227
+ """
228
+ Loads a langchain SambanovaCloud chat model given some parameters
229
+ Args:
230
+ model (str): The name of the model to use, e.g., llama3-8b.
231
+ streaming (bool): whether to use streaming method. Defaults to False.
232
+ max_tokens (int) : Optional max number of tokens to generate.
233
+ temperature (float) : Optional model temperature.
234
+ top_p (float) : Optional model top_p.
235
+ top_k (int) : Optional model top_k.
236
+ stream_options (dict) : stream options, include usage to get generation metrics
237
+
238
+ sambanova_url (str): Optional SambaNova Cloud URL",
239
+ sambanova_api_key (str): Optional SambaNovaCloud API key.
240
+
241
+ Returns:
242
+ langchain BaseChatModel
243
+ """
244
+
245
+ envs = {
246
+ 'sambanova_url': sambanova_url,
247
+ 'sambanova_api_key': sambanova_api_key,
248
+ }
249
+ envs = {k: v for k, v in envs.items() if v is not None}
250
+ model = ChatSambaNovaCloud(
251
+ **envs,
252
+ model= model,
253
+ streaming=streaming,
254
+ max_tokens=max_tokens,
255
+ temperature=temperature,
256
+ top_k=top_k,
257
+ top_p=top_p,
258
+ stream_options=stream_options
259
+ )
260
+ return model
utils/model_wrappers/langchain_chat_models.py ADDED
@@ -0,0 +1,465 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from typing import Any, Dict, Iterator, List, Optional
3
+
4
+ import requests
5
+ from langchain_core.callbacks import (
6
+ CallbackManagerForLLMRun,
7
+ )
8
+ from langchain_core.language_models.chat_models import (
9
+ BaseChatModel,
10
+ generate_from_stream,
11
+ )
12
+ from langchain_core.messages import (
13
+ AIMessage,
14
+ AIMessageChunk,
15
+ BaseMessage,
16
+ ChatMessage,
17
+ HumanMessage,
18
+ SystemMessage,
19
+ ToolMessage,
20
+ )
21
+ from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
22
+ from langchain_core.pydantic_v1 import Field, SecretStr
23
+ from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
24
+
25
+
26
+ class ChatSambaNovaCloud(BaseChatModel):
27
+ """
28
+ SambaNova Cloud chat model.
29
+
30
+ Setup:
31
+ To use, you should have the environment variables
32
+ ``SAMBANOVA_URL`` set with your SambaNova Cloud URL.
33
+ ``SAMBANOVA_API_KEY`` set with your SambaNova Cloud API Key.
34
+ http://cloud.sambanova.ai/
35
+ Example:
36
+ .. code-block:: python
37
+ ChatSambaNovaCloud(
38
+ sambanova_url = SambaNova cloud endpoint URL,
39
+ sambanova_api_key = set with your SambaNova cloud API key,
40
+ model = model name,
41
+ streaming = set True for use streaming API
42
+ max_tokens = max number of tokens to generate,
43
+ temperature = model temperature,
44
+ top_p = model top p,
45
+ top_k = model top k,
46
+ stream_options = include usage to get generation metrics
47
+ )
48
+
49
+ Key init args — completion params:
50
+ model: str
51
+ The name of the model to use, e.g., llama3-8b.
52
+ streaming: bool
53
+ Whether to use streaming or not
54
+ max_tokens: int
55
+ max tokens to generate
56
+ temperature: float
57
+ model temperature
58
+ top_p: float
59
+ model top p
60
+ top_k: int
61
+ model top k
62
+ stream_options: dict
63
+ stream options, include usage to get generation metrics
64
+
65
+ Key init args — client params:
66
+ sambanova_url: str
67
+ SambaNova Cloud Url
68
+ sambanova_api_key: str
69
+ SambaNova Cloud api key
70
+
71
+ Instantiate:
72
+ .. code-block:: python
73
+
74
+ from langchain_community.chat_models import ChatSambaNovaCloud
75
+
76
+ chat = ChatSambaNovaCloud(
77
+ sambanova_url = SambaNova cloud endpoint URL,
78
+ sambanova_api_key = set with your SambaNova cloud API key,
79
+ model = model name,
80
+ streaming = set True for streaming
81
+ max_tokens = max number of tokens to generate,
82
+ temperature = model temperature,
83
+ top_p = model top p,
84
+ top_k = model top k,
85
+ stream_options = include usage to get generation metrics
86
+ )
87
+ Invoke:
88
+ .. code-block:: python
89
+ messages = [
90
+ SystemMessage(content="your are an AI assistant."),
91
+ HumanMessage(content="tell me a joke."),
92
+ ]
93
+ response = chat.invoke(messages)
94
+
95
+ Stream:
96
+ .. code-block:: python
97
+
98
+ for chunk in chat.stream(messages):
99
+ print(chunk.content, end="", flush=True)
100
+
101
+ Async:
102
+ .. code-block:: python
103
+
104
+ response = chat.ainvoke(messages)
105
+ await response
106
+
107
+ Token usage:
108
+ .. code-block:: python
109
+ response = chat.invoke(messages)
110
+ print(response.response_metadata["usage"]["prompt_tokens"]
111
+ print(response.response_metadata["usage"]["total_tokens"]
112
+
113
+ Response metadata
114
+ .. code-block:: python
115
+
116
+ response = chat.invoke(messages)
117
+ print(response.response_metadata)
118
+ """
119
+
120
+ sambanova_url: str = Field(default="")
121
+ """SambaNova Cloud Url"""
122
+
123
+ sambanova_api_key: SecretStr = Field(default="")
124
+ """SambaNova Cloud api key"""
125
+
126
+ model: str = Field(default="llama3-8b")
127
+ """The name of the model"""
128
+
129
+ streaming: bool = Field(default=False)
130
+ """Whether to use streaming or not"""
131
+
132
+ max_tokens: int = Field(default=1024)
133
+ """max tokens to generate"""
134
+
135
+ temperature: float = Field(default=0.7)
136
+ """model temperature"""
137
+
138
+ top_p: float = Field(default=0.0)
139
+ """model top p"""
140
+
141
+ top_k: int = Field(default=1)
142
+ """model top k"""
143
+
144
+ stream_options: dict = Field(default={"include_usage": True})
145
+ """stream options, include usage to get generation metrics"""
146
+
147
+ class Config:
148
+ allow_population_by_field_name = True
149
+
150
+ @classmethod
151
+ def is_lc_serializable(cls) -> bool:
152
+ """Return whether this model can be serialized by Langchain."""
153
+ return False
154
+
155
+ @property
156
+ def lc_secrets(self) -> Dict[str, str]:
157
+ return {"sambanova_api_key": "sambanova_api_key"}
158
+
159
+ @property
160
+ def _identifying_params(self) -> Dict[str, Any]:
161
+ """Return a dictionary of identifying parameters.
162
+
163
+ This information is used by the LangChain callback system, which
164
+ is used for tracing purposes make it possible to monitor LLMs.
165
+ """
166
+ return {
167
+ "model": self.model,
168
+ "streaming": self.streaming,
169
+ "max_tokens": self.max_tokens,
170
+ "temperature": self.temperature,
171
+ "top_p": self.top_p,
172
+ "top_k": self.top_k,
173
+ "stream_options": self.stream_options,
174
+ }
175
+
176
+ @property
177
+ def _llm_type(self) -> str:
178
+ """Get the type of language model used by this chat model."""
179
+ return "sambanovacloud-chatmodel"
180
+
181
+ def __init__(self, **kwargs: Any) -> None:
182
+ """init and validate environment variables"""
183
+ kwargs["sambanova_url"] = get_from_dict_or_env(
184
+ kwargs,
185
+ "sambanova_url",
186
+ "SAMBANOVA_URL",
187
+ default="https://api.sambanova.ai/v1/chat/completions",
188
+ )
189
+ kwargs["sambanova_api_key"] = convert_to_secret_str(
190
+ get_from_dict_or_env(kwargs, "sambanova_api_key", "SAMBANOVA_API_KEY")
191
+ )
192
+ super().__init__(**kwargs)
193
+
194
+ def _handle_request(
195
+ self, messages_dicts: List[Dict], stop: Optional[List[str]] = None
196
+ ) -> Dict[str, Any]:
197
+ """
198
+ Performs a post request to the LLM API.
199
+
200
+ Args:
201
+ messages_dicts: List of role / content dicts to use as input.
202
+ stop: list of stop tokens
203
+
204
+ Returns:
205
+ An iterator of response dicts.
206
+ """
207
+ data = {
208
+ "messages": messages_dicts,
209
+ "max_tokens": self.max_tokens,
210
+ "stop": stop,
211
+ "model": self.model,
212
+ "temperature": self.temperature,
213
+ "top_p": self.top_p,
214
+ "top_k": self.top_k,
215
+ }
216
+ http_session = requests.Session()
217
+ response = http_session.post(
218
+ self.sambanova_url,
219
+ headers={
220
+ "Authorization": f"Bearer {self.sambanova_api_key.get_secret_value()}",
221
+ "Content-Type": "application/json",
222
+ },
223
+ json=data,
224
+ )
225
+ if response.status_code != 200:
226
+ raise RuntimeError(
227
+ f"Sambanova /complete call failed with status code "
228
+ f"{response.status_code}."
229
+ f"{response.text}."
230
+ )
231
+ response_dict = response.json()
232
+ if response_dict.get("error"):
233
+ raise RuntimeError(
234
+ f"Sambanova /complete call failed with status code "
235
+ f"{response.status_code}."
236
+ f"{response_dict}."
237
+ )
238
+ return response_dict
239
+
240
+ def _handle_streaming_request(
241
+ self, messages_dicts: List[Dict], stop: Optional[List[str]] = None
242
+ ) -> Iterator[Dict]:
243
+ """
244
+ Performs an streaming post request to the LLM API.
245
+
246
+ Args:
247
+ messages_dicts: List of role / content dicts to use as input.
248
+ stop: list of stop tokens
249
+
250
+ Returns:
251
+ An iterator of response dicts.
252
+ """
253
+ try:
254
+ import sseclient
255
+ except ImportError:
256
+ raise ImportError(
257
+ "could not import sseclient library"
258
+ "Please install it with `pip install sseclient-py`."
259
+ )
260
+ data = {
261
+ "messages": messages_dicts,
262
+ "max_tokens": self.max_tokens,
263
+ "stop": stop,
264
+ "model": self.model,
265
+ "temperature": self.temperature,
266
+ "top_p": self.top_p,
267
+ "top_k": self.top_k,
268
+ "stream": True,
269
+ "stream_options": self.stream_options,
270
+ }
271
+ http_session = requests.Session()
272
+ response = http_session.post(
273
+ self.sambanova_url,
274
+ headers={
275
+ "Authorization": f"Bearer {self.sambanova_api_key.get_secret_value()}",
276
+ "Content-Type": "application/json",
277
+ },
278
+ json=data,
279
+ stream=True,
280
+ )
281
+
282
+ client = sseclient.SSEClient(response)
283
+
284
+ if response.status_code != 200:
285
+ raise RuntimeError(
286
+ f"Sambanova /complete call failed with status code "
287
+ f"{response.status_code}."
288
+ f"{response.text}."
289
+ )
290
+
291
+ for event in client.events():
292
+ chunk = {
293
+ "event": event.event,
294
+ "data": event.data,
295
+ "status_code": response.status_code,
296
+ }
297
+
298
+ if chunk["event"] == "error_event" or chunk["status_code"] != 200:
299
+ raise RuntimeError(
300
+ f"Sambanova /complete call failed with status code "
301
+ f"{chunk['status_code']}."
302
+ f"{chunk}."
303
+ )
304
+
305
+ try:
306
+ # check if the response is a final event
307
+ # in that case event data response is '[DONE]'
308
+ if chunk["data"] != "[DONE]":
309
+ if isinstance(chunk["data"], str):
310
+ data = json.loads(chunk["data"])
311
+ else:
312
+ raise RuntimeError(
313
+ f"Sambanova /complete call failed with status code "
314
+ f"{chunk['status_code']}."
315
+ f"{chunk}."
316
+ )
317
+ if data.get("error"):
318
+ raise RuntimeError(
319
+ f"Sambanova /complete call failed with status code "
320
+ f"{chunk['status_code']}."
321
+ f"{chunk}."
322
+ )
323
+ yield data
324
+ except Exception:
325
+ raise Exception(
326
+ f"Error getting content chunk raw streamed response: {chunk}"
327
+ )
328
+
329
+ def _convert_message_to_dict(self, message: BaseMessage) -> Dict[str, Any]:
330
+ """
331
+ convert a BaseMessage to a dictionary with Role / content
332
+
333
+ Args:
334
+ message: BaseMessage
335
+
336
+ Returns:
337
+ messages_dict: role / content dict
338
+ """
339
+ if isinstance(message, ChatMessage):
340
+ message_dict = {"role": message.role, "content": message.content}
341
+ elif isinstance(message, SystemMessage):
342
+ message_dict = {"role": "system", "content": message.content}
343
+ elif isinstance(message, HumanMessage):
344
+ message_dict = {"role": "user", "content": message.content}
345
+ elif isinstance(message, AIMessage):
346
+ message_dict = {"role": "assistant", "content": message.content}
347
+ elif isinstance(message, ToolMessage):
348
+ message_dict = {"role": "tool", "content": message.content}
349
+ else:
350
+ raise TypeError(f"Got unknown type {message}")
351
+ return message_dict
352
+
353
+ def _create_message_dicts(
354
+ self, messages: List[BaseMessage]
355
+ ) -> List[Dict[str, Any]]:
356
+ """
357
+ convert a lit of BaseMessages to a list of dictionaries with Role / content
358
+
359
+ Args:
360
+ messages: list of BaseMessages
361
+
362
+ Returns:
363
+ messages_dicts: list of role / content dicts
364
+ """
365
+ message_dicts = [self._convert_message_to_dict(m) for m in messages]
366
+ return message_dicts
367
+
368
+ def _generate(
369
+ self,
370
+ messages: List[BaseMessage],
371
+ stop: Optional[List[str]] = None,
372
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
373
+ **kwargs: Any,
374
+ ) -> ChatResult:
375
+ """
376
+ SambaNovaCloud chat model logic.
377
+
378
+ Call SambaNovaCloud API.
379
+
380
+ Args:
381
+ messages: the prompt composed of a list of messages.
382
+ stop: a list of strings on which the model should stop generating.
383
+ If generation stops due to a stop token, the stop token itself
384
+ SHOULD BE INCLUDED as part of the output. This is not enforced
385
+ across models right now, but it's a good practice to follow since
386
+ it makes it much easier to parse the output of the model
387
+ downstream and understand why generation stopped.
388
+ run_manager: A run manager with callbacks for the LLM.
389
+ """
390
+ if self.streaming:
391
+ stream_iter = self._stream(
392
+ messages, stop=stop, run_manager=run_manager, **kwargs
393
+ )
394
+ if stream_iter:
395
+ return generate_from_stream(stream_iter)
396
+ messages_dicts = self._create_message_dicts(messages)
397
+ response = self._handle_request(messages_dicts, stop)
398
+ message = AIMessage(
399
+ content=response["choices"][0]["message"]["content"],
400
+ additional_kwargs={},
401
+ response_metadata={
402
+ "finish_reason": response["choices"][0]["finish_reason"],
403
+ "usage": response.get("usage"),
404
+ "model_name": response["model"],
405
+ "system_fingerprint": response["system_fingerprint"],
406
+ "created": response["created"],
407
+ },
408
+ id=response["id"],
409
+ )
410
+
411
+ generation = ChatGeneration(message=message)
412
+ return ChatResult(generations=[generation])
413
+
414
+ def _stream(
415
+ self,
416
+ messages: List[BaseMessage],
417
+ stop: Optional[List[str]] = None,
418
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
419
+ **kwargs: Any,
420
+ ) -> Iterator[ChatGenerationChunk]:
421
+ """
422
+ Stream the output of the SambaNovaCloud chat model.
423
+
424
+ Args:
425
+ messages: the prompt composed of a list of messages.
426
+ stop: a list of strings on which the model should stop generating.
427
+ If generation stops due to a stop token, the stop token itself
428
+ SHOULD BE INCLUDED as part of the output. This is not enforced
429
+ across models right now, but it's a good practice to follow since
430
+ it makes it much easier to parse the output of the model
431
+ downstream and understand why generation stopped.
432
+ run_manager: A run manager with callbacks for the LLM.
433
+ """
434
+ messages_dicts = self._create_message_dicts(messages)
435
+ finish_reason = None
436
+ for partial_response in self._handle_streaming_request(messages_dicts, stop):
437
+ if len(partial_response["choices"]) > 0:
438
+ finish_reason = partial_response["choices"][0].get("finish_reason")
439
+ content = partial_response["choices"][0]["delta"]["content"]
440
+ id = partial_response["id"]
441
+ chunk = ChatGenerationChunk(
442
+ message=AIMessageChunk(content=content, id=id, additional_kwargs={})
443
+ )
444
+ else:
445
+ content = ""
446
+ id = partial_response["id"]
447
+ metadata = {
448
+ "finish_reason": finish_reason,
449
+ "usage": partial_response.get("usage"),
450
+ "model_name": partial_response["model"],
451
+ "system_fingerprint": partial_response["system_fingerprint"],
452
+ "created": partial_response["created"],
453
+ }
454
+ chunk = ChatGenerationChunk(
455
+ message=AIMessageChunk(
456
+ content=content,
457
+ id=id,
458
+ response_metadata=metadata,
459
+ additional_kwargs={},
460
+ )
461
+ )
462
+
463
+ if run_manager:
464
+ run_manager.on_llm_new_token(chunk.text, chunk=chunk)
465
+ yield chunk
utils/model_wrappers/langchain_embeddings.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Langchain Wrapper around Sambanova embedding APIs."""
2
+
3
+ import json
4
+ from typing import Dict, Generator, List, Optional
5
+
6
+ import requests
7
+ from langchain_core.embeddings import Embeddings
8
+ from langchain_core.pydantic_v1 import BaseModel
9
+ from langchain_core.utils import get_from_dict_or_env, pre_init
10
+
11
+
12
+ class SambaStudioEmbeddings(BaseModel, Embeddings):
13
+ """SambaNova embedding models.
14
+
15
+ To use, you should have the environment variables
16
+ ``SAMBASTUDIO_EMBEDDINGS_BASE_URL``, ``SAMBASTUDIO_EMBEDDINGS_BASE_URI``
17
+ ``SAMBASTUDIO_EMBEDDINGS_PROJECT_ID``, ``SAMBASTUDIO_EMBEDDINGS_ENDPOINT_ID``,
18
+ ``SAMBASTUDIO_EMBEDDINGS_API_KEY``
19
+ set with your personal sambastudio variable or pass it as a named parameter
20
+ to the constructor.
21
+
22
+ Example:
23
+ .. code-block:: python
24
+
25
+ from langchain_community.embeddings import SambaStudioEmbeddings
26
+
27
+ embeddings = SambaStudioEmbeddings(sambastudio_embeddings_base_url=base_url,
28
+ sambastudio_embeddings_base_uri=base_uri,
29
+ sambastudio_embeddings_project_id=project_id,
30
+ sambastudio_embeddings_endpoint_id=endpoint_id,
31
+ sambastudio_embeddings_api_key=api_key,
32
+ batch_size=32)
33
+ (or)
34
+
35
+ embeddings = SambaStudioEmbeddings(batch_size=32)
36
+
37
+ (or)
38
+
39
+ # CoE example
40
+ embeddings = SambaStudioEmbeddings(
41
+ batch_size=1,
42
+ model_kwargs={
43
+ 'select_expert':'e5-mistral-7b-instruct'
44
+ }
45
+ )
46
+ """
47
+
48
+ sambastudio_embeddings_base_url: str = ''
49
+ """Base url to use"""
50
+
51
+ sambastudio_embeddings_base_uri: str = ''
52
+ """endpoint base uri"""
53
+
54
+ sambastudio_embeddings_project_id: str = ''
55
+ """Project id on sambastudio for model"""
56
+
57
+ sambastudio_embeddings_endpoint_id: str = ''
58
+ """endpoint id on sambastudio for model"""
59
+
60
+ sambastudio_embeddings_api_key: str = ''
61
+ """sambastudio api key"""
62
+
63
+ model_kwargs: dict = {}
64
+ """Key word arguments to pass to the model."""
65
+
66
+ batch_size: int = 32
67
+ """Batch size for the embedding models"""
68
+
69
+ @pre_init
70
+ def validate_environment(cls, values: Dict) -> Dict:
71
+ """Validate that api key and python package exists in environment."""
72
+ values['sambastudio_embeddings_base_url'] = get_from_dict_or_env(
73
+ values, 'sambastudio_embeddings_base_url', 'SAMBASTUDIO_EMBEDDINGS_BASE_URL'
74
+ )
75
+ values['sambastudio_embeddings_base_uri'] = get_from_dict_or_env(
76
+ values,
77
+ 'sambastudio_embeddings_base_uri',
78
+ 'SAMBASTUDIO_EMBEDDINGS_BASE_URI',
79
+ default='api/predict/generic',
80
+ )
81
+ values['sambastudio_embeddings_project_id'] = get_from_dict_or_env(
82
+ values,
83
+ 'sambastudio_embeddings_project_id',
84
+ 'SAMBASTUDIO_EMBEDDINGS_PROJECT_ID',
85
+ )
86
+ values['sambastudio_embeddings_endpoint_id'] = get_from_dict_or_env(
87
+ values,
88
+ 'sambastudio_embeddings_endpoint_id',
89
+ 'SAMBASTUDIO_EMBEDDINGS_ENDPOINT_ID',
90
+ )
91
+ values['sambastudio_embeddings_api_key'] = get_from_dict_or_env(
92
+ values, 'sambastudio_embeddings_api_key', 'SAMBASTUDIO_EMBEDDINGS_API_KEY'
93
+ )
94
+ return values
95
+
96
+ def _get_tuning_params(self) -> str:
97
+ """
98
+ Get the tuning parameters to use when calling the model
99
+
100
+ Returns:
101
+ The tuning parameters as a JSON string.
102
+ """
103
+ if 'api/v2/predict/generic' in self.sambastudio_embeddings_base_uri:
104
+ tuning_params_dict = self.model_kwargs
105
+ else:
106
+ tuning_params_dict = {
107
+ k: {'type': type(v).__name__, 'value': str(v)} for k, v in (self.model_kwargs.items())
108
+ }
109
+ tuning_params = json.dumps(tuning_params_dict)
110
+ return tuning_params
111
+
112
+ def _get_full_url(self, path: str) -> str:
113
+ """
114
+ Return the full API URL for a given path.
115
+
116
+ :param str path: the sub-path
117
+ :returns: the full API URL for the sub-path
118
+ :rtype: str
119
+ """
120
+ return f'{self.sambastudio_embeddings_base_url}/{self.sambastudio_embeddings_base_uri}/{path}' # noqa: E501
121
+
122
+ def _iterate_over_batches(self, texts: List[str], batch_size: int) -> Generator:
123
+ """Generator for creating batches in the embed documents method
124
+ Args:
125
+ texts (List[str]): list of strings to embed
126
+ batch_size (int, optional): batch size to be used for the embedding model.
127
+ Will depend on the RDU endpoint used.
128
+ Yields:
129
+ List[str]: list (batch) of strings of size batch size
130
+ """
131
+ for i in range(0, len(texts), batch_size):
132
+ yield texts[i : i + batch_size]
133
+
134
+ def embed_documents(self, texts: List[str], batch_size: Optional[int] = None) -> List[List[float]]:
135
+ """Returns a list of embeddings for the given sentences.
136
+ Args:
137
+ texts (`List[str]`): List of texts to encode
138
+ batch_size (`int`): Batch size for the encoding
139
+
140
+ Returns:
141
+ `List[np.ndarray]` or `List[tensor]`: List of embeddings
142
+ for the given sentences
143
+ """
144
+ if batch_size is None:
145
+ batch_size = self.batch_size
146
+ http_session = requests.Session()
147
+ url = self._get_full_url(f'{self.sambastudio_embeddings_project_id}/{self.sambastudio_embeddings_endpoint_id}')
148
+ params = json.loads(self._get_tuning_params())
149
+ embeddings = []
150
+
151
+ if 'api/predict/nlp' in self.sambastudio_embeddings_base_uri:
152
+ for batch in self._iterate_over_batches(texts, batch_size):
153
+ data = {'inputs': batch, 'params': params}
154
+ response = http_session.post(
155
+ url,
156
+ headers={'key': self.sambastudio_embeddings_api_key},
157
+ json=data,
158
+ )
159
+ if response.status_code != 200:
160
+ raise RuntimeError(
161
+ f'Sambanova /complete call failed with status code '
162
+ f'{response.status_code}.\n Details: {response.text}'
163
+ )
164
+ try:
165
+ embedding = response.json()['data']
166
+ embeddings.extend(embedding)
167
+ except KeyError:
168
+ raise KeyError(
169
+ "'data' not found in endpoint response",
170
+ response.json(),
171
+ )
172
+
173
+ elif 'api/v2/predict/generic' in self.sambastudio_embeddings_base_uri:
174
+ for batch in self._iterate_over_batches(texts, batch_size):
175
+ items = [{'id': f'item{i}', 'value': item} for i, item in enumerate(batch)]
176
+ data = {'items': items, 'params': params}
177
+ response = http_session.post(
178
+ url,
179
+ headers={'key': self.sambastudio_embeddings_api_key},
180
+ json=data,
181
+ )
182
+ if response.status_code != 200:
183
+ raise RuntimeError(
184
+ f'Sambanova /complete call failed with status code '
185
+ f'{response.status_code}.\n Details: {response.text}'
186
+ )
187
+ try:
188
+ embedding = [item['value'] for item in response.json()['items']]
189
+ embeddings.extend(embedding)
190
+ except KeyError:
191
+ raise KeyError(
192
+ "'items' not found in endpoint response",
193
+ response.json(),
194
+ )
195
+
196
+ elif 'api/predict/generic' in self.sambastudio_embeddings_base_uri:
197
+ for batch in self._iterate_over_batches(texts, batch_size):
198
+ data = {'instances': batch, 'params': params}
199
+ response = http_session.post(
200
+ url,
201
+ headers={'key': self.sambastudio_embeddings_api_key},
202
+ json=data,
203
+ )
204
+ if response.status_code != 200:
205
+ raise RuntimeError(
206
+ f'Sambanova /complete call failed with status code '
207
+ f'{response.status_code}.\n Details: {response.text}'
208
+ )
209
+ try:
210
+ if params.get('select_expert'):
211
+ embedding = response.json()['predictions']
212
+ else:
213
+ embedding = response.json()['predictions']
214
+ embeddings.extend(embedding)
215
+ except KeyError:
216
+ raise KeyError(
217
+ "'predictions' not found in endpoint response",
218
+ response.json(),
219
+ )
220
+
221
+ else:
222
+ raise ValueError(
223
+ f'handling of endpoint uri: {self.sambastudio_embeddings_base_uri} not implemented' # noqa: E501
224
+ )
225
+
226
+ return embeddings
227
+
228
+ def embed_query(self, text: str) -> List[float]:
229
+ """Returns a list of embeddings for the given sentences.
230
+ Args:
231
+ sentences (`List[str]`): List of sentences to encode
232
+
233
+ Returns:
234
+ `List[np.ndarray]` or `List[tensor]`: List of embeddings
235
+ for the given sentences
236
+ """
237
+ http_session = requests.Session()
238
+ url = self._get_full_url(f'{self.sambastudio_embeddings_project_id}/{self.sambastudio_embeddings_endpoint_id}')
239
+ params = json.loads(self._get_tuning_params())
240
+
241
+ if 'api/predict/nlp' in self.sambastudio_embeddings_base_uri:
242
+ data = {'inputs': [text], 'params': params}
243
+ response = http_session.post(
244
+ url,
245
+ headers={'key': self.sambastudio_embeddings_api_key},
246
+ json=data,
247
+ )
248
+ if response.status_code != 200:
249
+ raise RuntimeError(
250
+ f'Sambanova /complete call failed with status code '
251
+ f'{response.status_code}.\n Details: {response.text}'
252
+ )
253
+ try:
254
+ embedding = response.json()['data'][0]
255
+ except KeyError:
256
+ raise KeyError(
257
+ "'data' not found in endpoint response",
258
+ response.json(),
259
+ )
260
+
261
+ elif 'api/v2/predict/generic' in self.sambastudio_embeddings_base_uri:
262
+ data = {'items': [{'id': 'item0', 'value': text}], 'params': params}
263
+ response = http_session.post(
264
+ url,
265
+ headers={'key': self.sambastudio_embeddings_api_key},
266
+ json=data,
267
+ )
268
+ if response.status_code != 200:
269
+ raise RuntimeError(
270
+ f'Sambanova /complete call failed with status code '
271
+ f'{response.status_code}.\n Details: {response.text}'
272
+ )
273
+ try:
274
+ embedding = response.json()['items'][0]['value']
275
+ except KeyError:
276
+ raise KeyError(
277
+ "'items' not found in endpoint response",
278
+ response.json(),
279
+ )
280
+
281
+ elif 'api/predict/generic' in self.sambastudio_embeddings_base_uri:
282
+ data = {'instances': [text], 'params': params}
283
+ response = http_session.post(
284
+ url,
285
+ headers={'key': self.sambastudio_embeddings_api_key},
286
+ json=data,
287
+ )
288
+ if response.status_code != 200:
289
+ raise RuntimeError(
290
+ f'Sambanova /complete call failed with status code '
291
+ f'{response.status_code}.\n Details: {response.text}'
292
+ )
293
+ try:
294
+ if params.get('select_expert'):
295
+ embedding = response.json()['predictions'][0]
296
+ else:
297
+ embedding = response.json()['predictions'][0]
298
+ except KeyError:
299
+ raise KeyError(
300
+ "'predictions' not found in endpoint response",
301
+ response.json(),
302
+ )
303
+
304
+ else:
305
+ raise ValueError(
306
+ f'handling of endpoint uri: {self.sambastudio_embeddings_base_uri} not implemented' # noqa: E501
307
+ )
308
+
309
+ return embedding
utils/model_wrappers/langchain_llms.py ADDED
@@ -0,0 +1,770 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Langchain Wrapper around Sambanova LLM APIs."""
2
+
3
+ import json
4
+ from typing import Any, Dict, Generator, Iterator, List, Optional, Union
5
+
6
+ import requests
7
+ from langchain_core.callbacks.manager import CallbackManagerForLLMRun
8
+ from langchain_core.language_models.llms import LLM
9
+ from langchain_core.outputs import GenerationChunk
10
+ from langchain_core.pydantic_v1 import Extra
11
+ from langchain_core.utils import get_from_dict_or_env, pre_init
12
+
13
+
14
+ class SSEndpointHandler:
15
+ """
16
+ SambaNova Systems Interface for SambaStudio model endpoints.
17
+
18
+ :param str host_url: Base URL of the DaaS API service
19
+ """
20
+
21
+ def __init__(self, host_url: str, api_base_uri: str):
22
+ """
23
+ Initialize the SSEndpointHandler.
24
+
25
+ :param str host_url: Base URL of the DaaS API service
26
+ :param str api_base_uri: Base URI of the DaaS API service
27
+ """
28
+ self.host_url = host_url
29
+ self.api_base_uri = api_base_uri
30
+ self.http_session = requests.Session()
31
+
32
+ def _process_response(self, response: requests.Response) -> Dict:
33
+ """
34
+ Processes the API response and returns the resulting dict.
35
+
36
+ All resulting dicts, regardless of success or failure, will contain the
37
+ `status_code` key with the API response status code.
38
+
39
+ If the API returned an error, the resulting dict will contain the key
40
+ `detail` with the error message.
41
+
42
+ If the API call was successful, the resulting dict will contain the key
43
+ `data` with the response data.
44
+
45
+ :param requests.Response response: the response object to process
46
+ :return: the response dict
47
+ :type: dict
48
+ """
49
+ result: Dict[str, Any] = {}
50
+ try:
51
+ result = response.json()
52
+ except Exception as e:
53
+ result['detail'] = str(e)
54
+ if 'status_code' not in result:
55
+ result['status_code'] = response.status_code
56
+ return result
57
+
58
+ def _process_streaming_response(
59
+ self,
60
+ response: requests.Response,
61
+ ) -> Generator[Dict, None, None]:
62
+ """Process the streaming response"""
63
+ if 'api/predict/nlp' in self.api_base_uri:
64
+ try:
65
+ import sseclient
66
+ except ImportError:
67
+ raise ImportError(
68
+ 'could not import sseclient library' 'Please install it with `pip install sseclient-py`.'
69
+ )
70
+ client = sseclient.SSEClient(response)
71
+ close_conn = False
72
+ for event in client.events():
73
+ if event.event == 'error_event':
74
+ close_conn = True
75
+ chunk = {
76
+ 'event': event.event,
77
+ 'data': event.data,
78
+ 'status_code': response.status_code,
79
+ }
80
+ yield chunk
81
+ if close_conn:
82
+ client.close()
83
+ elif 'api/v2/predict/generic' in self.api_base_uri or 'api/predict/generic' in self.api_base_uri:
84
+ try:
85
+ for line in response.iter_lines():
86
+ chunk = json.loads(line)
87
+ if 'status_code' not in chunk:
88
+ chunk['status_code'] = response.status_code
89
+ yield chunk
90
+ except Exception as e:
91
+ raise RuntimeError(f'Error processing streaming response: {e}')
92
+ else:
93
+ raise ValueError(f'handling of endpoint uri: {self.api_base_uri} not implemented')
94
+
95
+ def _get_full_url(self, path: str) -> str:
96
+ """
97
+ Return the full API URL for a given path.
98
+
99
+ :param str path: the sub-path
100
+ :returns: the full API URL for the sub-path
101
+ :type: str
102
+ """
103
+ return f'{self.host_url}/{self.api_base_uri}/{path}'
104
+
105
+ def nlp_predict(
106
+ self,
107
+ project: str,
108
+ endpoint: str,
109
+ key: str,
110
+ input: Union[List[str], str],
111
+ params: Optional[str] = '',
112
+ stream: bool = False,
113
+ ) -> Dict:
114
+ """
115
+ NLP predict using inline input string.
116
+
117
+ :param str project: Project ID in which the endpoint exists
118
+ :param str endpoint: Endpoint ID
119
+ :param str key: API Key
120
+ :param str input_str: Input string
121
+ :param str params: Input params string
122
+ :returns: Prediction results
123
+ :type: dict
124
+ """
125
+ if isinstance(input, str):
126
+ input = [input]
127
+ if 'api/predict/nlp' in self.api_base_uri:
128
+ if params:
129
+ data = {'inputs': input, 'params': json.loads(params)}
130
+ else:
131
+ data = {'inputs': input}
132
+ elif 'api/v2/predict/generic' in self.api_base_uri:
133
+ items = [{'id': f'item{i}', 'value': item} for i, item in enumerate(input)]
134
+ if params:
135
+ data = {'items': items, 'params': json.loads(params)}
136
+ else:
137
+ data = {'items': items}
138
+ elif 'api/predict/generic' in self.api_base_uri:
139
+ if params:
140
+ data = {'instances': input, 'params': json.loads(params)}
141
+ else:
142
+ data = {'instances': input}
143
+ else:
144
+ raise ValueError(f'handling of endpoint uri: {self.api_base_uri} not implemented')
145
+ response = self.http_session.post(
146
+ self._get_full_url(f'{project}/{endpoint}'),
147
+ headers={'key': key},
148
+ json=data,
149
+ )
150
+ return self._process_response(response)
151
+
152
+ def nlp_predict_stream(
153
+ self,
154
+ project: str,
155
+ endpoint: str,
156
+ key: str,
157
+ input: Union[List[str], str],
158
+ params: Optional[str] = '',
159
+ ) -> Iterator[Dict]:
160
+ """
161
+ NLP predict using inline input string.
162
+
163
+ :param str project: Project ID in which the endpoint exists
164
+ :param str endpoint: Endpoint ID
165
+ :param str key: API Key
166
+ :param str input_str: Input string
167
+ :param str params: Input params string
168
+ :returns: Prediction results
169
+ :type: dict
170
+ """
171
+ if 'api/predict/nlp' in self.api_base_uri:
172
+ if isinstance(input, str):
173
+ input = [input]
174
+ if params:
175
+ data = {'inputs': input, 'params': json.loads(params)}
176
+ else:
177
+ data = {'inputs': input}
178
+ elif 'api/v2/predict/generic' in self.api_base_uri:
179
+ if isinstance(input, str):
180
+ input = [input]
181
+ items = [{'id': f'item{i}', 'value': item} for i, item in enumerate(input)]
182
+ if params:
183
+ data = {'items': items, 'params': json.loads(params)}
184
+ else:
185
+ data = {'items': items}
186
+ elif 'api/predict/generic' in self.api_base_uri:
187
+ if isinstance(input, list):
188
+ input = input[0]
189
+ if params:
190
+ data = {'instance': input, 'params': json.loads(params)}
191
+ else:
192
+ data = {'instance': input}
193
+ else:
194
+ raise ValueError(f'handling of endpoint uri: {self.api_base_uri} not implemented')
195
+ # Streaming output
196
+ response = self.http_session.post(
197
+ self._get_full_url(f'stream/{project}/{endpoint}'),
198
+ headers={'key': key},
199
+ json=data,
200
+ stream=True,
201
+ )
202
+ for chunk in self._process_streaming_response(response):
203
+ yield chunk
204
+
205
+
206
+ class SambaStudio(LLM):
207
+ """
208
+ SambaStudio large language models.
209
+
210
+ To use, you should have the environment variables
211
+ ``SAMBASTUDIO_BASE_URL`` set with your SambaStudio environment URL.
212
+ ``SAMBASTUDIO_BASE_URI`` set with your SambaStudio api base URI.
213
+ ``SAMBASTUDIO_PROJECT_ID`` set with your SambaStudio project ID.
214
+ ``SAMBASTUDIO_ENDPOINT_ID`` set with your SambaStudio endpoint ID.
215
+ ``SAMBASTUDIO_API_KEY`` set with your SambaStudio endpoint API key.
216
+
217
+ https://sambanova.ai/products/enterprise-ai-platform-sambanova-suite
218
+
219
+ read extra documentation in https://docs.sambanova.ai/sambastudio/latest/index.html
220
+
221
+ Example:
222
+ .. code-block:: python
223
+
224
+ from langchain_community.llms.sambanova import SambaStudio
225
+ SambaStudio(
226
+ sambastudio_base_url="your-SambaStudio-environment-URL",
227
+ sambastudio_base_uri="your-SambaStudio-base-URI",
228
+ sambastudio_project_id="your-SambaStudio-project-ID",
229
+ sambastudio_endpoint_id="your-SambaStudio-endpoint-ID",
230
+ sambastudio_api_key="your-SambaStudio-endpoint-API-key,
231
+ streaming=False
232
+ model_kwargs={
233
+ "do_sample": False,
234
+ "max_tokens_to_generate": 1000,
235
+ "temperature": 0.7,
236
+ "top_p": 1.0,
237
+ "repetition_penalty": 1,
238
+ "top_k": 50,
239
+ #"process_prompt": False,
240
+ #"select_expert": "Meta-Llama-3-8B-Instruct"
241
+ },
242
+ )
243
+ """
244
+
245
+ sambastudio_base_url: str = ''
246
+ """Base url to use"""
247
+
248
+ sambastudio_base_uri: str = ''
249
+ """endpoint base uri"""
250
+
251
+ sambastudio_project_id: str = ''
252
+ """Project id on sambastudio for model"""
253
+
254
+ sambastudio_endpoint_id: str = ''
255
+ """endpoint id on sambastudio for model"""
256
+
257
+ sambastudio_api_key: str = ''
258
+ """sambastudio api key"""
259
+
260
+ model_kwargs: Optional[dict] = None
261
+ """Key word arguments to pass to the model."""
262
+
263
+ streaming: Optional[bool] = False
264
+ """Streaming flag to get streamed response."""
265
+
266
+ class Config:
267
+ """Configuration for this pydantic object."""
268
+
269
+ extra = Extra.forbid
270
+
271
+ @classmethod
272
+ def is_lc_serializable(cls) -> bool:
273
+ return True
274
+
275
+ @property
276
+ def _identifying_params(self) -> Dict[str, Any]:
277
+ """Get the identifying parameters."""
278
+ return {**{'model_kwargs': self.model_kwargs}}
279
+
280
+ @property
281
+ def _llm_type(self) -> str:
282
+ """Return type of llm."""
283
+ return 'Sambastudio LLM'
284
+
285
+ @pre_init
286
+ def validate_environment(cls, values: Dict) -> Dict:
287
+ """Validate that api key and python package exists in environment."""
288
+ values['sambastudio_base_url'] = get_from_dict_or_env(values, 'sambastudio_base_url', 'SAMBASTUDIO_BASE_URL')
289
+ values['sambastudio_base_uri'] = get_from_dict_or_env(
290
+ values,
291
+ 'sambastudio_base_uri',
292
+ 'SAMBASTUDIO_BASE_URI',
293
+ default='api/predict/generic',
294
+ )
295
+ values['sambastudio_project_id'] = get_from_dict_or_env(
296
+ values, 'sambastudio_project_id', 'SAMBASTUDIO_PROJECT_ID'
297
+ )
298
+ values['sambastudio_endpoint_id'] = get_from_dict_or_env(
299
+ values, 'sambastudio_endpoint_id', 'SAMBASTUDIO_ENDPOINT_ID'
300
+ )
301
+ values['sambastudio_api_key'] = get_from_dict_or_env(values, 'sambastudio_api_key', 'SAMBASTUDIO_API_KEY')
302
+ return values
303
+
304
+ def _get_tuning_params(self, stop: Optional[List[str]]) -> str:
305
+ """
306
+ Get the tuning parameters to use when calling the LLM.
307
+
308
+ Args:
309
+ stop: Stop words to use when generating. Model output is cut off at the
310
+ first occurrence of any of the stop substrings.
311
+
312
+ Returns:
313
+ The tuning parameters as a JSON string.
314
+ """
315
+ _model_kwargs = self.model_kwargs or {}
316
+ _kwarg_stop_sequences = _model_kwargs.get('stop_sequences', [])
317
+ _stop_sequences = stop or _kwarg_stop_sequences
318
+ # if not _kwarg_stop_sequences:
319
+ # _model_kwargs["stop_sequences"] = ",".join(
320
+ # f'"{x}"' for x in _stop_sequences
321
+ # )
322
+ if 'api/v2/predict/generic' in self.sambastudio_base_uri:
323
+ tuning_params_dict = _model_kwargs
324
+ else:
325
+ tuning_params_dict = {k: {'type': type(v).__name__, 'value': str(v)} for k, v in (_model_kwargs.items())}
326
+ # _model_kwargs["stop_sequences"] = _kwarg_stop_sequences
327
+ tuning_params = json.dumps(tuning_params_dict)
328
+ return tuning_params
329
+
330
+ def _handle_nlp_predict(self, sdk: SSEndpointHandler, prompt: Union[List[str], str], tuning_params: str) -> str:
331
+ """
332
+ Perform an NLP prediction using the SambaStudio endpoint handler.
333
+
334
+ Args:
335
+ sdk: The SSEndpointHandler to use for the prediction.
336
+ prompt: The prompt to use for the prediction.
337
+ tuning_params: The tuning parameters to use for the prediction.
338
+
339
+ Returns:
340
+ The prediction result.
341
+
342
+ Raises:
343
+ ValueError: If the prediction fails.
344
+ """
345
+ response = sdk.nlp_predict(
346
+ self.sambastudio_project_id,
347
+ self.sambastudio_endpoint_id,
348
+ self.sambastudio_api_key,
349
+ prompt,
350
+ tuning_params,
351
+ )
352
+ if response['status_code'] != 200:
353
+ optional_detail = response.get('detail')
354
+ if optional_detail:
355
+ raise RuntimeError(
356
+ f"Sambanova /complete call failed with status code "
357
+ f"{response['status_code']}.\n Details: {optional_detail}"
358
+ )
359
+ else:
360
+ raise RuntimeError(
361
+ f"Sambanova /complete call failed with status code "
362
+ f"{response['status_code']}.\n response {response}"
363
+ )
364
+ if 'api/predict/nlp' in self.sambastudio_base_uri:
365
+ return response['data'][0]['completion']
366
+ elif 'api/v2/predict/generic' in self.sambastudio_base_uri:
367
+ return response['items'][0]['value']['completion']
368
+ elif 'api/predict/generic' in self.sambastudio_base_uri:
369
+ return response['predictions'][0]['completion']
370
+ else:
371
+ raise ValueError(f'handling of endpoint uri: {self.sambastudio_base_uri} not implemented')
372
+
373
+ def _handle_completion_requests(self, prompt: Union[List[str], str], stop: Optional[List[str]]) -> str:
374
+ """
375
+ Perform a prediction using the SambaStudio endpoint handler.
376
+
377
+ Args:
378
+ prompt: The prompt to use for the prediction.
379
+ stop: stop sequences.
380
+
381
+ Returns:
382
+ The prediction result.
383
+
384
+ Raises:
385
+ ValueError: If the prediction fails.
386
+ """
387
+ ss_endpoint = SSEndpointHandler(self.sambastudio_base_url, self.sambastudio_base_uri)
388
+ tuning_params = self._get_tuning_params(stop)
389
+ return self._handle_nlp_predict(ss_endpoint, prompt, tuning_params)
390
+
391
+ def _handle_nlp_predict_stream(
392
+ self, sdk: SSEndpointHandler, prompt: Union[List[str], str], tuning_params: str
393
+ ) -> Iterator[GenerationChunk]:
394
+ """
395
+ Perform a streaming request to the LLM.
396
+
397
+ Args:
398
+ sdk: The SVEndpointHandler to use for the prediction.
399
+ prompt: The prompt to use for the prediction.
400
+ tuning_params: The tuning parameters to use for the prediction.
401
+
402
+ Returns:
403
+ An iterator of GenerationChunks.
404
+ """
405
+ for chunk in sdk.nlp_predict_stream(
406
+ self.sambastudio_project_id,
407
+ self.sambastudio_endpoint_id,
408
+ self.sambastudio_api_key,
409
+ prompt,
410
+ tuning_params,
411
+ ):
412
+ if chunk['status_code'] != 200:
413
+ error = chunk.get('error')
414
+ if error:
415
+ optional_code = error.get('code')
416
+ optional_details = error.get('details')
417
+ optional_message = error.get('message')
418
+ raise ValueError(
419
+ f"Sambanova /complete call failed with status code "
420
+ f"{chunk['status_code']}.\n"
421
+ f"Message: {optional_message}\n"
422
+ f"Details: {optional_details}\n"
423
+ f"Code: {optional_code}\n"
424
+ )
425
+ else:
426
+ raise RuntimeError(
427
+ f"Sambanova /complete call failed with status code " f"{chunk['status_code']}." f"{chunk}."
428
+ )
429
+ if 'api/predict/nlp' in self.sambastudio_base_uri:
430
+ text = json.loads(chunk['data'])['stream_token']
431
+ elif 'api/v2/predict/generic' in self.sambastudio_base_uri:
432
+ text = chunk['result']['items'][0]['value']['stream_token']
433
+ elif 'api/predict/generic' in self.sambastudio_base_uri:
434
+ if len(chunk['result']['responses']) > 0:
435
+ text = chunk['result']['responses'][0]['stream_token']
436
+ else:
437
+ text = ''
438
+ else:
439
+ raise ValueError(f'handling of endpoint uri: {self.sambastudio_base_uri}' f'not implemented')
440
+ generated_chunk = GenerationChunk(text=text)
441
+ yield generated_chunk
442
+
443
+ def _stream(
444
+ self,
445
+ prompt: Union[List[str], str],
446
+ stop: Optional[List[str]] = None,
447
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
448
+ **kwargs: Any,
449
+ ) -> Iterator[GenerationChunk]:
450
+ """Call out to Sambanova's complete endpoint.
451
+
452
+ Args:
453
+ prompt: The prompt to pass into the model.
454
+ stop: Optional list of stop words to use when generating.
455
+
456
+ Returns:
457
+ The string generated by the model.
458
+ """
459
+ ss_endpoint = SSEndpointHandler(self.sambastudio_base_url, self.sambastudio_base_uri)
460
+ tuning_params = self._get_tuning_params(stop)
461
+ try:
462
+ if self.streaming:
463
+ for chunk in self._handle_nlp_predict_stream(ss_endpoint, prompt, tuning_params):
464
+ if run_manager:
465
+ run_manager.on_llm_new_token(chunk.text)
466
+ yield chunk
467
+ else:
468
+ return
469
+ except Exception as e:
470
+ # Handle any errors raised by the inference endpoint
471
+ raise ValueError(f'Error raised by the inference endpoint: {e}') from e
472
+
473
+ def _handle_stream_request(
474
+ self,
475
+ prompt: Union[List[str], str],
476
+ stop: Optional[List[str]],
477
+ run_manager: Optional[CallbackManagerForLLMRun],
478
+ kwargs: Dict[str, Any],
479
+ ) -> str:
480
+ """
481
+ Perform a streaming request to the LLM.
482
+
483
+ Args:
484
+ prompt: The prompt to generate from.
485
+ stop: Stop words to use when generating. Model output is cut off at the
486
+ first occurrence of any of the stop substrings.
487
+ run_manager: Callback manager for the run.
488
+ **kwargs: Additional keyword arguments. directly passed
489
+ to the sambastudio model in API call.
490
+
491
+ Returns:
492
+ The model output as a string.
493
+ """
494
+ completion = ''
495
+ for chunk in self._stream(prompt=prompt, stop=stop, run_manager=run_manager, **kwargs):
496
+ completion += chunk.text
497
+ return completion
498
+
499
+ def _call(
500
+ self,
501
+ prompt: Union[List[str], str],
502
+ stop: Optional[List[str]] = None,
503
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
504
+ **kwargs: Any,
505
+ ) -> str:
506
+ """Call out to Sambanova's complete endpoint.
507
+
508
+ Args:
509
+ prompt: The prompt to pass into the model.
510
+ stop: Optional list of stop words to use when generating.
511
+
512
+ Returns:
513
+ The string generated by the model.
514
+ """
515
+ if stop is not None:
516
+ raise Exception('stop not implemented')
517
+ try:
518
+ if self.streaming:
519
+ return self._handle_stream_request(prompt, stop, run_manager, kwargs)
520
+ return self._handle_completion_requests(prompt, stop)
521
+ except Exception as e:
522
+ # Handle any errors raised by the inference endpoint
523
+ raise ValueError(f'Error raised by the inference endpoint: {e}') from e
524
+
525
+
526
+ class SambaNovaCloud(LLM):
527
+ """
528
+ SambaNova Cloud large language models.
529
+
530
+ To use, you should have the environment variables
531
+ ``SAMBANOVA_URL`` set with your SambaNova Cloud URL.
532
+ ``SAMBANOVA_API_KEY`` set with your SambaNova Cloud API Key.
533
+
534
+ http://cloud.sambanova.ai/
535
+
536
+ Example:
537
+ .. code-block:: python
538
+
539
+ SambaNovaCloud(
540
+ sambanova_url = SambaNova cloud endpoint URL,
541
+ sambanova_api_key = set with your SambaNova cloud API key,
542
+ max_tokens = mas number of tokens to generate
543
+ stop_tokens = list of stop tokens
544
+ model = model name
545
+ )
546
+ """
547
+
548
+ sambanova_url: str = ''
549
+ """SambaNova Cloud Url"""
550
+
551
+ sambanova_api_key: str = ''
552
+ """SambaNova Cloud api key"""
553
+
554
+ max_tokens: int = 1024
555
+ """max tokens to generate"""
556
+
557
+ stop_tokens: list = ['<|eot_id|>']
558
+ """Stop tokens"""
559
+
560
+ model: str = 'llama3-8b'
561
+ """LLM model expert to use"""
562
+
563
+ temperature: float = 0.0
564
+ """model temperature"""
565
+
566
+ top_p: float = 0.0
567
+ """model top p"""
568
+
569
+ top_k: int = 1
570
+ """model top k"""
571
+
572
+ stream_api: bool = True
573
+ """use stream api"""
574
+
575
+ stream_options: dict = {'include_usage': True}
576
+ """stream options, include usage to get generation metrics"""
577
+
578
+ class Config:
579
+ """Configuration for this pydantic object."""
580
+
581
+ extra = Extra.forbid
582
+
583
+ @classmethod
584
+ def is_lc_serializable(cls) -> bool:
585
+ return True
586
+
587
+ @property
588
+ def _identifying_params(self) -> Dict[str, Any]:
589
+ """Get the identifying parameters."""
590
+ return {
591
+ 'model': self.model,
592
+ 'max_tokens': self.max_tokens,
593
+ 'stop': self.stop_tokens,
594
+ 'temperature': self.temperature,
595
+ 'top_p': self.top_p,
596
+ 'top_k': self.top_k,
597
+ }
598
+
599
+ @property
600
+ def _llm_type(self) -> str:
601
+ """Return type of llm."""
602
+ return 'SambaNova Cloud'
603
+
604
+ @pre_init
605
+ def validate_environment(cls, values: Dict) -> Dict:
606
+ """Validate that api key and python package exists in environment."""
607
+ values['sambanova_url'] = get_from_dict_or_env(
608
+ values, 'sambanova_url', 'SAMBANOVA_URL', default='https://api.sambanova.ai/v1/chat/completions'
609
+ )
610
+ values['sambanova_api_key'] = get_from_dict_or_env(values, 'sambanova_api_key', 'SAMBANOVA_API_KEY')
611
+ return values
612
+
613
+ def _handle_nlp_predict_stream(
614
+ self,
615
+ prompt: Union[List[str], str],
616
+ stop: List[str],
617
+ ) -> Iterator[GenerationChunk]:
618
+ """
619
+ Perform a streaming request to the LLM.
620
+
621
+ Args:
622
+ prompt: The prompt to use for the prediction.
623
+ stop: list of stop tokens
624
+
625
+ Returns:
626
+ An iterator of GenerationChunks.
627
+ """
628
+ try:
629
+ import sseclient
630
+ except ImportError:
631
+ raise ImportError('could not import sseclient library' 'Please install it with `pip install sseclient-py`.')
632
+ try:
633
+ formatted_prompt = json.loads(prompt)
634
+ except:
635
+ formatted_prompt = [{'role': 'user', 'content': prompt}]
636
+
637
+ http_session = requests.Session()
638
+ if not stop:
639
+ stop = self.stop_tokens
640
+ data = {
641
+ 'messages': formatted_prompt,
642
+ 'max_tokens': self.max_tokens,
643
+ 'stop': stop,
644
+ 'model': self.model,
645
+ 'temperature': self.temperature,
646
+ 'top_p': self.top_p,
647
+ 'top_k': self.top_k,
648
+ 'stream': self.stream_api,
649
+ 'stream_options': self.stream_options,
650
+ }
651
+ # Streaming output
652
+ response = http_session.post(
653
+ self.sambanova_url,
654
+ headers={'Authorization': f'Bearer {self.sambanova_api_key}', 'Content-Type': 'application/json'},
655
+ json=data,
656
+ stream=True,
657
+ )
658
+
659
+ client = sseclient.SSEClient(response)
660
+ close_conn = False
661
+
662
+ if response.status_code != 200:
663
+ raise RuntimeError(
664
+ f'Sambanova /complete call failed with status code ' f'{response.status_code}.' f'{response.text}.'
665
+ )
666
+
667
+ for event in client.events():
668
+ if event.event == 'error_event':
669
+ close_conn = True
670
+ chunk = {
671
+ 'event': event.event,
672
+ 'data': event.data,
673
+ 'status_code': response.status_code,
674
+ }
675
+
676
+ if chunk.get('error'):
677
+ raise RuntimeError(
678
+ f"Sambanova /complete call failed with status code " f"{chunk['status_code']}." f"{chunk}."
679
+ )
680
+
681
+ try:
682
+ # check if the response is a final event in that case event data response is '[DONE]'
683
+ if chunk['data'] != '[DONE]':
684
+ data = json.loads(chunk['data'])
685
+ if data.get('error'):
686
+ raise RuntimeError(
687
+ f"Sambanova /complete call failed with status code " f"{chunk['status_code']}." f"{chunk}."
688
+ )
689
+ # check if the response is a final response with usage stats (not includes content)
690
+ if data.get('usage') is None:
691
+ # check is not "end of text" response
692
+ if data['choices'][0]['finish_reason'] is None:
693
+ text = data['choices'][0]['delta']['content']
694
+ generated_chunk = GenerationChunk(text=text)
695
+ yield generated_chunk
696
+ except Exception as e:
697
+ raise Exception(f'Error getting content chunk raw streamed response: {chunk}')
698
+
699
+ def _stream(
700
+ self,
701
+ prompt: Union[List[str], str],
702
+ stop: Optional[List[str]] = None,
703
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
704
+ **kwargs: Any,
705
+ ) -> Iterator[GenerationChunk]:
706
+ """Call out to Sambanova's complete endpoint.
707
+
708
+ Args:
709
+ prompt: The prompt to pass into the model.
710
+ stop: Optional list of stop words to use when generating.
711
+
712
+ Returns:
713
+ The string generated by the model.
714
+ """
715
+ try:
716
+ for chunk in self._handle_nlp_predict_stream(prompt, stop):
717
+ if run_manager:
718
+ run_manager.on_llm_new_token(chunk.text)
719
+ yield chunk
720
+ except Exception as e:
721
+ # Handle any errors raised by the inference endpoint
722
+ raise ValueError(f'Error raised by the inference endpoint: {e}') from e
723
+
724
+ def _handle_stream_request(
725
+ self,
726
+ prompt: Union[List[str], str],
727
+ stop: Optional[List[str]],
728
+ run_manager: Optional[CallbackManagerForLLMRun],
729
+ kwargs: Dict[str, Any],
730
+ ) -> str:
731
+ """
732
+ Perform a streaming request to the LLM.
733
+
734
+ Args:
735
+ prompt: The prompt to generate from.
736
+ stop: Stop words to use when generating. Model output is cut off at the
737
+ first occurrence of any of the stop substrings.
738
+ run_manager: Callback manager for the run.
739
+ **kwargs: Additional keyword arguments. directly passed
740
+ to the Sambanova Cloud model in API call.
741
+
742
+ Returns:
743
+ The model output as a string.
744
+ """
745
+ completion = ''
746
+ for chunk in self._stream(prompt=prompt, stop=stop, run_manager=run_manager, **kwargs):
747
+ completion += chunk.text
748
+ return completion
749
+
750
+ def _call(
751
+ self,
752
+ prompt: Union[List[str], str],
753
+ stop: Optional[List[str]] = None,
754
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
755
+ **kwargs: Any,
756
+ ) -> str:
757
+ """Call out to Sambanova's complete endpoint.
758
+
759
+ Args:
760
+ prompt: The prompt to pass into the model.
761
+ stop: Optional list of stop words to use when generating.
762
+
763
+ Returns:
764
+ The string generated by the model.
765
+ """
766
+ try:
767
+ return self._handle_stream_request(prompt, stop, run_manager, kwargs)
768
+ except Exception as e:
769
+ # Handle any errors raised by the inference endpoint
770
+ raise ValueError(f'Error raised by the inference endpoint: {e}') from e
utils/model_wrappers/usage.ipynb ADDED
@@ -0,0 +1,878 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# SambanNova Langchain Wrappers Usage"
8
+ ]
9
+ },
10
+ {
11
+ "cell_type": "code",
12
+ "execution_count": 2,
13
+ "metadata": {},
14
+ "outputs": [
15
+ {
16
+ "data": {
17
+ "text/plain": [
18
+ "True"
19
+ ]
20
+ },
21
+ "execution_count": 2,
22
+ "metadata": {},
23
+ "output_type": "execute_result"
24
+ }
25
+ ],
26
+ "source": [
27
+ "import os\n",
28
+ "\n",
29
+ "from dotenv import load_dotenv\n",
30
+ "from langchain_embeddings import SambaStudioEmbeddings\n",
31
+ "from langchain_llms import SambaStudio, SambaNovaCloud\n",
32
+ "from langchain_chat_models import ChatSambaNovaCloud\n",
33
+ "from langchain_core.messages import SystemMessage, HumanMessage\n",
34
+ "\n",
35
+ "current_dir = os.getcwd()\n",
36
+ "utils_dir = os.path.abspath(os.path.join(current_dir, '..'))\n",
37
+ "repo_dir = os.path.abspath(os.path.join(utils_dir, '..'))\n",
38
+ "\n",
39
+ "load_dotenv(os.path.join(repo_dir, '.env'), override=True)"
40
+ ]
41
+ },
42
+ {
43
+ "cell_type": "markdown",
44
+ "metadata": {},
45
+ "source": [
46
+ "# SambaStudio LLM"
47
+ ]
48
+ },
49
+ {
50
+ "cell_type": "markdown",
51
+ "metadata": {},
52
+ "source": [
53
+ "## Non streaming"
54
+ ]
55
+ },
56
+ {
57
+ "cell_type": "code",
58
+ "execution_count": 9,
59
+ "metadata": {},
60
+ "outputs": [],
61
+ "source": [
62
+ "llm = SambaStudio(\n",
63
+ " streaming=False,\n",
64
+ " # base_uri=\"api/predict/generic\",\n",
65
+ " model_kwargs={\n",
66
+ " 'do_sample': False,\n",
67
+ " 'temperature': 0.01,\n",
68
+ " 'max_tokens_to_generate': 256,\n",
69
+ " 'process_prompt': False,\n",
70
+ " 'select_expert': 'Meta-Llama-3-70B-Instruct-4096',\n",
71
+ " },\n",
72
+ ")"
73
+ ]
74
+ },
75
+ {
76
+ "cell_type": "code",
77
+ "execution_count": 11,
78
+ "metadata": {},
79
+ "outputs": [
80
+ {
81
+ "data": {
82
+ "text/plain": [
83
+ "' of a brave knight\\nSir Valoric, the fearless knight, charged into the dark forest, his armor shining like the sun. He battled the dragon, its fiery breath singeing his beard, but he stood tall, his sword flashing in the moonlight, until the beast lay defeated at his feet, its treasure his noble reward.'"
84
+ ]
85
+ },
86
+ "execution_count": 11,
87
+ "metadata": {},
88
+ "output_type": "execute_result"
89
+ }
90
+ ],
91
+ "source": [
92
+ "llm.invoke('tell me a 50 word tale')"
93
+ ]
94
+ },
95
+ {
96
+ "cell_type": "markdown",
97
+ "metadata": {},
98
+ "source": [
99
+ "## Streaming"
100
+ ]
101
+ },
102
+ {
103
+ "cell_type": "code",
104
+ "execution_count": null,
105
+ "metadata": {},
106
+ "outputs": [],
107
+ "source": [
108
+ "llm = SambaStudio(\n",
109
+ " streaming=True,\n",
110
+ " model_kwargs={\n",
111
+ " 'do_sample': False,\n",
112
+ " 'max_tokens_to_generate': 256,\n",
113
+ " 'temperature': 0.01,\n",
114
+ " 'process_prompt': False,\n",
115
+ " 'select_expert': 'Meta-Llama-3-70B-Instruct-4096',\n",
116
+ " },\n",
117
+ ")"
118
+ ]
119
+ },
120
+ {
121
+ "cell_type": "code",
122
+ "execution_count": null,
123
+ "metadata": {},
124
+ "outputs": [
125
+ {
126
+ "name": "stdout",
127
+ "output_type": "stream",
128
+ "text": [
129
+ " of a character who is a master of disguise\n",
130
+ "\n",
131
+ "Sure! Here is a 50-word tale of a character who is a master of disguise:\n",
132
+ "\n",
133
+ "\"Araxys, the skilled disguise artist, transformed into a stunning mermaid to infiltrate a pirate's lair. With a flick of her tail, she charmed the pirates and stole their treasure.\""
134
+ ]
135
+ }
136
+ ],
137
+ "source": [
138
+ "for chunk in llm.stream('tell me a 50 word tale'):\n",
139
+ " print(chunk, end='', flush=True)"
140
+ ]
141
+ },
142
+ {
143
+ "cell_type": "markdown",
144
+ "metadata": {},
145
+ "source": [
146
+ "# SambaNovaCloud LLM"
147
+ ]
148
+ },
149
+ {
150
+ "cell_type": "markdown",
151
+ "metadata": {},
152
+ "source": [
153
+ "## Non Streaming"
154
+ ]
155
+ },
156
+ {
157
+ "cell_type": "code",
158
+ "execution_count": 4,
159
+ "metadata": {},
160
+ "outputs": [],
161
+ "source": [
162
+ "llm = SambaNovaCloud(model='llama3-70b')"
163
+ ]
164
+ },
165
+ {
166
+ "cell_type": "code",
167
+ "execution_count": 5,
168
+ "metadata": {},
169
+ "outputs": [
170
+ {
171
+ "data": {
172
+ "text/plain": [
173
+ "'Hello. How can I assist you today?'"
174
+ ]
175
+ },
176
+ "execution_count": 5,
177
+ "metadata": {},
178
+ "output_type": "execute_result"
179
+ }
180
+ ],
181
+ "source": [
182
+ "import json\n",
183
+ "\n",
184
+ "llm.invoke(json.dumps([{'role': 'user', 'content': 'hello'}]))"
185
+ ]
186
+ },
187
+ {
188
+ "cell_type": "code",
189
+ "execution_count": 6,
190
+ "metadata": {},
191
+ "outputs": [
192
+ {
193
+ "data": {
194
+ "text/plain": [
195
+ "'Hello. How can I assist you today?'"
196
+ ]
197
+ },
198
+ "execution_count": 6,
199
+ "metadata": {},
200
+ "output_type": "execute_result"
201
+ }
202
+ ],
203
+ "source": [
204
+ "llm.invoke('hello')"
205
+ ]
206
+ },
207
+ {
208
+ "cell_type": "markdown",
209
+ "metadata": {},
210
+ "source": [
211
+ "## Streaming"
212
+ ]
213
+ },
214
+ {
215
+ "cell_type": "code",
216
+ "execution_count": 7,
217
+ "metadata": {},
218
+ "outputs": [
219
+ {
220
+ "name": "stdout",
221
+ "output_type": "stream",
222
+ "text": [
223
+ "\n",
224
+ "Here's a long story \n",
225
+ "for you:\n",
226
+ "\n",
227
+ "Once upon \n",
228
+ "a time, in a small village \n",
229
+ "nestled in the rolling hills of \n",
230
+ "rural France, there lived a \n",
231
+ "young girl named Sophie. Sophie \n",
232
+ "was a curious and adventurous \n",
233
+ "child, with a mop of curly \n",
234
+ "brown hair and a smile that \n",
235
+ "could light up the darkest \n",
236
+ "of rooms. She lived with \n",
237
+ "her parents, Pierre and \n",
238
+ "Colette, in a small stone cottage \n",
239
+ "on the outskirts of \n",
240
+ "the village.\n",
241
+ "\n",
242
+ "Sophie's village was \n",
243
+ "a charming \n",
244
+ "place, filled with narrow \n",
245
+ "cobblestone streets, quaint shops, \n",
246
+ "and \n",
247
+ "bustling cafes. The villagers \n",
248
+ "were a tight-knit \n",
249
+ "community, and everyone knew each \n",
250
+ "other's names and stories. Sophie \n",
251
+ "loved listening to the villagers' \n",
252
+ "tales of \n",
253
+ "old, which \n",
254
+ "often featured brave knights, \n",
255
+ "beautiful princesses, and \n",
256
+ "magical creatures.\n",
257
+ "\n",
258
+ "One day, while exploring \n",
259
+ "the village, Sophie stumbled upon \n",
260
+ "a small, mysterious shop tucked \n",
261
+ "away on a quiet street. \n",
262
+ "The sign above the door \n",
263
+ "read \"Curios \n",
264
+ "and Wonders,\" and the \n",
265
+ "windows were filled \n",
266
+ "with a dazzling array of strange \n",
267
+ "and exotic objects. Sophie's \n",
268
+ "curiosity was piqued, \n",
269
+ "and she pushed open the door \n",
270
+ "to venture inside.\n",
271
+ "\n",
272
+ "The shop \n",
273
+ "was dimly lit, and \n",
274
+ "the air was thick with the \n",
275
+ "scent of old books and \n",
276
+ "dust. Sophie's eyes \n",
277
+ "adjusted slowly, and she \n",
278
+ "saw that the shop was filled \n",
279
+ "with all manner of curious \n",
280
+ "objects: vintage \n",
281
+ "clocks, rare coins, \n",
282
+ "and even a \n",
283
+ "taxidermied owl perched on \n",
284
+ "a shelf. Behind the counter stood \n",
285
+ "an old man with a kind \n",
286
+ "face \n",
287
+ "and a twinkle in his eye.\n",
288
+ "\n",
289
+ "\n",
290
+ "\n",
291
+ "\"Bonjour, mademoiselle,\" he \n",
292
+ "said, his voice low and \n",
293
+ "soothing. \"Welcome to Curios \n",
294
+ "and Wonders. I \n",
295
+ "am Monsieur LaFleur, \n",
296
+ "the proprietor. How may I \n",
297
+ "assist you \n",
298
+ "today?\"\n",
299
+ "\n",
300
+ "Sophie wandered the aisles, \n",
301
+ "running her fingers over \n",
302
+ "the strange objects on \n",
303
+ "display. She picked up \n",
304
+ "a small, delicate music \n",
305
+ "box and wound \n",
306
+ "it up, listening \n",
307
+ "as it played \n",
308
+ "a soft, melancholy \n",
309
+ "tune. Monsieur LaFleur \n",
310
+ "smiled and nodded \n",
311
+ "in approval.\n",
312
+ "\n",
313
+ "\"Ah, you have a \n",
314
+ "good ear for \n",
315
+ "music, mademoiselle,\" he \n",
316
+ "said. \"That music box \n",
317
+ "is a \n",
318
+ "rare and precious item. It \n",
319
+ "was crafted by a skilled artisan \n",
320
+ "in the 18th century.\"\n",
321
+ "\n",
322
+ "\n",
323
+ "As Sophie continued to \n",
324
+ "explore the shop, \n",
325
+ "she stumbled upon \n",
326
+ "a large, leather-bound book \n",
327
+ "with strange symbols etched into \n",
328
+ "the cover. \n",
329
+ "Monsieur LaFleur noticed her interest and \n",
330
+ "approached \n",
331
+ "her.\n",
332
+ "\n",
333
+ "\"Ah, you've found \n",
334
+ "the infamous 'Livre \n",
335
+ "\n",
336
+ "des Secrets,'\" \n",
337
+ "he said, his \n",
338
+ "voice low and mysterious. \n",
339
+ "\"That book is said to contain \n",
340
+ "the secrets of the universe, \n",
341
+ "hidden within its pages. But \n",
342
+ "be \n",
343
+ "warned, mademoiselle, \n",
344
+ "the book is said to \n",
345
+ "be cursed. Many have attempted \n",
346
+ "to unlock its secrets, but \n",
347
+ "none have \n",
348
+ "succeeded.\"\n",
349
+ "\n",
350
+ "Sophie's eyes widened with \n",
351
+ "excitement as she carefully opened \n",
352
+ "the book. The pages \n",
353
+ "were yellowed and \n",
354
+ "crackling, and \n",
355
+ "the text was written in a \n",
356
+ "language she couldn't understand. \n",
357
+ "But as she turned the \n",
358
+ "pages, \n",
359
+ "she felt a strange sensation, \n",
360
+ "as if the book \n",
361
+ "was calling \n",
362
+ "to her.\n",
363
+ "\n",
364
+ "Monsieur \n",
365
+ "LaFleur smiled \n",
366
+ "and \n",
367
+ "nodded. \"I see you have a \n",
368
+ "connection to the \n",
369
+ "book, mademoiselle. Perhaps you \n",
370
+ "are the one who can unlock \n",
371
+ "its secrets.\"\n",
372
+ "\n",
373
+ "Over the next \n",
374
+ "few weeks, Sophie returned to \n",
375
+ "the shop again and again, \n",
376
+ "pouring over \n",
377
+ "the pages of the Livre \n",
378
+ "des Secrets. She spent hours \n",
379
+ "studying \n",
380
+ "the symbols and trying to decipher \n",
381
+ "the text. \n",
382
+ "Monsieur \n",
383
+ "LaFleur watched her with a \n",
384
+ "keen eye, offering guidance and encouragement \n",
385
+ "whenever she needed it.\n",
386
+ "\n",
387
+ "As \n",
388
+ "the days turned into weeks, \n",
389
+ "Sophie began to notice strange occurrences \n",
390
+ "happening around her. She would \n",
391
+ "find objects moved from their \n",
392
+ "usual places, and she would hear \n",
393
+ "whispers in the night. She \n",
394
+ "began \n",
395
+ "to feel as though the book \n",
396
+ "was exerting some kind of \n",
397
+ "influence over her, drawing her \n",
398
+ "deeper into \n",
399
+ "its secrets.\n",
400
+ "\n",
401
+ "One \n",
402
+ "night, Sophie had a vivid dream \n",
403
+ "in which \n",
404
+ "she saw herself standing in \n",
405
+ "a \n",
406
+ "grand, \n",
407
+ "candlelit hall. \n",
408
+ "The walls were lined with \n",
409
+ "ancient tapestries, and the \n",
410
+ "air was thick with the scent \n",
411
+ "of \n",
412
+ "incense. At the far end of \n",
413
+ "the hall, she saw a \n",
414
+ "figure cloaked in shadows.\n",
415
+ "\n",
416
+ "\n",
417
+ "As she approached \n",
418
+ "the figure, it stepped forward, \n",
419
+ "revealing a woman \n",
420
+ "with long, flowing hair and \n",
421
+ "piercing green eyes. The woman \n",
422
+ "spoke in a voice that was \n",
423
+ "both familiar and yet \n",
424
+ "completely alien.\n",
425
+ "\n",
426
+ "\"Sophie, you \n",
427
+ "have been chosen to unlock the \n",
428
+ "secrets of the Livre \n",
429
+ "des Secrets,\" she \n",
430
+ "said. \"But be warned, \n",
431
+ "the \n",
432
+ "journey will \n",
433
+ "be difficult, and the cost \n",
434
+ "will be high. Are you \n",
435
+ "prepared to pay \n",
436
+ "the price?\"\n",
437
+ "\n",
438
+ "Sophie woke up with \n",
439
+ "a start, her heart racing and \n",
440
+ "her mind reeling. She \n",
441
+ "knew that she had \n",
442
+ "to return to the shop and \n",
443
+ "confront Monsieur LaFleur \n",
444
+ "about the \n",
445
+ "strange \n",
446
+ "occurrences. But when she \n",
447
+ "arrived at the shop, she \n",
448
+ "found that it \n",
449
+ "was closed, \n",
450
+ "and \n",
451
+ "a sign on the door \n",
452
+ "read \"Gone on \n",
453
+ "a \n",
454
+ "journey. Will return \n",
455
+ "soon.\"\n",
456
+ "\n",
457
+ "Sophie \n",
458
+ "was devastated. \n",
459
+ "She felt as though she had \n",
460
+ "been abandoned, left \n",
461
+ "to navigate the mysteries of \n",
462
+ "the Livre des Secrets on \n",
463
+ "her own. But as \n",
464
+ "she turned to leave, she \n",
465
+ "noticed a\n"
466
+ ]
467
+ }
468
+ ],
469
+ "source": [
470
+ "for i in llm.stream('hello tell me a long story'):\n",
471
+ " print(i)"
472
+ ]
473
+ },
474
+ {
475
+ "cell_type": "markdown",
476
+ "metadata": {},
477
+ "source": [
478
+ "# SambaNova Cloud Chat Model"
479
+ ]
480
+ },
481
+ {
482
+ "cell_type": "markdown",
483
+ "metadata": {},
484
+ "source": [
485
+ "## Non Streaming"
486
+ ]
487
+ },
488
+ {
489
+ "cell_type": "code",
490
+ "execution_count": 4,
491
+ "metadata": {},
492
+ "outputs": [],
493
+ "source": [
494
+ "llm = ChatSambaNovaCloud(\n",
495
+ " model= \"llama3-405b\",\n",
496
+ " max_tokens=1024,\n",
497
+ " temperature=0.7,\n",
498
+ " top_k=1,\n",
499
+ " top_p=0.01,\n",
500
+ " stream_options={'include_usage':True}\n",
501
+ " )"
502
+ ]
503
+ },
504
+ {
505
+ "cell_type": "code",
506
+ "execution_count": 5,
507
+ "metadata": {},
508
+ "outputs": [
509
+ {
510
+ "data": {
511
+ "text/plain": [
512
+ "AIMessage(content='A man walked into a library and asked the librarian, \"Do you have any books on Pavlov\\'s dogs and Schrödinger\\'s cat?\"\\n\\nThe librarian replied, \"It rings a bell, but I\\'m not sure if it\\'s here or not.\"', response_metadata={'finish_reason': 'stop', 'usage': {'acceptance_rate': 6.875, 'completion_tokens': 54, 'completion_tokens_after_first_per_sec': 146.48573712341215, 'completion_tokens_after_first_per_sec_first_ten': 172.9005798161617, 'completion_tokens_per_sec': 81.99632208428116, 'end_time': 1726178488.071125, 'is_last_response': True, 'prompt_tokens': 40, 'start_time': 1726178487.3630672, 'time_to_first_token': 0.34624791145324707, 'total_latency': 0.658566123789007, 'total_tokens': 94, 'total_tokens_per_sec': 142.73433844300794}, 'model_name': 'Meta-Llama-3.1-405B-Instruct', 'system_fingerprint': 'fastcoe', 'created': 1726178487}, id='a5590b89-4853-4bd9-9fd8-83276b369278')"
513
+ ]
514
+ },
515
+ "execution_count": 5,
516
+ "metadata": {},
517
+ "output_type": "execute_result"
518
+ }
519
+ ],
520
+ "source": [
521
+ "llm.invoke(\"tell me a joke\")"
522
+ ]
523
+ },
524
+ {
525
+ "cell_type": "code",
526
+ "execution_count": 7,
527
+ "metadata": {},
528
+ "outputs": [
529
+ {
530
+ "data": {
531
+ "text/plain": [
532
+ "AIMessage(content=\"Yer lookin' fer a joke, eh? Alright then, matey! Here be one fer ye:\\n\\nWhy did the pirate quit his job?\\n\\n(pause fer dramatic effect)\\n\\nBecause he was sick o' all the arrrr-guments!\\n\\nYarrr, hope that made ye laugh, me hearty!\", response_metadata={'finish_reason': 'stop', 'usage': {'acceptance_rate': 5.583333333333333, 'completion_tokens': 64, 'completion_tokens_after_first_per_sec': 120.91573778458478, 'completion_tokens_after_first_per_sec_first_ten': 140.3985499426452, 'completion_tokens_per_sec': 79.98855768735817, 'end_time': 1726065701.9732044, 'is_last_response': True, 'prompt_tokens': 48, 'start_time': 1726065701.107911, 'time_to_first_token': 0.3442692756652832, 'total_latency': 0.8001144394945743, 'total_tokens': 112, 'total_tokens_per_sec': 139.9799759528768}, 'model_name': 'Meta-Llama-3.1-405B-Instruct', 'system_fingerprint': 'fastcoe', 'created': 1726065701}, id='7b0748bb-c5f7-4696-ae56-03b734b60fb9')"
533
+ ]
534
+ },
535
+ "execution_count": 7,
536
+ "metadata": {},
537
+ "output_type": "execute_result"
538
+ }
539
+ ],
540
+ "source": [
541
+ "messages = [\n",
542
+ " SystemMessage(content=\"You are a helpful assistant with pirate accent\"),\n",
543
+ " HumanMessage(content=\"tell me a joke\")\n",
544
+ " ]\n",
545
+ "llm.invoke(messages)"
546
+ ]
547
+ },
548
+ {
549
+ "cell_type": "code",
550
+ "execution_count": 8,
551
+ "metadata": {},
552
+ "outputs": [
553
+ {
554
+ "data": {
555
+ "text/plain": [
556
+ "AIMessage(content='A man walked into a library and asked the librarian, \"Do you have any books on Pavlov\\'s dogs and Schrödinger\\'s cat?\"\\n\\nThe librarian replied, \"It rings a bell, but I\\'m not sure if it\\'s here or not.\"', response_metadata={'finish_reason': 'stop', 'usage': {'acceptance_rate': 6.875, 'completion_tokens': 54, 'completion_tokens_after_first_per_sec': 146.72813415408498, 'completion_tokens_after_first_per_sec_first_ten': 172.71830994351703, 'completion_tokens_per_sec': 82.34884281970663, 'end_time': 1726065746.6364844, 'is_last_response': True, 'prompt_tokens': 40, 'start_time': 1726065745.932173, 'time_to_first_token': 0.34309911727905273, 'total_latency': 0.6557469194585627, 'total_tokens': 94, 'total_tokens_per_sec': 143.34798564911895}, 'model_name': 'Meta-Llama-3.1-405B-Instruct', 'system_fingerprint': 'fastcoe', 'created': 1726065745}, id='27e7d4fe-8e24-419a-b75b-51ea2519781b')"
557
+ ]
558
+ },
559
+ "execution_count": 8,
560
+ "metadata": {},
561
+ "output_type": "execute_result"
562
+ }
563
+ ],
564
+ "source": [
565
+ "future_response = llm.ainvoke(\"tell me a joke\")\n",
566
+ "await(future_response) "
567
+ ]
568
+ },
569
+ {
570
+ "cell_type": "markdown",
571
+ "metadata": {},
572
+ "source": [
573
+ "## Batching"
574
+ ]
575
+ },
576
+ {
577
+ "cell_type": "code",
578
+ "execution_count": 9,
579
+ "metadata": {},
580
+ "outputs": [],
581
+ "source": [
582
+ "llm = ChatSambaNovaCloud(\n",
583
+ " model= \"llama3-405b\",\n",
584
+ " streaming=False,\n",
585
+ " max_tokens=1024,\n",
586
+ " temperature=0.7,\n",
587
+ " top_k=1,\n",
588
+ " top_p=0.01,\n",
589
+ " stream_options={'include_usage':True}\n",
590
+ " )"
591
+ ]
592
+ },
593
+ {
594
+ "cell_type": "code",
595
+ "execution_count": 11,
596
+ "metadata": {},
597
+ "outputs": [
598
+ {
599
+ "data": {
600
+ "text/plain": [
601
+ "[AIMessage(content='A man walked into a library and asked the librarian, \"Do you have any books on Pavlov\\'s dogs and Schrödinger\\'s cat?\"\\n\\nThe librarian replied, \"It rings a bell, but I\\'m not sure if it\\'s here or not.\"', response_metadata={'finish_reason': 'stop', 'usage': {'acceptance_rate': 6.875, 'completion_tokens': 54, 'completion_tokens_after_first_per_sec': 146.72232349940003, 'completion_tokens_after_first_per_sec_first_ten': 173.01988455676758, 'completion_tokens_per_sec': 82.21649876350362, 'end_time': 1726065879.4066722, 'is_last_response': True, 'prompt_tokens': 40, 'start_time': 1726065878.700746, 'time_to_first_token': 0.3446996212005615, 'total_latency': 0.656802476536144, 'total_tokens': 94, 'total_tokens_per_sec': 143.1176089586915}, 'model_name': 'Meta-Llama-3.1-405B-Instruct', 'system_fingerprint': 'fastcoe', 'created': 1726065878}, id='28d3a38b-5dae-4d62-bf6c-cface081df34'),\n",
602
+ " AIMessage(content='The capital of the United Kingdom is London.', response_metadata={'finish_reason': 'stop', 'usage': {'acceptance_rate': 13, 'completion_tokens': 10, 'completion_tokens_after_first_per_sec': 110.21174794386165, 'completion_tokens_after_first_per_sec_first_ten': 327.0275172132524, 'completion_tokens_per_sec': 26.88555788272027, 'end_time': 1726065879.138034, 'is_last_response': True, 'prompt_tokens': 43, 'start_time': 1726065878.7150047, 'time_to_first_token': 0.3413684368133545, 'total_latency': 0.37194690337547887, 'total_tokens': 53, 'total_tokens_per_sec': 142.49345677841742}, 'model_name': 'Meta-Llama-3.1-405B-Instruct', 'system_fingerprint': 'fastcoe', 'created': 1726065878}, id='859a9e45-c0a5-44ec-bd53-686877c2cf89')]"
603
+ ]
604
+ },
605
+ "execution_count": 11,
606
+ "metadata": {},
607
+ "output_type": "execute_result"
608
+ }
609
+ ],
610
+ "source": [
611
+ "llm.batch([\"tell me a joke\",\"which is the capital of UK?\"])"
612
+ ]
613
+ },
614
+ {
615
+ "cell_type": "code",
616
+ "execution_count": 13,
617
+ "metadata": {},
618
+ "outputs": [
619
+ {
620
+ "name": "stderr",
621
+ "output_type": "stream",
622
+ "text": [
623
+ "/var/folders/p4/y0q2kh796nx_k_yzfhxs57f00000gp/T/ipykernel_33601/1543848179.py:1: RuntimeWarning: coroutine 'Runnable.abatch' was never awaited\n",
624
+ " future_responses = llm.abatch([\"tell me a joke\",\"which is the capital of UK?\"])\n",
625
+ "RuntimeWarning: Enable tracemalloc to get the object allocation traceback\n"
626
+ ]
627
+ },
628
+ {
629
+ "data": {
630
+ "text/plain": [
631
+ "[AIMessage(content='A man walked into a library and asked the librarian, \"Do you have any books on Pavlov\\'s dogs and Schrödinger\\'s cat?\"\\n\\nThe librarian replied, \"It rings a bell, but I\\'m not sure if it\\'s here or not.\"', response_metadata={'finish_reason': 'stop', 'usage': {'acceptance_rate': 6.875, 'completion_tokens': 54, 'completion_tokens_after_first_per_sec': 120.34699641554552, 'completion_tokens_after_first_per_sec_first_ten': 141.51170437257693, 'completion_tokens_per_sec': 36.223157123884754, 'end_time': 1726065914.8678048, 'is_last_response': True, 'prompt_tokens': 40, 'start_time': 1726065913.3182464, 'time_to_first_token': 1.1091651916503906, 'total_latency': 1.4907590692693538, 'total_tokens': 94, 'total_tokens_per_sec': 63.05512536379939}, 'model_name': 'Meta-Llama-3.1-405B-Instruct', 'system_fingerprint': 'fastcoe', 'created': 1726065913}, id='f279d0fb-70b5-428c-9283-457b9831b559'),\n",
632
+ " AIMessage(content='The capital of the United Kingdom is London.', response_metadata={'finish_reason': 'stop', 'usage': {'acceptance_rate': 9.5, 'completion_tokens': 10, 'completion_tokens_after_first_per_sec': 60.73429985889864, 'completion_tokens_after_first_per_sec_first_ten': 195.5434460421063, 'completion_tokens_per_sec': 8.61842566880045, 'end_time': 1726065914.575598, 'is_last_response': True, 'prompt_tokens': 43, 'start_time': 1726065913.3182464, 'time_to_first_token': 1.1091651916503906, 'total_latency': 1.160304722033049, 'total_tokens': 53, 'total_tokens_per_sec': 45.67765604464238}, 'model_name': 'Meta-Llama-3.1-405B-Instruct', 'system_fingerprint': 'fastcoe', 'created': 1726065913}, id='f279d0fb-70b5-428c-9283-457b9831b559')]"
633
+ ]
634
+ },
635
+ "execution_count": 13,
636
+ "metadata": {},
637
+ "output_type": "execute_result"
638
+ }
639
+ ],
640
+ "source": [
641
+ "future_responses = llm.abatch([\"tell me a joke\",\"which is the capital of UK?\"])\n",
642
+ "await(future_responses)"
643
+ ]
644
+ },
645
+ {
646
+ "cell_type": "markdown",
647
+ "metadata": {},
648
+ "source": [
649
+ "## Streaming"
650
+ ]
651
+ },
652
+ {
653
+ "cell_type": "code",
654
+ "execution_count": 14,
655
+ "metadata": {},
656
+ "outputs": [],
657
+ "source": [
658
+ "llm = ChatSambaNovaCloud(\n",
659
+ " model= \"llama3-405b\",\n",
660
+ " streaming=True,\n",
661
+ " max_tokens=1024,\n",
662
+ " temperature=0.7,\n",
663
+ " top_k=1,\n",
664
+ " top_p=0.01,\n",
665
+ " stream_options={'include_usage':True}\n",
666
+ " )"
667
+ ]
668
+ },
669
+ {
670
+ "cell_type": "code",
671
+ "execution_count": 15,
672
+ "metadata": {},
673
+ "outputs": [
674
+ {
675
+ "name": "stdout",
676
+ "output_type": "stream",
677
+ "text": [
678
+ "\n",
679
+ "A man walked into a \n",
680
+ "library and asked the \n",
681
+ "librarian, \"Do you have any books \n",
682
+ "on Pavlov's dogs \n",
683
+ "and Schrödinger's cat?\"\n",
684
+ "\n",
685
+ "\n",
686
+ "The librarian \n",
687
+ "replied, \"It rings a bell, \n",
688
+ "but I'm not sure \n",
689
+ "if it's here \n",
690
+ "or not.\"\n",
691
+ "\n",
692
+ "\n",
693
+ "\n"
694
+ ]
695
+ }
696
+ ],
697
+ "source": [
698
+ "for chunk in llm.stream(\"tell me a joke\"):\n",
699
+ " print(chunk.content)"
700
+ ]
701
+ },
702
+ {
703
+ "cell_type": "code",
704
+ "execution_count": 16,
705
+ "metadata": {},
706
+ "outputs": [
707
+ {
708
+ "name": "stdout",
709
+ "output_type": "stream",
710
+ "text": [
711
+ "\n",
712
+ "Yer lookin' \n",
713
+ "fer a joke, eh? \n",
714
+ "Alright then, matey! \n",
715
+ "Here be one fer \n",
716
+ "ye:\n",
717
+ "\n",
718
+ "Why did the pirate quit his job?\n",
719
+ "\n",
720
+ "\n",
721
+ "\n",
722
+ "(pause fer \n",
723
+ "dramatic effect)\n",
724
+ "\n",
725
+ "Because he was sick \n",
726
+ "o' all the arrrr-guments!\n",
727
+ "\n",
728
+ "\n",
729
+ "\n",
730
+ "\n",
731
+ "Yarrr, hope that made ye \n",
732
+ "laugh, \n",
733
+ "me hearty!\n",
734
+ "\n",
735
+ "\n",
736
+ "\n"
737
+ ]
738
+ }
739
+ ],
740
+ "source": [
741
+ "messages = [\n",
742
+ " SystemMessage(content=\"You are a helpful assistant with pirate accent\"),\n",
743
+ " HumanMessage(content=\"tell me a joke\")\n",
744
+ " ]\n",
745
+ "for chunk in llm.stream(messages):\n",
746
+ " print(chunk.content)"
747
+ ]
748
+ },
749
+ {
750
+ "cell_type": "code",
751
+ "execution_count": 17,
752
+ "metadata": {},
753
+ "outputs": [
754
+ {
755
+ "name": "stdout",
756
+ "output_type": "stream",
757
+ "text": [
758
+ "\n",
759
+ "A man walked into a \n",
760
+ "library and asked the \n",
761
+ "librarian, \"Do you have any books \n",
762
+ "on Pavlov's dogs \n",
763
+ "and Schrödinger's cat?\"\n",
764
+ "\n",
765
+ "\n",
766
+ "The librarian \n",
767
+ "replied, \"It rings a bell, \n",
768
+ "but I'm not sure \n",
769
+ "if it's here \n",
770
+ "or not.\"\n",
771
+ "\n",
772
+ "\n",
773
+ "\n"
774
+ ]
775
+ }
776
+ ],
777
+ "source": [
778
+ "async for chunk in llm.astream(\"tell me a joke\"):\n",
779
+ " print(chunk.content)"
780
+ ]
781
+ },
782
+ {
783
+ "cell_type": "markdown",
784
+ "metadata": {},
785
+ "source": [
786
+ "# Sambastudio Embeddings"
787
+ ]
788
+ },
789
+ {
790
+ "cell_type": "code",
791
+ "execution_count": null,
792
+ "metadata": {},
793
+ "outputs": [],
794
+ "source": [
795
+ "embedding = SambaStudioEmbeddings(batch_size=1, model_kwargs={'select_expert': 'e5-mistral-7b-instruct'})\n",
796
+ "embedding.embed_documents(['tell me a 50 word tale', 'tell me a joke'])\n",
797
+ "embedding.embed_query('tell me a 50 word tale')"
798
+ ]
799
+ },
800
+ {
801
+ "cell_type": "code",
802
+ "execution_count": 13,
803
+ "metadata": {},
804
+ "outputs": [
805
+ {
806
+ "name": "stderr",
807
+ "output_type": "stream",
808
+ "text": [
809
+ "/Users/jorgep/Documents/ask_public_own/finetuning_env/lib/python3.11/site-packages/langchain_core/_api/deprecation.py:139: LangChainDeprecationWarning: The method `BaseRetriever.get_relevant_documents` was deprecated in langchain-core 0.1.46 and will be removed in 0.3.0. Use invoke instead.\n",
810
+ " warn_deprecated(\n"
811
+ ]
812
+ },
813
+ {
814
+ "data": {
815
+ "text/plain": [
816
+ "[Document(page_content='tell me a 50 word tale'),\n",
817
+ " Document(page_content='tell me a joke'),\n",
818
+ " Document(page_content='give me 3 party activities'),\n",
819
+ " Document(page_content='give me three healty dishes')]"
820
+ ]
821
+ },
822
+ "execution_count": 13,
823
+ "metadata": {},
824
+ "output_type": "execute_result"
825
+ }
826
+ ],
827
+ "source": [
828
+ "from langchain.schema import Document\n",
829
+ "from langchain.vectorstores import Chroma\n",
830
+ "\n",
831
+ "docs = [\n",
832
+ " 'tell me a 50 word tale',\n",
833
+ " 'tell me a joke',\n",
834
+ " 'when was America discoverd?',\n",
835
+ " 'how to build an engine?',\n",
836
+ " 'give me 3 party activities',\n",
837
+ " 'give me three healty dishes',\n",
838
+ "]\n",
839
+ "docs = [Document(doc) for doc in docs]\n",
840
+ "\n",
841
+ "query = 'prompt for generating something fun'\n",
842
+ "\n",
843
+ "vectordb = Chroma.from_documents(docs, embedding)\n",
844
+ "retriever = vectordb.as_retriever()\n",
845
+ "\n",
846
+ "retriever.get_relevant_documents(query)"
847
+ ]
848
+ },
849
+ {
850
+ "cell_type": "code",
851
+ "execution_count": null,
852
+ "metadata": {},
853
+ "outputs": [],
854
+ "source": []
855
+ }
856
+ ],
857
+ "metadata": {
858
+ "kernelspec": {
859
+ "display_name": "peenv",
860
+ "language": "python",
861
+ "name": "python3"
862
+ },
863
+ "language_info": {
864
+ "codemirror_mode": {
865
+ "name": "ipython",
866
+ "version": 3
867
+ },
868
+ "file_extension": ".py",
869
+ "mimetype": "text/x-python",
870
+ "name": "python",
871
+ "nbconvert_exporter": "python",
872
+ "pygments_lexer": "ipython3",
873
+ "version": "3.10.11"
874
+ }
875
+ },
876
+ "nbformat": 4,
877
+ "nbformat_minor": 2
878
+ }
utils/parsing/README.md ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SambaParse
2
+
3
+ SambaParse is a Python library that simplifies the process of extracting and processing unstructured data using the Unstructured.io API. It provides a convenient wrapper around the Unstructured.io CLI tool, allowing you to ingest data from various sources, perform partitioning, chunking, embedding, and load the processed data into a vector database. It's designed to be used within AI Starter kits and SN Apps, unifying our data ingestion and document intelligence platform. This allows us to keep our code base centralized for data ingestion kits.
4
+
5
+ ## Prerequisites
6
+
7
+ Before using SambaParse, make sure you have the following:
8
+
9
+ - Docker installed on your machine (or access to another API server)
10
+ - An Unstructured.io API key
11
+
12
+ Before using SambaParse, make sure you have the following:
13
+
14
+ - Create a `.env` file in the ai-starter-kit root directory (not in the parsing folder root):
15
+
16
+ ```bash
17
+ UNSTRUCTURED_API_KEY=your_api_key_here
18
+ ```
19
+
20
+ ## Setup
21
+
22
+ ### Pre Reqs
23
+
24
+ Using pyenv to manage virtualenv's is recommended
25
+ Mac install instructions. See pyenv-virtualenv repo for more detailed instructions.
26
+
27
+ ```bash
28
+ brew install pyenv-virtualenv
29
+ ```
30
+
31
+ - Create a python venv using python version 3.10.12
32
+
33
+ ```bash
34
+ pyenv install 3.10.12
35
+ pyenv virtualenv 3.10.12 sambaparse
36
+ pyenv activate sambaparse
37
+ ```
38
+
39
+ - Clone the ai-starter-kit repo and cd:
40
+
41
+ ```bash
42
+ git clone https://github.com/sambanova/ai-starter-kit
43
+ ```
44
+
45
+ - cd into utils/parsing and pip install the requirements
46
+
47
+ ```bash
48
+ pip install -r requirements.txt
49
+ ```
50
+
51
+ - cd into the unstructured-api foder and Install the unstructured-api make-file:
52
+
53
+ ```bash
54
+ cd unstructured-api
55
+ ```
56
+
57
+ - Run
58
+
59
+ ```bash
60
+ make install
61
+ ```
62
+
63
+ - Run The Web Server:
64
+
65
+ ```bash
66
+ make run-web-app
67
+ ```
68
+
69
+ This script will start the Unstructured API server using the specified API key and expose it on port 8005.
70
+
71
+ - Alternatively, if you have another Unstructured API server running on a different instance, make sure to update the `partition_endpoint` and `unstructured_port` values in the YAML configuration file accordingly.
72
+
73
+ ## Usage
74
+
75
+ 1. Import the `SambaParse` class from the `ai-starter-kit` library:
76
+
77
+ ```python
78
+ from utils.parsing.sambaparse import SambaParse
79
+ ```
80
+
81
+ 2. Create a YAML configuration file (e.g., `config.yaml`) to specify the desired settings for the ingestion process. Here's the configuration for use cases 1 and 2 ie local files and folders:
82
+
83
+ ```yaml
84
+ processor:
85
+ verbose: True
86
+ output_dir: './output'
87
+ num_processes: 2
88
+
89
+ sources:
90
+ local:
91
+ recursive: True
92
+ confluence:
93
+ api_token: 'your_confluence_api_token'
94
+ user_email: 'your_email@example.com'
95
+ url: 'https://your-confluence-url.atlassian.net'
96
+ github:
97
+ url: 'owner/repo'
98
+ branch: 'main'
99
+ google_drive:
100
+ service_account_key: 'path/to/service_account_key.json'
101
+ recursive: True
102
+ drive_id: 'your_drive_id'
103
+
104
+ partitioning:
105
+ pdf_infer_table_structure: True
106
+ skip_infer_table_types: []
107
+ strategy: 'auto'
108
+ hi_res_model_name: 'yolox'
109
+ ocr_languages: ['eng']
110
+ encoding: 'utf-8'
111
+ fields_include: ['element_id', 'text', 'type', 'metadata', 'embeddings']
112
+ flatten_metadata: False
113
+ metadata_exclude: []
114
+ metadata_include: []
115
+ partition_endpoint: 'http://localhost'
116
+ unstructured_port: 8005
117
+ partition_by_api: True
118
+
119
+ chunking:
120
+ enabled: True
121
+ strategy: 'basic'
122
+ chunk_max_characters: 1500
123
+ chunk_overlap: 300
124
+
125
+ embedding:
126
+ enabled: False
127
+ provider: 'langchain-huggingface'
128
+ model_name: 'intfloat/e5-large-v2'
129
+
130
+ destination_connectors:
131
+ enabled: False
132
+ type: 'chroma'
133
+ batch_size: 80
134
+ chroma:
135
+ host: 'localhost'
136
+ port: 8004
137
+ collection_name: 'snconf'
138
+ tenant: 'default_tenant'
139
+ database: 'default_database'
140
+ qdrant:
141
+ location: 'http://localhost:6333'
142
+ collection_name: 'test'
143
+
144
+ additional_processing:
145
+ enabled: True
146
+ extend_metadata: True
147
+ replace_table_text: True
148
+ table_text_key: 'text_as_html'
149
+ return_langchain_docs: True
150
+ convert_metadata_keys_to_string: True
151
+ ```
152
+
153
+ Make sure to place the `config.yaml` file in the desired folder.
154
+
155
+ 3. Create an instance of the `SambaParse` class, passing the path to the YAML configuration file:
156
+
157
+ ```python
158
+ sambaparse = SambaParse('path/to/config.yaml')
159
+ ```
160
+
161
+ 4. Use the `run_ingest` method to process your data:
162
+
163
+ - For a single file:
164
+
165
+ ```python
166
+ source_type = 'local'
167
+ input_path = 'path/to/your/file.pdf'
168
+ additional_metadata = {'key': 'value'}
169
+ texts, metadata_list, langchain_docs = sambaparse.run_ingest(source_type, input_path=input_path, additional_metadata=additional_metadata)
170
+ ```
171
+
172
+ - For a folder:
173
+
174
+ ```python
175
+ source_type = 'local'
176
+ input_path = 'path/to/your/file.pdf'
177
+ additional_metadata = {'key': 'value'}
178
+ texts, metadata_list, langchain_docs = sambaparse.run_ingest(source_type, input_path=input_path, additional_metadata=additional_metadata)
179
+ ```
180
+
181
+ - For Confluence:
182
+
183
+ ```python
184
+ source_type = 'confluence'
185
+ additional_metadata = {'key': 'value'}
186
+ texts, metadata_list, langchain_docs = sambaparse.run_ingest(source_type, additional_metadata=additional_metadata)
187
+ ```
188
+
189
+ Note that for conflence you must enable embedding and destinatation connectors automatically ie Chroma and turn off additional processing (ie langchain), an example yaml to do that is below
190
+
191
+ ```yaml
192
+ processor:
193
+ verbose: True
194
+ output_dir: './output'
195
+ num_processes: 2
196
+
197
+ sources:
198
+ local:
199
+ recursive: True
200
+ confluence:
201
+ api_token: 'your_confluence_api_token'
202
+ user_email: 'your_email@example.com'
203
+ url: 'https://your-confluence-url.atlassian.net'
204
+ github:
205
+ url: 'owner/repo'
206
+ branch: 'main'
207
+ google_drive:
208
+ service_account_key: 'path/to/service_account_key.json'
209
+ recursive: True
210
+ drive_id: 'your_drive_id'
211
+
212
+ partitioning:
213
+ pdf_infer_table_structure: True
214
+ skip_infer_table_types: []
215
+ strategy: 'auto'
216
+ hi_res_model_name: 'yolox'
217
+ ocr_languages: ['eng']
218
+ encoding: 'utf-8'
219
+ fields_include: ['element_id', 'text', 'type', 'metadata', 'embeddings']
220
+ flatten_metadata: False
221
+ metadata_exclude: []
222
+ metadata_include: []
223
+ partition_endpoint: 'http://localhost'
224
+ unstructured_port: 8005
225
+ partition_by_api: True
226
+
227
+ chunking:
228
+ enabled: True
229
+ strategy: 'basic'
230
+ chunk_max_characters: 1500
231
+ chunk_overlap: 300
232
+
233
+ embedding:
234
+ enabled: True
235
+ provider: 'langchain-huggingface'
236
+ model_name: 'intfloat/e5-large-v2'
237
+
238
+ destination_connectors:
239
+ enabled: True
240
+ type: 'chroma'
241
+ batch_size: 80
242
+ chroma:
243
+ host: 'localhost'
244
+ port: 8004
245
+ collection_name: 'snconf'
246
+ tenant: 'default_tenant'
247
+ database: 'default_database'
248
+ qdrant:
249
+ location: 'http://localhost:6333'
250
+ collection_name: 'test'
251
+
252
+ additional_processing:
253
+ enabled: False
254
+ extend_metadata: True
255
+ replace_table_text: True
256
+ table_text_key: 'text_as_html'
257
+ return_langchain_docs: True
258
+ convert_metadata_keys_to_string: True
259
+ ```
260
+
261
+ In addition for confluence you will need to have a Chroma Server running on port 8004, you can do this by running the docker command below
262
+
263
+ ```bash
264
+ docker run -d --rm --name chromadb -v ./chroma:/chroma/chroma -e IS_PERSISTENT=TRUE -e ANONYMIZED_TELEMETRY=TRUE -p 8004:8000 chromadb/chroma:latest
265
+ ```
266
+
267
+ The `run_ingest` method returns a tuple containing the extracted texts, metadata, and LangChain documents (if `return_langchain_docs` is set to `True` in the configuration).
268
+
269
+ 5. Process the returned data as needed:
270
+ - `texts`: A list of extracted text elements from the documents.
271
+ - `metadata_list`: A list of metadata dictionaries for each text element.
272
+ - `langchain_docs`: A list of LangChain `Document` objects, which combine the text and metadata.
273
+
274
+ #### Configuration Options
275
+
276
+ The YAML configuration file allows you to customize various aspects of the ingestion process. Here are some of the key options:
277
+
278
+ - `processor`: Settings related to the processing of documents, such as the output directory and the number of processes to use.
279
+ - `sources`: Configuration for different data sources, including local files, Confluence, GitHub, and Google Drive.
280
+ - `partitioning`: Options for partitioning the documents, including the strategy, OCR languages, and API settings.
281
+ - `chunking`: Settings for chunking the documents, such as enabling chunking, specifying the chunking strategy, and setting the maximum chunk size and overlap.
282
+ - `embedding`: Options for embedding the documents, including enabling embedding, specifying the embedding provider, and setting the model name.
283
+ - `additional_processing`: Configuration for additional processing steps, such as extending metadata, replacing table text, and returning LangChain documents.
284
+
285
+ Make sure to review and modify the configuration file according to your specific requirements.
utils/parsing/config.yaml ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ processor:
2
+ verbose: True
3
+ output_dir: './output'
4
+ num_processes: 2
5
+ reprocess: False
6
+
7
+ sources:
8
+ local:
9
+ recursive: True
10
+ confluence:
11
+ api_token: 'your_confluence_api_token'
12
+ user_email: 'your_email@example.com'
13
+ url: 'https://your-confluence-url.atlassian.net'
14
+ github:
15
+ url: 'owner/repo'
16
+ branch: 'main'
17
+ google_drive:
18
+ service_account_key: 'path/to/service_account_key.json'
19
+ recursive: True
20
+ drive_id: 'your_drive_id'
21
+
22
+ partitioning:
23
+ skip_infer_table_types: []
24
+ strategy: 'auto'
25
+ hi_res_model_name: 'yolox'
26
+ ocr_languages: ['eng']
27
+ encoding: 'utf-8'
28
+ fields_include: ['element_id', 'text', 'type', 'metadata', 'embeddings']
29
+ flatten_metadata: False
30
+ metadata_exclude: []
31
+ metadata_include: []
32
+ partition_endpoint: 'http://localhost'
33
+ unstructured_port: 8005
34
+ partition_by_api: False # set as true if using API server
35
+ default_unstructured_api_key: 123456789abcde
36
+
37
+ chunking:
38
+ enabled: True
39
+ strategy: 'by_title'
40
+ chunk_max_characters: 1500
41
+ chunk_overlap: 300
42
+ combine_under_n_chars: 1500
43
+
44
+ embedding:
45
+ enabled: False
46
+ provider: 'langchain-huggingface'
47
+ model_name: 'intfloat/e5-large-v2'
48
+
49
+ destination_connectors:
50
+ enabled: False
51
+ type: 'chroma'
52
+ batch_size: 80
53
+ chroma:
54
+ host: 'localhost'
55
+ port: 8004
56
+ collection_name: 'snconf'
57
+ tenant: 'default_tenant'
58
+ database: 'default_database'
59
+ qdrant:
60
+ location: 'http://localhost:6333'
61
+ collection_name: 'test'
62
+
63
+ additional_processing:
64
+ enabled: True
65
+ extend_metadata: True
66
+ replace_table_text: True
67
+ table_text_key: 'text_as_html'
68
+ return_langchain_docs: True
69
+ convert_metadata_keys_to_string: True
utils/parsing/docker-compose.yaml ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ version: '3.9'
2
+
3
+ networks:
4
+ net:
5
+ driver: bridge
6
+
7
+ services:
8
+ unstructured-api:
9
+ image: downloads.unstructured.io/unstructured-io/unstructured-api:latest
10
+ command: --port 8000 --host 0.0.0.0
11
+ ports:
12
+ - "${UNSTRUCTURED_PORT:-8005}:8000"
13
+ env_file:
14
+ - ../../.env
15
+
16
+ networks:
17
+ - net
18
+
19
+ chromadb:
20
+ image: chromadb/chroma:latest
21
+ volumes:
22
+ - ./chromadb:/chroma/chroma
23
+ environment:
24
+ - IS_PERSISTENT=TRUE
25
+ - PERSIST_DIRECTORY=/chroma/chroma
26
+ - ANONYMIZED_TELEMETRY=${ANONYMIZED_TELEMETRY:-TRUE}
27
+ ports:
28
+ - "${CHROMA_PORT:-8004}:8000"
29
+ networks:
30
+ - net
utils/parsing/parse_usage.ipynb ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 15,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "name": "stdout",
10
+ "output_type": "stream",
11
+ "text": [
12
+ "This is the repo dir /Users/kwasia/Documents/Projects/ai-starter-kit\n"
13
+ ]
14
+ }
15
+ ],
16
+ "source": [
17
+ "import os\n",
18
+ "import sys\n",
19
+ "\n",
20
+ "current_dir = os.getcwd()\n",
21
+ "kit_dir = os.path.abspath(os.path.join(current_dir, '..'))\n",
22
+ "repo_dir = os.path.abspath(os.path.join(kit_dir, '..'))\n",
23
+ "\n",
24
+ "sys.path.append(kit_dir)\n",
25
+ "sys.path.append(repo_dir)\n",
26
+ "\n",
27
+ "print(f'This is the repo dir {repo_dir}')"
28
+ ]
29
+ },
30
+ {
31
+ "cell_type": "code",
32
+ "execution_count": 16,
33
+ "metadata": {},
34
+ "outputs": [
35
+ {
36
+ "data": {
37
+ "text/plain": [
38
+ "True"
39
+ ]
40
+ },
41
+ "execution_count": 16,
42
+ "metadata": {},
43
+ "output_type": "execute_result"
44
+ }
45
+ ],
46
+ "source": [
47
+ "# Load DotEnv\n",
48
+ "\n",
49
+ "from dotenv import load_dotenv\n",
50
+ "\n",
51
+ "load_dotenv('../../.env')"
52
+ ]
53
+ },
54
+ {
55
+ "cell_type": "code",
56
+ "execution_count": 17,
57
+ "metadata": {},
58
+ "outputs": [],
59
+ "source": [
60
+ "from utils.parsing.sambaparse import SambaParse"
61
+ ]
62
+ },
63
+ {
64
+ "cell_type": "markdown",
65
+ "metadata": {},
66
+ "source": [
67
+ "# Use Case 1 - Process a Single File"
68
+ ]
69
+ },
70
+ {
71
+ "cell_type": "code",
72
+ "execution_count": 19,
73
+ "metadata": {},
74
+ "outputs": [
75
+ {
76
+ "name": "stderr",
77
+ "output_type": "stream",
78
+ "text": [
79
+ "2024-06-20 16:15:20,971 - INFO - Deleting contents of output directory: ./output\n"
80
+ ]
81
+ },
82
+ {
83
+ "name": "stderr",
84
+ "output_type": "stream",
85
+ "text": [
86
+ "2024-06-20 16:15:20,995 - INFO - Running command: unstructured-ingest local --output-dir ./output --num-processes 2 --strategy auto --ocr-languages eng --encoding utf-8 --fields-include element_id,text,type,metadata,embeddings --metadata-exclude --metadata-include --pdf-infer-table-structure --input-path \"./test_docs/samba_turbo.pdf\" --recursive --verbose --partition-by-api --api-key EA6ZX3037WEZUV8THwco --partition-endpoint http://localhost:8005 --pdf-infer-table-structure --chunking-strategy basic --chunk-max-characters 1500 --chunk-overlap 300\n",
87
+ "2024-06-20 16:15:20,996 - INFO - This may take some time depending on the size of your data. Please be patient...\n",
88
+ "2024-06-20 16:15:20,996 - INFO - This may take some time depending on the size of your data. Please be patient...\n",
89
+ "/Users/kwasia/.pyenv/versions/sambaparse/lib/python3.10/site-packages/dataclasses_json/core.py:201: RuntimeWarning: 'NoneType' object value of non-optional type additional_partition_args detected when decoding CliPartitionConfig.\n",
90
+ " warnings.warn(\n",
91
+ "2024-06-20 16:15:22,908 MainProcess INFO running pipeline: DocFactory -> Reader -> Partitioner -> Chunker -> Copier with config: {\"reprocess\": false, \"verbose\": true, \"work_dir\": \"/Users/kwasia/.cache/unstructured/ingest/pipeline\", \"output_dir\": \"./output\", \"num_processes\": 2, \"raise_on_error\": false}\n",
92
+ "2024-06-20 16:15:24,658 MainProcess INFO Running doc factory to generate ingest docs. Source connector: {\"processor_config\": {\"reprocess\": false, \"verbose\": true, \"work_dir\": \"/Users/kwasia/.cache/unstructured/ingest/pipeline\", \"output_dir\": \"./output\", \"num_processes\": 2, \"raise_on_error\": false}, \"read_config\": {\"download_dir\": null, \"re_download\": false, \"preserve_downloads\": false, \"download_only\": false, \"max_docs\": null}, \"connector_config\": {\"input_path\": \"./test_docs/samba_turbo.pdf\", \"recursive\": true, \"file_glob\": null}}\n",
93
+ "2024-06-20 16:15:24,661 MainProcess INFO processing 1 docs via 2 processes\n",
94
+ "2024-06-20 16:15:24,661 MainProcess INFO Calling Reader with 1 docs\n",
95
+ "2024-06-20 16:15:24,661 MainProcess INFO Running source node to download data associated with ingest docs\n",
96
+ "2024-06-20 16:15:26,511 SpawnPoolWorker-3 INFO File exists: test_docs/samba_turbo.pdf, skipping download\n",
97
+ "2024-06-20 16:15:26,522 MainProcess INFO Calling Partitioner with 1 docs\n",
98
+ "2024-06-20 16:15:26,523 MainProcess INFO Running partition node to extract content from json files. Config: {\"pdf_infer_table_structure\": true, \"strategy\": \"auto\", \"ocr_languages\": [\"eng\"], \"encoding\": \"utf-8\", \"additional_partition_args\": null, \"skip_infer_table_types\": null, \"fields_include\": [\"element_id\", \"text\", \"type\", \"metadata\", \"embeddings\"], \"flatten_metadata\": false, \"metadata_exclude\": [\"--metadata-include\"], \"metadata_include\": [], \"partition_endpoint\": \"http://localhost:8005\", \"partition_by_api\": true, \"api_key\": \"*******\", \"hi_res_model_name\": null}, partition kwargs: {}]\n",
99
+ "2024-06-20 16:15:26,523 MainProcess INFO Creating /Users/kwasia/.cache/unstructured/ingest/pipeline/partitioned\n",
100
+ "2024-06-20 16:15:28,387 SpawnPoolWorker-4 INFO Processing test_docs/samba_turbo.pdf\n",
101
+ "2024-06-20 16:15:29,836 SpawnPoolWorker-4 DEBUG Using remote partition (http://localhost:8005)\n",
102
+ "2024-06-20 16:15:40,244 SpawnPoolWorker-4 INFO writing partitioned content to /Users/kwasia/.cache/unstructured/ingest/pipeline/partitioned/eb87c25354d57b8c7434994ca9c3f796.json\n",
103
+ "2024-06-20 16:15:40,254 MainProcess INFO Calling Chunker with 1 docs\n",
104
+ "2024-06-20 16:15:40,255 MainProcess INFO Running chunking node. Chunking config: {\"chunking_strategy\": \"basic\", \"combine_text_under_n_chars\": null, \"include_orig_elements\": true, \"max_characters\": 1500, \"multipage_sections\": true, \"new_after_n_chars\": null, \"overlap\": 300, \"overlap_all\": false}]\n",
105
+ "2024-06-20 16:15:40,255 MainProcess INFO Creating /Users/kwasia/.cache/unstructured/ingest/pipeline/chunked\n",
106
+ "2024-06-20 16:15:42,318 SpawnPoolWorker-6 INFO writing chunking content to /Users/kwasia/.cache/unstructured/ingest/pipeline/chunked/df2636b5a36c11e91958dfd7ae81ddb1.json\n",
107
+ "2024-06-20 16:15:42,323 MainProcess INFO Calling Copier with 1 docs\n",
108
+ "2024-06-20 16:15:42,323 MainProcess INFO Running copy node to move content to desired output location\n",
109
+ "2024-06-20 16:15:44,114 SpawnPoolWorker-9 INFO Copying /Users/kwasia/.cache/unstructured/ingest/pipeline/chunked/df2636b5a36c11e91958dfd7ae81ddb1.json -> output/samba_turbo.pdf.json\n",
110
+ "2024-06-20 16:15:44,320 - INFO - Ingest process completed successfully!\n",
111
+ "2024-06-20 16:15:44,321 - INFO - Performing additional processing...\n",
112
+ "2024-06-20 16:15:44,324 - INFO - Additional processing completed.\n"
113
+ ]
114
+ }
115
+ ],
116
+ "source": [
117
+ "config_yaml = './config.yaml'\n",
118
+ "sambaparse = SambaParse(config_yaml)\n",
119
+ "\n",
120
+ "source_type = 'local'\n",
121
+ "input_path = './test_docs/samba_turbo.pdf'\n",
122
+ "additional_metadata = {'key': 'value'}\n",
123
+ "\n",
124
+ "texts, metadata_list, langchain_docs = sambaparse.run_ingest(\n",
125
+ " source_type, input_path=input_path, additional_metadata=additional_metadata\n",
126
+ ")"
127
+ ]
128
+ },
129
+ {
130
+ "cell_type": "code",
131
+ "execution_count": 20,
132
+ "metadata": {},
133
+ "outputs": [
134
+ {
135
+ "name": "stdout",
136
+ "output_type": "stream",
137
+ "text": [
138
+ "This is the length of the lanchain docs 5\n",
139
+ "This is an example langcahin doc \n",
140
+ "\n",
141
+ " page_content=\"6/20/24, 3:23 PM\\n\\nSambaNova has broken the 1000 t/s barrier: why it's a big deal for enterprise AI\\n\\nG\\\\SambaNovar\\n\\nEN\\n\\nBACK TO RESOURCES\\n\\n<\\n\\nPREVIOUS | NEXT\\n\\n>\\n\\nMay 29, 2024\\n\\njn\\n\\nNX\\n\\nfF\\n\\nBS\\n\\nSambaNova has broken the 1000 t/s barrier: why it's a big deal for enterprise AI\\n\\nSambaNova is the clear winner of the latest large language model LLM benchmark by Artificial Analysis. Topping the Leaderboad at over 1000 tokens per second (t/s), Samba-1 Turbo sets a new record for Llama 3 8B performance on a single SN40L node and with full precision.\\n\\nWith speeds like this, enterprises can expect to accelerate an array of use cases and will enable innovation around unblocking agentic workflow, copilot, and synthetic data, to name a few. This breakthrough in AI technology is possible because the purpose-built SambaNova SN40L Reconfigurable Dataflow Unit RDU can hold hundreds of models at the same time and can switch between them in microseconds.\\n\\nSpeed for today and tomorrow\" metadata={'filename': 'samba_turbo.pdf', 'filetype': 'application/pdf', 'languages': 'eng', 'page_number': '1', 'orig_elements': 'eJzVl21v2zYQx7/KwW+2AV7DJ1FUMQxI22wrlqZFHrYCbVHw4WhzkSVBkut63b77jvYejCJF7BdDkleCyBN597v/Hak3nyZY4wKb8X0Kk8cwKWzl0TimbZBSVEWprZZRSx20YbJwkylMFjjaYEdL9p8mMdXY2AXmjwe7cPb9uOxd+6gLMdvm6XHdbaZt19XJ2zG1zdHf07VtZks7w4Hm30ywmU3e0WhHI++b5cJhT+P8Txoa8eOY19BHgh0JNQX5WEh49SIv8s/6P6EN9AWZfx5VjIWpKvRWRFlZ6dEJxr3SElEFr8xdR3WRtzhrP1iY2wFc315jA+McgTPGYDyiMdv3CfvHsJqvIY1fDWDBpRkEtDXEtgcKFvuuTwPC8fP9qJQlE8YV0ToeXamKqiAuKI3RJTNFFe6CymakPyBzuxh/fPv2X5L9LoPLNNZ4EwIUpQqxRKYL4RwtaQxHdMYrFSup7lzuu0E8X5DdTUGEGENRVbIqfCU1s0wFGazTVMGFJnZ3HcTJ2X56VFgq6jWoKMOlM0FykmMMEjk6LoW4D3q8lfVu4E+On/4Mly/h/OTi5dX505OLvTQZjUWunZeclwKDFJ4JpqMzUflKlXgfMNyaql0M3+2GfdXQ9jhr+/Q7hstscQMCxlUUUikUSjkjnePCCGFtYJE77x+eEl6dn/zy/OXVBfwBZyevL/fSQSWo+3tqzRELaniyUpoTFKkq5JWLd9KbPodwa6J2IXx/sA5Ky6pYeMZZITPlUAmnOTFBI6lC2IND8MKuQVRTEEyoXRpndLyTKx/wSyQsL5UpiAcLwUWLRisTSsmCcVYUgd8HEgdVxG/NXkVQBCtj0JJjLA1TvhROBaF9QR3IBV48uLjPXu8VNzeVQDR05SF5Oead08JjENFLr6R4ePmOP+wVty+rUvsouLHcM7qHlbzAwAXngU4Zey8Ov8PuAPsd+nQniq5iOieawi4dD5F5Q1WvqkKZ+ODi/j9/ab58dQqs0oXWBT0tp0opQwx0I6GyQS+re3Fe3JrpmymmYUPP12h7WKWmwR7auBmr6SAdRnr0s/yy9QkWbcAaTk9fgMPGzxe2vwa3huN+TDH5RIyPG1uvhzQ8gsu261Iz26x2urmfu9YGsCO0H2ifbcpyBgfo6H1A3zYBvqYsfjOFjY/fcrjM5GhuzIlscAU9mfVhk8rT2i4sSDBP8gI0srCNR2gbMh1o5xrh4kyxU2jIa7C0+CqNc4jLuoaO1kkDYX900ImpSAUycG1N6Zig25NkvIraMvqLlE7ci4o6SAu/ZiJDhxgGqNM1UrbSMN2pkgG8bQA/duhHShdY7wkJUcpEIfNaZ8ksqZ68zeZbzoSY4naUAlIVSS1HRdbtkmaXjatbf521QZE0Y/KwavvrWLerKfi2S3U7TjfrDOuG1JMNMtBp3j/DpPxGXJHCyFfqAmivxzktPZvTZlTVMKKfN23dztZZ4V07DCl74uiPP/uZBdktexrHb90y1SP8VxFbwZxnLcY0W/abEJ7R5tk7uGrSCOfPrjZM5m0dYE4B9RkeMdgUx5AFnncYsqNjWmyVlz8YSH5+Tm6MK9z2rUV2eJF8327VPxymRsOYDFHooL1WwmjDCxc8liW6KnJu74MaD+vvWYib2h7bQMLK5MZ20fZ9u7qhV7/7C2wCbXA=', 'key': 'value', 'type': 'CompositeElement', 'element_id': '34922f62e3c3e7600d32eb0627b79202', 'page': '1'}\n"
142
+ ]
143
+ }
144
+ ],
145
+ "source": [
146
+ "# Inspect the Output\n",
147
+ "\n",
148
+ "# 1. Number of Chunks\n",
149
+ "print(f'This is the length of the lanchain docs {len(langchain_docs)}')\n",
150
+ "\n",
151
+ "# 2. Example Chunk\n",
152
+ "print(f'This is an example langcahin doc \\n\\n {langchain_docs[0]}')"
153
+ ]
154
+ },
155
+ {
156
+ "cell_type": "markdown",
157
+ "metadata": {},
158
+ "source": [
159
+ "# Use Case 2 - Process Whole Directory "
160
+ ]
161
+ },
162
+ {
163
+ "cell_type": "code",
164
+ "execution_count": null,
165
+ "metadata": {},
166
+ "outputs": [],
167
+ "source": [
168
+ "config_yaml = './config.yaml'\n",
169
+ "sambaparse = SambaParse(config_yaml)\n",
170
+ "\n",
171
+ "source_type = 'local'\n",
172
+ "input_path = './test_docs'\n",
173
+ "additional_metadata = {'key': 'value'}\n",
174
+ "\n",
175
+ "texts, metadata_list, langchain_docs = sambaparse.run_ingest(\n",
176
+ " source_type, input_path=input_path, additional_metadata=additional_metadata\n",
177
+ ")"
178
+ ]
179
+ },
180
+ {
181
+ "cell_type": "code",
182
+ "execution_count": 22,
183
+ "metadata": {},
184
+ "outputs": [
185
+ {
186
+ "name": "stdout",
187
+ "output_type": "stream",
188
+ "text": [
189
+ "This is the length of the lanchain docs 44\n",
190
+ "This is an example langcahin doc \n",
191
+ "\n",
192
+ " page_content=\"6/20/24, 3:23 PM\\n\\nSambaNova has broken the 1000 t/s barrier: why it's a big deal for enterprise AI\\n\\nG\\\\SambaNovar\\n\\nEN\\n\\nBACK TO RESOURCES\\n\\n<\\n\\nPREVIOUS | NEXT\\n\\n>\\n\\nMay 29, 2024\\n\\njn\\n\\nNX\\n\\nfF\\n\\nBS\\n\\nSambaNova has broken the 1000 t/s barrier: why it's a big deal for enterprise AI\\n\\nSambaNova is the clear winner of the latest large language model LLM benchmark by Artificial Analysis. Topping the Leaderboad at over 1000 tokens per second (t/s), Samba-1 Turbo sets a new record for Llama 3 8B performance on a single SN40L node and with full precision.\\n\\nWith speeds like this, enterprises can expect to accelerate an array of use cases and will enable innovation around unblocking agentic workflow, copilot, and synthetic data, to name a few. This breakthrough in AI technology is possible because the purpose-built SambaNova SN40L Reconfigurable Dataflow Unit RDU can hold hundreds of models at the same time and can switch between them in microseconds.\\n\\nSpeed for today and tomorrow\" metadata={'filename': 'samba_turbo.pdf', 'filetype': 'application/pdf', 'languages': 'eng', 'page_number': '1', 'orig_elements': 'eJzVl21v2zYQx7/KwW+2AV7DJ1FUMQxI22wrlqZFHrYCbVHw4WhzkSVBkut63b77jvYejCJF7BdDkleCyBN597v/Hak3nyZY4wKb8X0Kk8cwKWzl0TimbZBSVEWprZZRSx20YbJwkylMFjjaYEdL9p8mMdXY2AXmjwe7cPb9uOxd+6gLMdvm6XHdbaZt19XJ2zG1zdHf07VtZks7w4Hm30ywmU3e0WhHI++b5cJhT+P8Txoa8eOY19BHgh0JNQX5WEh49SIv8s/6P6EN9AWZfx5VjIWpKvRWRFlZ6dEJxr3SElEFr8xdR3WRtzhrP1iY2wFc315jA+McgTPGYDyiMdv3CfvHsJqvIY1fDWDBpRkEtDXEtgcKFvuuTwPC8fP9qJQlE8YV0ToeXamKqiAuKI3RJTNFFe6CymakPyBzuxh/fPv2X5L9LoPLNNZ4EwIUpQqxRKYL4RwtaQxHdMYrFSup7lzuu0E8X5DdTUGEGENRVbIqfCU1s0wFGazTVMGFJnZ3HcTJ2X56VFgq6jWoKMOlM0FykmMMEjk6LoW4D3q8lfVu4E+On/4Mly/h/OTi5dX505OLvTQZjUWunZeclwKDFJ4JpqMzUflKlXgfMNyaql0M3+2GfdXQ9jhr+/Q7hstscQMCxlUUUikUSjkjnePCCGFtYJE77x+eEl6dn/zy/OXVBfwBZyevL/fSQSWo+3tqzRELaniyUpoTFKkq5JWLd9KbPodwa6J2IXx/sA5Ky6pYeMZZITPlUAmnOTFBI6lC2IND8MKuQVRTEEyoXRpndLyTKx/wSyQsL5UpiAcLwUWLRisTSsmCcVYUgd8HEgdVxG/NXkVQBCtj0JJjLA1TvhROBaF9QR3IBV48uLjPXu8VNzeVQDR05SF5Oead08JjENFLr6R4ePmOP+wVty+rUvsouLHcM7qHlbzAwAXngU4Zey8Ov8PuAPsd+nQniq5iOieawi4dD5F5Q1WvqkKZ+ODi/j9/ab58dQqs0oXWBT0tp0opQwx0I6GyQS+re3Fe3JrpmymmYUPP12h7WKWmwR7auBmr6SAdRnr0s/yy9QkWbcAaTk9fgMPGzxe2vwa3huN+TDH5RIyPG1uvhzQ8gsu261Iz26x2urmfu9YGsCO0H2ifbcpyBgfo6H1A3zYBvqYsfjOFjY/fcrjM5GhuzIlscAU9mfVhk8rT2i4sSDBP8gI0srCNR2gbMh1o5xrh4kyxU2jIa7C0+CqNc4jLuoaO1kkDYX900ImpSAUycG1N6Zig25NkvIraMvqLlE7ci4o6SAu/ZiJDhxgGqNM1UrbSMN2pkgG8bQA/duhHShdY7wkJUcpEIfNaZ8ksqZ68zeZbzoSY4naUAlIVSS1HRdbtkmaXjatbf521QZE0Y/KwavvrWLerKfi2S3U7TjfrDOuG1JMNMtBp3j/DpPxGXJHCyFfqAmivxzktPZvTZlTVMKKfN23dztZZ4V07DCl74uiPP/uZBdktexrHb90y1SP8VxFbwZxnLcY0W/abEJ7R5tk7uGrSCOfPrjZM5m0dYE4B9RkeMdgUx5AFnncYsqNjWmyVlz8YSH5+Tm6MK9z2rUV2eJF8327VPxymRsOYDFHooL1WwmjDCxc8liW6KnJu74MaD+vvWYib2h7bQMLK5MZ20fZ9u7qhV7/7C2wCbXA=', 'key': 'value', 'type': 'CompositeElement', 'element_id': '34922f62e3c3e7600d32eb0627b79202', 'page': '1'}\n"
193
+ ]
194
+ }
195
+ ],
196
+ "source": [
197
+ "# Inspect the Output\n",
198
+ "\n",
199
+ "# 1. Number of Chunks\n",
200
+ "print(f'This is the length of the lanchain docs {len(langchain_docs)}')\n",
201
+ "\n",
202
+ "# 2. Example Chunk\n",
203
+ "print(f'This is an example langcahin doc \\n\\n {langchain_docs[0]}')"
204
+ ]
205
+ }
206
+ ],
207
+ "metadata": {
208
+ "kernelspec": {
209
+ "display_name": "aisk-fine-tune-embeddings",
210
+ "language": "python",
211
+ "name": "python3"
212
+ },
213
+ "language_info": {
214
+ "codemirror_mode": {
215
+ "name": "ipython",
216
+ "version": 3
217
+ },
218
+ "file_extension": ".py",
219
+ "mimetype": "text/x-python",
220
+ "name": "python",
221
+ "nbconvert_exporter": "python",
222
+ "pygments_lexer": "ipython3",
223
+ "version": "3.10.12"
224
+ }
225
+ },
226
+ "nbformat": 4,
227
+ "nbformat_minor": 2
228
+ }
utils/parsing/requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ unstructured==0.13.6
2
+ unstructured-client==0.18.0
3
+ unstructured-inference==0.7.29
4
+ langchain==0.1.16
5
+ PyMuPDF==1.23.4
6
+ PyMuPDFb==1.23.3
utils/parsing/sambaparse.py ADDED
@@ -0,0 +1,525 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import yaml
3
+ import subprocess
4
+ import json
5
+ import logging
6
+ from typing import Dict, Optional, List, Tuple, Union, Any
7
+ from dotenv import load_dotenv
8
+ from langchain.docstore.document import Document
9
+ import shutil
10
+ from langchain_community.document_loaders import PyMuPDFLoader
11
+
12
+ load_dotenv()
13
+
14
+ logging.basicConfig(
15
+ level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
16
+ )
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ class SambaParse:
21
+ def __init__(self, config_path: str):
22
+ with open(config_path, "r") as file:
23
+ self.config = yaml.safe_load(file)
24
+
25
+ # Set the default Unstructured API key as an environment variable if not already set
26
+ if "UNSTRUCTURED_API_KEY" not in os.environ:
27
+ default_api_key = self.config.get("partitioning", {}).get("default_unstructured_api_key")
28
+ if default_api_key:
29
+ os.environ["UNSTRUCTURED_API_KEY"] = default_api_key
30
+
31
+
32
+ def run_ingest(
33
+ self,
34
+ source_type: str,
35
+ input_path: Optional[str] = None,
36
+ additional_metadata: Optional[Dict] = None,
37
+ ) -> Tuple[List[str], List[Dict], List[Document]]:
38
+ """
39
+ Runs the ingest process for the specified source type and input path.
40
+
41
+ Args:
42
+ source_type (str): The type of source to ingest (e.g., 'local', 'confluence', 'github', 'google-drive').
43
+ input_path (Optional[str]): The input path for the source (only required for 'local' source type).
44
+ additional_metadata (Optional[Dict]): Additional metadata to include in the processed documents.
45
+
46
+ Returns:
47
+ Tuple[List[str], List[Dict], List[Document]]: A tuple containing the extracted texts, metadata, and LangChain documents.
48
+ """
49
+ if not self.config["partitioning"]["partition_by_api"]:
50
+ return self._run_ingest_pymupdf(input_path, additional_metadata)
51
+
52
+ output_dir = self.config["processor"]["output_dir"]
53
+
54
+ # Create the output directory if it doesn't exist
55
+ os.makedirs(output_dir, exist_ok=True)
56
+
57
+ # Delete contents of the output directory using shell command
58
+ del_command = f"rm -rf {output_dir}/*"
59
+ logger.info(f"Deleting contents of output directory: {output_dir}")
60
+ subprocess.run(del_command, shell=True, check=True)
61
+
62
+ command = [
63
+ "unstructured-ingest",
64
+ source_type,
65
+ "--output-dir",
66
+ output_dir,
67
+ "--num-processes",
68
+ str(self.config["processor"]["num_processes"]),
69
+ ]
70
+
71
+ if self.config["processor"]["reprocess"] == True:
72
+ command.extend(["--reprocess"])
73
+
74
+ # Add partition arguments
75
+ command.extend(
76
+ [
77
+ "--strategy",
78
+ self.config["partitioning"]["strategy"],
79
+ "--ocr-languages",
80
+ ",".join(self.config["partitioning"]["ocr_languages"]),
81
+ "--encoding",
82
+ self.config["partitioning"]["encoding"],
83
+ "--fields-include",
84
+ ",".join(self.config["partitioning"]["fields_include"]),
85
+ "--metadata-exclude",
86
+ ",".join(self.config["partitioning"]["metadata_exclude"]),
87
+ "--metadata-include",
88
+ ",".join(self.config["partitioning"]["metadata_include"]),
89
+ ]
90
+ )
91
+
92
+ if self.config["partitioning"]["skip_infer_table_types"]:
93
+ command.extend(
94
+ [
95
+ "--skip-infer-table-types",
96
+ ",".join(self.config["partitioning"]["skip_infer_table_types"]),
97
+ ]
98
+ )
99
+
100
+ if self.config["partitioning"]["flatten_metadata"]:
101
+ command.append("--flatten-metadata")
102
+
103
+ if source_type == "local":
104
+ if input_path is None:
105
+ raise ValueError("Input path is required for local source type.")
106
+ command.extend(["--input-path", f'"{input_path}"'])
107
+
108
+ if self.config["sources"]["local"]["recursive"]:
109
+ command.append("--recursive")
110
+ elif source_type == "confluence":
111
+ command.extend(
112
+ [
113
+ "--url",
114
+ self.config["sources"]["confluence"]["url"],
115
+ "--user-email",
116
+ self.config["sources"]["confluence"]["user_email"],
117
+ "--api-token",
118
+ self.config["sources"]["confluence"]["api_token"],
119
+ ]
120
+ )
121
+ elif source_type == "github":
122
+ command.extend(
123
+ [
124
+ "--url",
125
+ self.config["sources"]["github"]["url"],
126
+ "--git-branch",
127
+ self.config["sources"]["github"]["branch"],
128
+ ]
129
+ )
130
+ elif source_type == "google-drive":
131
+ command.extend(
132
+ [
133
+ "--drive-id",
134
+ self.config["sources"]["google_drive"]["drive_id"],
135
+ "--service-account-key",
136
+ self.config["sources"]["google_drive"]["service_account_key"],
137
+ ]
138
+ )
139
+ if self.config["sources"]["google_drive"]["recursive"]:
140
+ command.append("--recursive")
141
+ else:
142
+ raise ValueError(f"Unsupported source type: {source_type}")
143
+
144
+ if self.config["processor"]["verbose"]:
145
+ command.append("--verbose")
146
+
147
+ if self.config["partitioning"]["partition_by_api"]:
148
+ api_key = os.getenv("UNSTRUCTURED_API_KEY")
149
+ partition_endpoint_url = f"{self.config['partitioning']['partition_endpoint']}:{self.config['partitioning']['unstructured_port']}"
150
+ if api_key:
151
+ command.extend(["--partition-by-api", "--api-key", api_key])
152
+ command.extend(["--partition-endpoint", partition_endpoint_url])
153
+ else:
154
+ logger.warning("No Unstructured API key available. Partitioning by API will be skipped.")
155
+
156
+ if self.config["partitioning"]["strategy"] == "hi_res":
157
+ if (
158
+ "hi_res_model_name" in self.config["partitioning"]
159
+ and self.config["partitioning"]["hi_res_model_name"]
160
+ ):
161
+ command.extend(
162
+ [
163
+ "--hi-res-model-name",
164
+ self.config["partitioning"]["hi_res_model_name"],
165
+ ]
166
+ )
167
+ logger.warning(
168
+ "You've chosen the high-resolution partitioning strategy. Grab a cup of coffee or tea while you wait, as this may take some time due to OCR and table detection."
169
+ )
170
+
171
+ if self.config["chunking"]["enabled"]:
172
+ command.extend(
173
+ [
174
+ "--chunking-strategy",
175
+ self.config["chunking"]["strategy"],
176
+ "--chunk-max-characters",
177
+ str(self.config["chunking"]["chunk_max_characters"]),
178
+ "--chunk-overlap",
179
+ str(self.config["chunking"]["chunk_overlap"]),
180
+ ]
181
+ )
182
+
183
+ if self.config["chunking"]["strategy"] == "by_title":
184
+ command.extend(
185
+ [
186
+ "--chunk-combine-text-under-n-chars",
187
+ str(self.config["chunking"]["combine_under_n_chars"]),
188
+ ]
189
+ )
190
+
191
+ if self.config["embedding"]["enabled"]:
192
+ command.extend(
193
+ [
194
+ "--embedding-provider",
195
+ self.config["embedding"]["provider"],
196
+ "--embedding-model-name",
197
+ self.config["embedding"]["model_name"],
198
+ ]
199
+ )
200
+
201
+ if self.config["destination_connectors"]["enabled"]:
202
+ destination_type = self.config["destination_connectors"]["type"]
203
+ if destination_type == "chroma":
204
+ command.extend(
205
+ [
206
+ "chroma",
207
+ "--host",
208
+ self.config["destination_connectors"]["chroma"]["host"],
209
+ "--port",
210
+ str(self.config["destination_connectors"]["chroma"]["port"]),
211
+ "--collection-name",
212
+ self.config["destination_connectors"]["chroma"][
213
+ "collection_name"
214
+ ],
215
+ "--tenant",
216
+ self.config["destination_connectors"]["chroma"]["tenant"],
217
+ "--database",
218
+ self.config["destination_connectors"]["chroma"]["database"],
219
+ "--batch-size",
220
+ str(self.config["destination_connectors"]["batch_size"]),
221
+ ]
222
+ )
223
+ elif destination_type == "qdrant":
224
+ command.extend(
225
+ [
226
+ "qdrant",
227
+ "--location",
228
+ self.config["destination_connectors"]["qdrant"]["location"],
229
+ "--collection-name",
230
+ self.config["destination_connectors"]["qdrant"][
231
+ "collection_name"
232
+ ],
233
+ "--batch-size",
234
+ str(self.config["destination_connectors"]["batch_size"]),
235
+ ]
236
+ )
237
+ else:
238
+ raise ValueError(
239
+ f"Unsupported destination connector type: {destination_type}"
240
+ )
241
+
242
+ command_str = " ".join(command)
243
+ logger.info(f"Running command: {command_str}")
244
+ logger.info(
245
+ "This may take some time depending on the size of your data. Please be patient..."
246
+ )
247
+
248
+ subprocess.run(command_str, shell=True, check=True)
249
+
250
+ logger.info("Ingest process completed successfully!")
251
+
252
+ # Call the additional processing function if enabled
253
+ if self.config["additional_processing"]["enabled"]:
254
+ logger.info("Performing additional processing...")
255
+ texts, metadata_list, langchain_docs = additional_processing(
256
+ directory=output_dir,
257
+ extend_metadata=self.config["additional_processing"]["extend_metadata"],
258
+ additional_metadata=additional_metadata,
259
+ replace_table_text=self.config["additional_processing"][
260
+ "replace_table_text"
261
+ ],
262
+ table_text_key=self.config["additional_processing"]["table_text_key"],
263
+ return_langchain_docs=self.config["additional_processing"][
264
+ "return_langchain_docs"
265
+ ],
266
+ convert_metadata_keys_to_string=self.config["additional_processing"][
267
+ "convert_metadata_keys_to_string"
268
+ ],
269
+ )
270
+ logger.info("Additional processing completed.")
271
+ return texts, metadata_list, langchain_docs
272
+
273
+ def _run_ingest_pymupdf(
274
+ self, input_path: str, additional_metadata: Optional[Dict] = None
275
+ ) -> Tuple[List[str], List[Dict], List[Document]]:
276
+ """
277
+ Runs the ingest process using PyMuPDF via LangChain.
278
+
279
+ Args:
280
+ input_path (str): The input path for the source.
281
+ additional_metadata (Optional[Dict]): Additional metadata to include in the processed documents.
282
+
283
+ Returns:
284
+ Tuple[List[str], List[Dict], List[Document]]: A tuple containing the extracted texts, metadata, and LangChain documents.
285
+ """
286
+ if not input_path:
287
+ raise ValueError("Input path is required for PyMuPDF processing.")
288
+
289
+ texts = []
290
+ metadata_list = []
291
+ langchain_docs = []
292
+
293
+ if os.path.isfile(input_path):
294
+ file_paths = [input_path]
295
+ else:
296
+ file_paths = [
297
+ os.path.join(input_path, f)
298
+ for f in os.listdir(input_path)
299
+ if f.lower().endswith('.pdf')
300
+ ]
301
+
302
+ for file_path in file_paths:
303
+ loader = PyMuPDFLoader(file_path)
304
+ docs = loader.load()
305
+
306
+ for doc in docs:
307
+ text = doc.page_content
308
+ metadata = doc.metadata
309
+
310
+ # Add 'filename' key to metadata
311
+ metadata['filename'] = os.path.basename(metadata['source'])
312
+
313
+ if additional_metadata:
314
+ metadata.update(additional_metadata)
315
+
316
+ texts.append(text)
317
+ metadata_list.append(metadata)
318
+ langchain_docs.append(doc)
319
+
320
+ return texts, metadata_list, langchain_docs
321
+
322
+
323
+ def convert_to_string(value: Union[List, Tuple, Dict, Any]) -> str:
324
+ """
325
+ Convert a value to its string representation.
326
+
327
+ Args:
328
+ value (Union[List, Tuple, Dict, Any]): The value to be converted to a string.
329
+
330
+ Returns:
331
+ str: The string representation of the value.
332
+ """
333
+ if isinstance(value, (list, tuple)):
334
+ return ", ".join(map(str, value))
335
+ elif isinstance(value, dict):
336
+ return json.dumps(value)
337
+ else:
338
+ return str(value)
339
+
340
+
341
+ def additional_processing(
342
+ directory: str,
343
+ extend_metadata: bool,
344
+ additional_metadata: Optional[Dict],
345
+ replace_table_text: bool,
346
+ table_text_key: str,
347
+ return_langchain_docs: bool,
348
+ convert_metadata_keys_to_string: bool,
349
+ ):
350
+ """
351
+ Performs additional processing on the extracted documents.
352
+
353
+ Args:
354
+ directory (str): The directory containing the extracted JSON files.
355
+ extend_metadata (bool): Whether to extend the metadata with additional metadata.
356
+ additional_metadata (Optional[Dict]): Additional metadata to include in the processed documents.
357
+ replace_table_text (bool): Whether to replace table text with the specified table text key.
358
+ table_text_key (str): The key to use for replacing table text.
359
+ return_langchain_docs (bool): Whether to return LangChain documents.
360
+ convert_metadata_keys_to_string (bool): Whether to convert non-string metadata keys to string.
361
+
362
+ Returns:
363
+ Tuple[List[str], List[Dict], List[Document]]: A tuple containing the extracted texts, metadata, and LangChain documents.
364
+ """
365
+ if os.path.isfile(directory):
366
+ file_paths = [directory]
367
+ else:
368
+ file_paths = [
369
+ os.path.join(directory, f)
370
+ for f in os.listdir(directory)
371
+ if f.endswith(".json")
372
+ ]
373
+
374
+ texts = []
375
+ metadata_list = []
376
+ langchain_docs = []
377
+
378
+ for file_path in file_paths:
379
+ with open(file_path, "r") as file:
380
+ data = json.load(file)
381
+
382
+ for element in data:
383
+ if extend_metadata and additional_metadata:
384
+ element["metadata"].update(additional_metadata)
385
+
386
+ if replace_table_text and element["type"] == "Table":
387
+ element["text"] = element["metadata"][table_text_key]
388
+
389
+ metadata = element["metadata"].copy()
390
+ if convert_metadata_keys_to_string:
391
+ metadata = {
392
+ str(key): convert_to_string(value)
393
+ for key, value in metadata.items()
394
+ }
395
+ for key in element:
396
+ if key not in ["text", "metadata", "embeddings"]:
397
+ metadata[key] = element[key]
398
+ if "page_number" in metadata:
399
+ metadata["page"] = metadata["page_number"]
400
+ else:
401
+ metadata["page"] = 1
402
+
403
+ metadata_list.append(metadata)
404
+ texts.append(element["text"])
405
+
406
+ if return_langchain_docs:
407
+ langchain_docs.extend(get_langchain_docs(texts, metadata_list))
408
+
409
+ with open(file_path, "w") as file:
410
+ json.dump(data, file, indent=2)
411
+
412
+ return texts, metadata_list, langchain_docs
413
+
414
+
415
+ def get_langchain_docs(texts: List[str], metadata_list: List[Dict]) -> List[Document]:
416
+ """
417
+ Creates LangChain documents from the extracted texts and metadata.
418
+
419
+ Args:
420
+ texts (List[str]): The extracted texts.
421
+ metadata_list (List[Dict]): The metadata associated with each text.
422
+
423
+ Returns:
424
+ List[Document]: A list of LangChain documents.
425
+ """
426
+ return [
427
+ Document(page_content=content, metadata=metadata)
428
+ for content, metadata in zip(texts, metadata_list)
429
+ ]
430
+
431
+
432
+ def parse_doc_universal(
433
+ doc: str, additional_metadata: Optional[Dict] = None, source_type: str = "local"
434
+ ) -> Tuple[List[str], List[Dict], List[Document]]:
435
+ """
436
+ Extract text, tables, images, and metadata from a document or a folder of documents.
437
+
438
+ Args:
439
+ doc (str): Path to the document or folder of documents.
440
+ additional_metadata (Optional[Dict], optional): Additional metadata to include in the processed documents.
441
+ Defaults to an empty dictionary.
442
+ source_type (str, optional): The type of source to ingest. Defaults to 'local'.
443
+
444
+ Returns:
445
+ Tuple[List[str], List[Dict], List[Document]]: A tuple containing:
446
+ - A list of extracted text per page.
447
+ - A list of extracted metadata per page.
448
+ - A list of LangChain documents.
449
+ """
450
+ if additional_metadata is None:
451
+ additional_metadata = {}
452
+
453
+ # Get the directory of the current file
454
+ current_dir = os.path.dirname(os.path.abspath(__file__))
455
+
456
+ # Join the current directory with the relative path of the config file
457
+ config_path = os.path.join(current_dir, "config.yaml")
458
+
459
+ wrapper = SambaParse(config_path)
460
+
461
+ def process_file(file_path):
462
+ if file_path.lower().endswith('.pdf'):
463
+ return wrapper._run_ingest_pymupdf(file_path, additional_metadata)
464
+ else:
465
+ # Use the original method for non-PDF files
466
+ return wrapper.run_ingest(source_type, input_path=file_path, additional_metadata=additional_metadata)
467
+
468
+ if os.path.isfile(doc):
469
+ return process_file(doc)
470
+ else:
471
+ all_texts, all_metadata, all_docs = [], [], []
472
+ for root, _, files in os.walk(doc):
473
+ for file in files:
474
+ file_path = os.path.join(root, file)
475
+ texts, metadata_list, langchain_docs = process_file(file_path)
476
+ all_texts.extend(texts)
477
+ all_metadata.extend(metadata_list)
478
+ all_docs.extend(langchain_docs)
479
+ return all_texts, all_metadata, all_docs
480
+
481
+
482
+ def parse_doc_streamlit(docs: List,
483
+ kit_dir: str,
484
+ additional_metadata: Optional[Dict] = None,
485
+ ) -> List[Document]:
486
+ """
487
+ Parse the uploaded documents and return a list of LangChain documents.
488
+
489
+ Args:
490
+ docs (List[UploadFile]): A list of uploaded files.
491
+ kit_dir (str): The directory of the current kit.
492
+ additional_metadata (Optional[Dict], optional): Additional metadata to include in the processed documents.
493
+ Defaults to an empty dictionary.
494
+
495
+ Returns:
496
+ List[Document]: A list of LangChain documents.
497
+ """
498
+ if additional_metadata is None:
499
+ additional_metadata = {}
500
+
501
+ # Create the data/tmp folder if it doesn't exist
502
+ temp_folder = os.path.join(kit_dir, "data/tmp")
503
+ if not os.path.exists(temp_folder):
504
+ os.makedirs(temp_folder)
505
+ else:
506
+ # If there are already files there, delete them
507
+ for filename in os.listdir(temp_folder):
508
+ file_path = os.path.join(temp_folder, filename)
509
+ try:
510
+ if os.path.isfile(file_path) or os.path.islink(file_path):
511
+ os.unlink(file_path)
512
+ elif os.path.isdir(file_path):
513
+ shutil.rmtree(file_path)
514
+ except Exception as e:
515
+ print(f'Failed to delete {file_path}. Reason: {e}')
516
+
517
+ # Save all selected files to the tmp dir with their file names
518
+ for doc in docs:
519
+ temp_file = os.path.join(temp_folder, doc.name)
520
+ with open(temp_file, "wb") as f:
521
+ f.write(doc.getvalue())
522
+
523
+ # Pass in the temp folder for processing into the parse_doc_universal function
524
+ _, _, langchain_docs = parse_doc_universal(doc=temp_folder, additional_metadata=additional_metadata)
525
+ return langchain_docs
utils/vectordb/create_vector_db.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Define the script's usage example
2
+ USAGE_EXAMPLE = """
3
+ Example usage:
4
+
5
+ To process input *.txt files at input_path and save the vector db output at output_db:
6
+ python create_vector_db.py input_path output_db --chunk_size 100 --chunk_overlap 10
7
+
8
+ Required arguments:
9
+ - input_path: Path to the input dir containing the .txt files
10
+ - output_path: Path to the output vector db.
11
+
12
+ Optional arguments:
13
+ - --chunk_size: Size of the chunks (default: None).
14
+ - --chunk_overlap: Overlap between chunks (default: None).
15
+ """
16
+
17
+ import argparse
18
+ import logging
19
+ import os
20
+
21
+ from langchain.document_loaders import DirectoryLoader
22
+ from langchain.embeddings import HuggingFaceInstructEmbeddings
23
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
24
+ from langchain.vectorstores import FAISS, Chroma, Qdrant
25
+
26
+ # Configure the logger
27
+ logging.basicConfig(
28
+ level=logging.INFO, # Set the logging level (e.g., INFO, DEBUG)
29
+ format="%(asctime)s [%(levelname)s] - %(message)s", # Define the log message format
30
+ handlers=[
31
+ logging.StreamHandler(), # Output logs to the console
32
+ logging.FileHandler("create_vector_db.log"),
33
+ ],
34
+ )
35
+
36
+ # Create a logger object
37
+ logger = logging.getLogger(__name__)
38
+
39
+
40
+ # Parse the arguments
41
+ def parse_arguments():
42
+ parser = argparse.ArgumentParser(description="Process command line arguments.")
43
+ parser.add_argument("-input_path", type=dir_path, help="path to input directory")
44
+ parser.add_argument("--chunk_size", type=int, help="chunk size for splitting")
45
+ parser.add_argument("--chunk_overlap", type=int, help="chunk overlap for splitting")
46
+ parser.add_argument("-output_path", type=dir_path, help="path to input directory")
47
+
48
+ return parser.parse_args()
49
+
50
+
51
+ # Check valid path
52
+ def dir_path(path):
53
+ if os.path.isdir(path):
54
+ return path
55
+ else:
56
+ raise argparse.ArgumentTypeError(f"readable_dir:{path} is not a valid path")
57
+
58
+
59
+ def main(input_path, output_db, chunk_size, chunk_overlap, db_type):
60
+ # Load files from input_location
61
+ loader = DirectoryLoader(input_path, glob="*.txt")
62
+ docs = loader.load()
63
+ logger.info(f"Total {len(docs)} files loaded")
64
+
65
+ # get the text chunks
66
+ text_splitter = RecursiveCharacterTextSplitter(
67
+ chunk_size=chunk_size, chunk_overlap=chunk_overlap, length_function=len
68
+ )
69
+ chunks = text_splitter.split_documents(docs)
70
+ logger.info(f"Total {len(chunks)} chunks created")
71
+
72
+ # create vector store
73
+ encode_kwargs = {"normalize_embeddings": True}
74
+ embedding_model = "BAAI/bge-large-en"
75
+ embeddings = HuggingFaceInstructEmbeddings(
76
+ model_name=embedding_model,
77
+ embed_instruction="", # no instruction is needed for candidate passages
78
+ query_instruction="Represent this sentence for searching relevant passages: ",
79
+ encode_kwargs=encode_kwargs,
80
+ )
81
+ logger.info(
82
+ f"Processing embeddings using {embedding_model}. This could take time depending on the number of chunks ..."
83
+ )
84
+
85
+ if db_type == "faiss":
86
+ vectorstore = FAISS.from_documents(documents=chunks, embedding=embeddings)
87
+ # save vectorstore
88
+ vectorstore.save_local(output_db)
89
+ elif db_type == "chromadb":
90
+ vectorstore = Chroma.from_documents(
91
+ documents=chunks, embedding=embeddings, persist_directory=output_db
92
+ )
93
+ elif db_type == "qdrant":
94
+ vectorstore = Qdrant.from_documents(
95
+ documents=chunks,
96
+ embedding=embeddings,
97
+ path=output_db,
98
+ collection_name="test_collection",
99
+ )
100
+ elif db_type == "qdrant-server":
101
+ url = "http://localhost:6333/"
102
+ vectorstore = Qdrant.from_documents(
103
+ documents=chunks,
104
+ embedding=embeddings,
105
+ url=url,
106
+ prefer_grpc=True,
107
+ collection_name="anaconda",
108
+ )
109
+
110
+ logger.info(f"Vector store saved to {output_db}")
111
+
112
+
113
+ if __name__ == "__main__":
114
+ parser = argparse.ArgumentParser(description="Process data with optional chunking")
115
+
116
+ # Required arguments
117
+ parser.add_argument("input_path", type=str, help="Path to the input directory")
118
+ parser.add_argument("output_db", type=str, help="Path to the output vectordb")
119
+
120
+ # Optional arguments
121
+ parser.add_argument(
122
+ "--chunk_size", type=int, default=1000, help="Chunk size (default: 1000)"
123
+ )
124
+ parser.add_argument(
125
+ "--chunk_overlap", type=int, default=200, help="Chunk overlap (default: 200)"
126
+ )
127
+ parser.add_argument(
128
+ "--db_type",
129
+ type=str,
130
+ default="faiss",
131
+ help="Type of vectorstore (default: faiss)",
132
+ )
133
+
134
+ args = parser.parse_args()
135
+ main(
136
+ args.input_path,
137
+ args.output_db,
138
+ args.chunk_size,
139
+ args.chunk_overlap,
140
+ args.db_type,
141
+ )
utils/vectordb/vector_db.py ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Define the script's usage example
2
+ USAGE_EXAMPLE = """
3
+ Example usage:
4
+
5
+ To process input *.txt files at input_path and save the vector db output at output_db:
6
+ python create_vector_db.py input_path output_db --chunk_size 100 --chunk_overlap 10
7
+
8
+ Required arguments:
9
+ - input_path: Path to the input dir containing the .txt files
10
+ - output_path: Path to the output vector db.
11
+
12
+ Optional arguments:
13
+ - --chunk_size: Size of the chunks (default: None).
14
+ - --chunk_overlap: Overlap between chunks (default: None).
15
+ """
16
+
17
+ import os
18
+ import sys
19
+ import argparse
20
+ import logging
21
+
22
+ from langchain_community.document_loaders import DirectoryLoader, UnstructuredURLLoader
23
+ from langchain_community.embeddings import HuggingFaceInstructEmbeddings
24
+ from langchain.text_splitter import RecursiveCharacterTextSplitter, CharacterTextSplitter
25
+ from langchain_community.vectorstores import FAISS, Chroma, Qdrant
26
+
27
+ vectordb_dir = os.path.dirname(os.path.abspath(__file__))
28
+ utils_dir = os.path.abspath(os.path.join(vectordb_dir, ".."))
29
+ repo_dir = os.path.abspath(os.path.join(utils_dir, ".."))
30
+
31
+ sys.path.append(repo_dir)
32
+ sys.path.append(utils_dir)
33
+
34
+ from utils.model_wrappers.api_gateway import APIGateway
35
+ import uuid
36
+ import streamlit as st
37
+
38
+ EMBEDDING_MODEL = "intfloat/e5-large-v2"
39
+ NORMALIZE_EMBEDDINGS = True
40
+ VECTORDB_LOG_FILE_NAME = "vector_db.log"
41
+
42
+ # Configure the logger
43
+ logging.basicConfig(
44
+ level=logging.INFO, # Set the logging level (e.g., INFO, DEBUG)
45
+ format="%(asctime)s [%(levelname)s] - %(message)s", # Define the log message format
46
+ handlers=[
47
+ logging.StreamHandler(), # Output logs to the console
48
+ logging.FileHandler(VECTORDB_LOG_FILE_NAME),
49
+ ],
50
+ )
51
+
52
+ # Create a logger object
53
+ logger = logging.getLogger(__name__)
54
+
55
+
56
+ class VectorDb():
57
+ """
58
+ A class for creating, updating and loading FAISS or Chroma vector databases,
59
+ to use them with retrieval augmented generation tasks with langchain
60
+
61
+ Args:
62
+ None
63
+
64
+ Attributes:
65
+ None
66
+
67
+ Methods:
68
+ load_files: Load files from an input directory as langchain documents
69
+ get_text_chunks: Get text chunks from a list of documents
70
+ get_token_chunks: Get token chunks from a list of documents
71
+ create_vector_store: Create a vector store from chunks and an embedding model
72
+ load_vdb: load a previous stored vector database
73
+ update_vdb: Update an existing vector store with new chunks
74
+ create_vdb: Create a vector database from the raw files in a specific input directory
75
+ """
76
+ def __init__(self) -> None:
77
+ self.collection_id = str(uuid.uuid4())
78
+ self.vector_collections = set()
79
+
80
+ def load_files(self, input_path, recursive=False, load_txt=True, load_pdf=False, urls = None) -> list:
81
+ """Load files from input location
82
+
83
+ Args:
84
+ input_path : input location of files
85
+ recursive (bool, optional): flag to load files recursively. Defaults to False.
86
+ load_txt (bool, optional): flag to load txt files. Defaults to True.
87
+ load_pdf (bool, optional): flag to load pdf files. Defaults to False.
88
+ urls (list, optional): list of urls to load. Defaults to None.
89
+
90
+ Returns:
91
+ list: list of documents
92
+ """
93
+ docs=[]
94
+ text_loader_kwargs={'autodetect_encoding': True}
95
+ if input_path is not None:
96
+ if load_txt:
97
+ loader = DirectoryLoader(input_path, glob="*.txt", recursive=recursive, show_progress=True, loader_kwargs=text_loader_kwargs)
98
+ docs.extend(loader.load())
99
+ if load_pdf:
100
+ loader = DirectoryLoader(input_path, glob="*.pdf", recursive=recursive, show_progress=True, loader_kwargs=text_loader_kwargs)
101
+ docs.extend(loader.load())
102
+ if urls:
103
+ loader = UnstructuredURLLoader(urls=urls)
104
+ docs.extend(loader.load())
105
+
106
+ logger.info(f"Total {len(docs)} files loaded")
107
+
108
+ return docs
109
+
110
+ def get_text_chunks(self, docs: list, chunk_size: int, chunk_overlap: int, meta_data: list = None) -> list:
111
+ """Gets text chunks. If metadata is not None, it will create chunks with metadata elements.
112
+
113
+ Args:
114
+ docs (list): list of documents or texts. If no metadata is passed, this parameter is a list of documents.
115
+ If metadata is passed, this parameter is a list of texts.
116
+ chunk_size (int): chunk size in number of characters
117
+ chunk_overlap (int): chunk overlap in number of characters
118
+ metadata (list, optional): list of metadata in dictionary format. Defaults to None.
119
+
120
+ Returns:
121
+ list: list of documents
122
+ """
123
+
124
+ text_splitter = RecursiveCharacterTextSplitter(
125
+ chunk_size=chunk_size, chunk_overlap=chunk_overlap, length_function=len
126
+ )
127
+
128
+ if meta_data is None:
129
+ logger.info(f"Splitter: splitting documents")
130
+ chunks = text_splitter.split_documents(docs)
131
+ else:
132
+ logger.info(f"Splitter: creating documents with metadata")
133
+ chunks = text_splitter.create_documents(docs, meta_data)
134
+
135
+ logger.info(f"Total {len(chunks)} chunks created")
136
+
137
+ return chunks
138
+
139
+ def get_token_chunks(self, docs: list, chunk_size: int, chunk_overlap: int, tokenizer) -> list:
140
+ """Gets token chunks. If metadata is not None, it will create chunks with metadata elements.
141
+
142
+ Args:
143
+ docs (list): list of documents or texts. If no metadata is passed, this parameter is a list of documents.
144
+ If metadata is passed, this parameter is a list of texts.
145
+ chunk_size (int): chunk size in number of tokens
146
+ chunk_overlap (int): chunk overlap in number of tokens
147
+
148
+ Returns:
149
+ list: list of documents
150
+ """
151
+
152
+ text_splitter = CharacterTextSplitter.from_huggingface_tokenizer(
153
+ tokenizer, chunk_size=chunk_size, chunk_overlap=chunk_overlap
154
+ )
155
+
156
+ logger.info(f"Splitter: splitting documents")
157
+ chunks = text_splitter.split_documents(docs)
158
+
159
+ logger.info(f"Total {len(chunks)} chunks created")
160
+
161
+ return chunks
162
+
163
+
164
+
165
+ def create_vector_store(self, chunks: list, embeddings: HuggingFaceInstructEmbeddings, db_type: str,
166
+ output_db: str = None, collection_name: str = None):
167
+ """Creates a vector store
168
+
169
+ Args:
170
+ chunks (list): list of chunks
171
+ embeddings (HuggingFaceInstructEmbeddings): embedding model
172
+ db_type (str): vector db type
173
+ output_db (str, optional): output path to save the vector db. Defaults to None.
174
+ """
175
+ if collection_name is None:
176
+ collection_name = f"collection_{self.collection_id}"
177
+ logger.info(f'This is the collection name: {collection_name}')
178
+
179
+ if db_type == "faiss":
180
+ vector_store = FAISS.from_documents(
181
+ documents=chunks,
182
+ embedding=embeddings
183
+ )
184
+ if output_db:
185
+ vector_store.save_local(output_db)
186
+
187
+ elif db_type == "chroma":
188
+ if output_db:
189
+ vector_store = Chroma()
190
+ vector_store.delete_collection()
191
+ vector_store = Chroma.from_documents(
192
+ documents=chunks,
193
+ embedding=embeddings,
194
+ persist_directory=output_db,
195
+ collection_name=collection_name
196
+ )
197
+ else:
198
+ vector_store = Chroma()
199
+ vector_store.delete_collection()
200
+ vector_store = Chroma.from_documents(
201
+ documents=chunks,
202
+ embedding=embeddings,
203
+ collection_name=collection_name
204
+ )
205
+ self.vector_collections.add(collection_name)
206
+
207
+ elif db_type == "qdrant":
208
+ if output_db:
209
+ vector_store = Qdrant.from_documents(
210
+ documents=chunks,
211
+ embedding=embeddings,
212
+ path=output_db,
213
+ collection_name="test_collection",
214
+ )
215
+ else:
216
+ vector_store = Qdrant.from_documents(
217
+ documents=chunks,
218
+ embedding=embeddings,
219
+ collection_name="test_collection",
220
+ )
221
+
222
+ logger.info(f"Vector store saved to {output_db}")
223
+
224
+ return vector_store
225
+
226
+ def load_vdb(self, persist_directory, embedding_model, db_type="chroma", collection_name=None):
227
+ if db_type == "faiss":
228
+ vector_store = FAISS.load_local(persist_directory, embedding_model, allow_dangerous_deserialization=True)
229
+ elif db_type == "chroma":
230
+ if collection_name:
231
+ vector_store = Chroma(
232
+ persist_directory=persist_directory,
233
+ embedding_function=embedding_model,
234
+ collection_name=collection_name
235
+ )
236
+ else:
237
+ vector_store = Chroma(
238
+ persist_directory=persist_directory,
239
+ embedding_function=embedding_model
240
+ )
241
+ elif db_type == "qdrant":
242
+ # TODO: Implement Qdrant loading
243
+ pass
244
+ else:
245
+ raise ValueError(f"Unsupported database type: {db_type}")
246
+
247
+ return vector_store
248
+
249
+ def update_vdb(self, chunks: list, embeddings, db_type: str, input_db: str = None,
250
+ output_db: str = None):
251
+
252
+ if db_type == "faiss":
253
+ vector_store = FAISS.load_local(input_db, embeddings, allow_dangerous_deserialization=True)
254
+ new_vector_store = self.create_vector_store(chunks, embeddings, db_type, None)
255
+ vector_store.merge_from(new_vector_store)
256
+ if output_db:
257
+ vector_store.save_local(output_db)
258
+
259
+ elif db_type == "chroma":
260
+ # TODO implement update method for chroma
261
+ pass
262
+ elif db_type == "qdrant":
263
+ # TODO implement update method for qdrant
264
+ pass
265
+
266
+ return vector_store
267
+
268
+ def create_vdb(
269
+ self,
270
+ input_path,
271
+ chunk_size,
272
+ chunk_overlap,
273
+ db_type,
274
+ output_db=None,
275
+ recursive=False,
276
+ tokenizer=None,
277
+ load_txt=True,
278
+ load_pdf=False,
279
+ urls=None,
280
+ embedding_type="cpu",
281
+ batch_size= None,
282
+ coe = None,
283
+ select_expert = None
284
+ ):
285
+
286
+ docs = self.load_files(input_path, recursive=recursive, load_txt=load_txt, load_pdf=load_pdf, urls=urls)
287
+
288
+ if tokenizer is None:
289
+ chunks = self.get_text_chunks(docs, chunk_size, chunk_overlap)
290
+ else:
291
+ chunks = self.get_token_chunks(docs, chunk_size, chunk_overlap, tokenizer)
292
+
293
+ embeddings = APIGateway.load_embedding_model(
294
+ type=embedding_type,
295
+ batch_size=batch_size,
296
+ coe=coe,
297
+ select_expert=select_expert
298
+ )
299
+
300
+ vector_store = self.create_vector_store(chunks, embeddings, db_type, output_db)
301
+
302
+ return vector_store
303
+
304
+
305
+ def dir_path(path):
306
+ if os.path.isdir(path):
307
+ return path
308
+ else:
309
+ raise argparse.ArgumentTypeError(f"readable_dir:{path} is not a valid path")
310
+
311
+
312
+ # Parse the arguments
313
+ def parse_arguments():
314
+ parser = argparse.ArgumentParser(description="Process command line arguments.")
315
+ parser.add_argument("-input_path", type=dir_path, help="path to input directory")
316
+ parser.add_argument("--chunk_size", type=int, help="chunk size for splitting")
317
+ parser.add_argument("--chunk_overlap", type=int, help="chunk overlap for splitting")
318
+ parser.add_argument("-output_path", type=dir_path, help="path to input directory")
319
+
320
+ return parser.parse_args()
321
+
322
+
323
+ if __name__ == "__main__":
324
+ parser = argparse.ArgumentParser(description="Process data with optional chunking")
325
+
326
+ # Required arguments
327
+ parser.add_argument("--input_path", type=str, help="Path to the input directory")
328
+ parser.add_argument("--output_db", type=str, help="Path to the output vectordb")
329
+
330
+ # Optional arguments
331
+ parser.add_argument(
332
+ "--chunk_size", type=int, default=1000, help="Chunk size (default: 1000)"
333
+ )
334
+ parser.add_argument(
335
+ "--chunk_overlap", type=int, default=200, help="Chunk overlap (default: 200)"
336
+ )
337
+ parser.add_argument(
338
+ "--db_type",
339
+ type=str,
340
+ default="faiss",
341
+ help="Type of vector store (default: faiss)",
342
+ )
343
+ args = parser.parse_args()
344
+
345
+ vectordb = VectorDb()
346
+
347
+ vectordb.create_vdb(
348
+ args.input_path,
349
+ args.output_db,
350
+ args.chunk_size,
351
+ args.chunk_overlap,
352
+ args.db_type,
353
+ )
utils/visual/env_utils.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import netrc
2
+ import os
3
+ from typing import List, Optional, Tuple
4
+
5
+ import streamlit as st
6
+
7
+
8
+ def initialize_env_variables(prod_mode: bool = False, additional_env_vars: Optional[List[str]] = None) -> None:
9
+ if additional_env_vars is None:
10
+ additional_env_vars = []
11
+
12
+ if not prod_mode:
13
+ # In non-prod mode, prioritize environment variables
14
+ st.session_state.SAMBANOVA_API_KEY = os.environ.get(
15
+ 'SAMBANOVA_API_KEY', st.session_state.get('SMABANOVA_API_KEY', '')
16
+ )
17
+ for var in additional_env_vars:
18
+ st.session_state[var] = os.environ.get(var, st.session_state.get(var, ''))
19
+ else:
20
+ # In prod mode, only use session state
21
+ if 'SAMBANOVA_API_KEY' not in st.session_state:
22
+ st.session_state.SAMBANOVA_API_KEY = ''
23
+ for var in additional_env_vars:
24
+ if var not in st.session_state:
25
+ st.session_state[var] = ''
26
+
27
+
28
+ def set_env_variables(api_key, additional_vars=None, prod_mode=False):
29
+ st.session_state.SAMBANOVA_API_KEY = api_key
30
+ if additional_vars:
31
+ for key, value in additional_vars.items():
32
+ st.session_state[key] = value
33
+ if not prod_mode:
34
+ # In non-prod mode, also set environment variables
35
+ os.environ['SAMBANOVA_API_KEY'] = api_key
36
+ if additional_vars:
37
+ for key, value in additional_vars.items():
38
+ os.environ[key] = value
39
+
40
+
41
+ def env_input_fields(additional_env_vars=None) -> Tuple[str, str]:
42
+ if additional_env_vars is None:
43
+ additional_env_vars = []
44
+
45
+ api_key = st.text_input('Sambanova API Key', value=st.session_state.SAMBANOVA_API_KEY, type='password')
46
+
47
+ additional_vars = {}
48
+ for var in additional_env_vars:
49
+ additional_vars[var] = st.text_input(f'{var}', value=st.session_state.get(var, ''), type='password')
50
+
51
+ return api_key, additional_vars
52
+
53
+
54
+ def are_credentials_set(additional_env_vars=None) -> bool:
55
+ if additional_env_vars is None:
56
+ additional_env_vars = []
57
+
58
+ base_creds_set = bool(st.session_state.SAMBANOVA_API_KEY)
59
+ additional_creds_set = all(bool(st.session_state.get(var, '')) for var in additional_env_vars)
60
+
61
+ return base_creds_set and additional_creds_set
62
+
63
+
64
+ def save_credentials(api_key, additional_vars=None, prod_mode=False) -> str:
65
+ set_env_variables(api_key, additional_vars, prod_mode)
66
+ return 'Credentials saved successfully!'
67
+
68
+
69
+ def get_wandb_key():
70
+ # Check for WANDB_API_KEY in environment variables
71
+ env_wandb_api_key = os.getenv('WANDB_API_KEY')
72
+
73
+ # Check for WANDB_API_KEY in ~/.netrc
74
+ try:
75
+ netrc_path = os.path.expanduser('~/.netrc')
76
+ netrc_data = netrc.netrc(netrc_path)
77
+ netrc_wandb_api_key = netrc_data.authenticators('api.wandb.ai')
78
+ except (FileNotFoundError, netrc.NetrcParseError):
79
+ netrc_wandb_api_key = None
80
+
81
+ # If both are set, handle the conflict
82
+ if env_wandb_api_key and netrc_wandb_api_key:
83
+ print('WANDB_API_KEY is set in both the environment and ~/.netrc. Prioritizing environment variable.')
84
+ # Optionally, you can choose to remove one of them, here we remove the env variable
85
+ del os.environ['WANDB_API_KEY'] # Remove from environment to prioritize ~/.netrc
86
+ return netrc_wandb_api_key[2] if netrc_wandb_api_key else None # Return the key from .netrc
87
+
88
+ # Return the key from environment if available, otherwise from .netrc
89
+ if env_wandb_api_key:
90
+ return env_wandb_api_key
91
+ elif netrc_wandb_api_key:
92
+ return netrc_wandb_api_key[2] if netrc_wandb_api_key else None
93
+
94
+ # If neither is set, return None
95
+ return None