diff --git "a/src/gpt_langchain.py" "b/src/gpt_langchain.py" new file mode 100644--- /dev/null +++ "b/src/gpt_langchain.py" @@ -0,0 +1,5443 @@ +import ast +import asyncio +import copy +import functools +import glob +import gzip +import inspect +import json +import os +import pathlib +import pickle +import shutil +import subprocess +import tempfile +import time +import traceback +import types +import typing +import urllib.error +import uuid +import zipfile +from collections import defaultdict +from datetime import datetime +from functools import reduce +from operator import concat +import filelock +import tabulate +import yaml + +from joblib import delayed +from langchain.callbacks import streaming_stdout +from langchain.embeddings import HuggingFaceInstructEmbeddings +from langchain.llms.huggingface_pipeline import VALID_TASKS +from langchain.llms.utils import enforce_stop_tokens +from langchain.schema import LLMResult, Generation +from langchain.tools import PythonREPLTool +from langchain.tools.json.tool import JsonSpec +from tqdm import tqdm + +from utils import wrapped_partial, EThread, import_matplotlib, sanitize_filename, makedirs, get_url, flatten_list, \ + get_device, ProgressParallel, remove, hash_file, clear_torch_cache, NullContext, get_hf_server, FakeTokenizer, \ + have_libreoffice, have_arxiv, have_playwright, have_selenium, have_tesseract, have_doctr, have_pymupdf, set_openai, \ + get_list_or_str, have_pillow, only_selenium, only_playwright, only_unstructured_urls, get_sha, get_short_name, \ + get_accordion, have_jq, get_doc, get_source, have_chromamigdb, get_token_count, reverse_ucurve_list +from enums import DocumentSubset, no_lora_str, model_token_mapping, source_prefix, source_postfix, non_query_commands, \ + LangChainAction, LangChainMode, DocumentChoice, LangChainTypes, font_size, head_acc, super_source_prefix, \ + super_source_postfix, langchain_modes_intrinsic, get_langchain_prompts, LangChainAgent +from evaluate_params import gen_hyper, gen_hyper0 +from gen import get_model, SEED, get_limited_prompt, get_docs_tokens +from prompter import non_hf_types, PromptType, Prompter +from src.serpapi import H2OSerpAPIWrapper +from utils_langchain import StreamingGradioCallbackHandler, _chunk_sources, _add_meta, add_parser, fix_json_meta + +import_matplotlib() + +import numpy as np +import pandas as pd +import requests +from langchain.chains.qa_with_sources import load_qa_with_sources_chain +# , GCSDirectoryLoader, GCSFileLoader +# , OutlookMessageLoader # GPL3 +# ImageCaptionLoader, # use our own wrapper +# ReadTheDocsLoader, # no special file, some path, so have to give as special option +from langchain.document_loaders import PyPDFLoader, TextLoader, CSVLoader, PythonLoader, TomlLoader, \ + UnstructuredURLLoader, UnstructuredHTMLLoader, UnstructuredWordDocumentLoader, UnstructuredMarkdownLoader, \ + EverNoteLoader, UnstructuredEmailLoader, UnstructuredODTLoader, UnstructuredPowerPointLoader, \ + UnstructuredEPubLoader, UnstructuredImageLoader, UnstructuredRTFLoader, ArxivLoader, UnstructuredPDFLoader, \ + UnstructuredExcelLoader, JSONLoader +from langchain.text_splitter import Language +from langchain.chains.question_answering import load_qa_chain +from langchain.docstore.document import Document +from langchain import PromptTemplate, HuggingFaceTextGenInference, HuggingFacePipeline +from langchain.vectorstores import Chroma +from chromamig import ChromaMig + + +def split_list(input_list, split_size): + for i in range(0, len(input_list), split_size): + yield input_list[i:i + split_size] + + +def get_db(sources, use_openai_embedding=False, db_type='faiss', + persist_directory=None, load_db_if_exists=True, + langchain_mode='notset', + langchain_mode_paths={}, + langchain_mode_types={}, + collection_name=None, + hf_embedding_model=None, + migrate_embedding_model=False, + auto_migrate_db=False, + n_jobs=-1): + if not sources: + return None + user_path = langchain_mode_paths.get(langchain_mode) + if persist_directory is None: + langchain_type = langchain_mode_types.get(langchain_mode, LangChainTypes.EITHER.value) + persist_directory, langchain_type = get_persist_directory(langchain_mode, langchain_type=langchain_type) + langchain_mode_types[langchain_mode] = langchain_type + assert hf_embedding_model is not None + + # get freshly-determined embedding model + embedding = get_embedding(use_openai_embedding, hf_embedding_model=hf_embedding_model) + assert collection_name is not None or langchain_mode != 'notset' + if collection_name is None: + collection_name = langchain_mode.replace(' ', '_') + + # Create vector database + if db_type == 'faiss': + from langchain.vectorstores import FAISS + db = FAISS.from_documents(sources, embedding) + elif db_type == 'weaviate': + import weaviate + from weaviate.embedded import EmbeddedOptions + from langchain.vectorstores import Weaviate + + if os.getenv('WEAVIATE_URL', None): + client = _create_local_weaviate_client() + else: + client = weaviate.Client( + embedded_options=EmbeddedOptions(persistence_data_path=persist_directory) + ) + index_name = collection_name.capitalize() + db = Weaviate.from_documents(documents=sources, embedding=embedding, client=client, by_text=False, + index_name=index_name) + elif db_type in ['chroma', 'chroma_old']: + assert persist_directory is not None + # use_base already handled when making persist_directory, unless was passed into get_db() + makedirs(persist_directory, exist_ok=True) + + # see if already actually have persistent db, and deal with possible changes in embedding + db, use_openai_embedding, hf_embedding_model = \ + get_existing_db(None, persist_directory, load_db_if_exists, db_type, + use_openai_embedding, + langchain_mode, langchain_mode_paths, langchain_mode_types, + hf_embedding_model, migrate_embedding_model, auto_migrate_db, + verbose=False, + n_jobs=n_jobs) + if db is None: + import logging + logging.getLogger("chromadb").setLevel(logging.ERROR) + if db_type == 'chroma': + from chromadb.config import Settings + settings_extra_kwargs = dict(is_persistent=True) + else: + from chromamigdb.config import Settings + settings_extra_kwargs = dict(chroma_db_impl="duckdb+parquet") + client_settings = Settings(anonymized_telemetry=False, + persist_directory=persist_directory, + **settings_extra_kwargs) + if n_jobs in [None, -1]: + n_jobs = int(os.getenv('OMP_NUM_THREADS', str(os.cpu_count() // 2))) + num_threads = max(1, min(n_jobs, 8)) + else: + num_threads = max(1, n_jobs) + collection_metadata = {"hnsw:num_threads": num_threads} + from_kwargs = dict(embedding=embedding, + persist_directory=persist_directory, + collection_name=collection_name, + client_settings=client_settings, + collection_metadata=collection_metadata) + if db_type == 'chroma': + import chromadb + api = chromadb.PersistentClient(path=persist_directory) + max_batch_size = api._producer.max_batch_size + sources_batches = split_list(sources, max_batch_size) + for sources_batch in sources_batches: + db = Chroma.from_documents(documents=sources_batch, **from_kwargs) + db.persist() + else: + db = ChromaMig.from_documents(documents=sources, **from_kwargs) + clear_embedding(db) + save_embed(db, use_openai_embedding, hf_embedding_model) + else: + # then just add + # doesn't check or change embedding, just saves it in case not saved yet, after persisting + db, num_new_sources, new_sources_metadata = add_to_db(db, sources, db_type=db_type, + use_openai_embedding=use_openai_embedding, + hf_embedding_model=hf_embedding_model) + else: + raise RuntimeError("No such db_type=%s" % db_type) + + # once here, db is not changing and embedding choices in calling functions does not matter + return db + + +def _get_unique_sources_in_weaviate(db): + batch_size = 100 + id_source_list = [] + result = db._client.data_object.get(class_name=db._index_name, limit=batch_size) + + while result['objects']: + id_source_list += [(obj['id'], obj['properties']['source']) for obj in result['objects']] + last_id = id_source_list[-1][0] + result = db._client.data_object.get(class_name=db._index_name, limit=batch_size, after=last_id) + + unique_sources = {source for _, source in id_source_list} + return unique_sources + + +def del_from_db(db, sources, db_type=None): + if db_type in ['chroma', 'chroma_old'] and db is not None: + # sources should be list of x.metadata['source'] from document metadatas + if isinstance(sources, str): + sources = [sources] + else: + assert isinstance(sources, (list, tuple, types.GeneratorType)) + metadatas = set(sources) + client_collection = db._client.get_collection(name=db._collection.name, + embedding_function=db._collection._embedding_function) + for source in metadatas: + meta = dict(source=source) + try: + client_collection.delete(where=meta) + except KeyError: + pass + + +def add_to_db(db, sources, db_type='faiss', + avoid_dup_by_file=False, + avoid_dup_by_content=True, + use_openai_embedding=False, + hf_embedding_model=None): + assert hf_embedding_model is not None + num_new_sources = len(sources) + if not sources: + return db, num_new_sources, [] + if db_type == 'faiss': + db.add_documents(sources) + elif db_type == 'weaviate': + # FIXME: only control by file name, not hash yet + if avoid_dup_by_file or avoid_dup_by_content: + unique_sources = _get_unique_sources_in_weaviate(db) + sources = [x for x in sources if x.metadata['source'] not in unique_sources] + num_new_sources = len(sources) + if num_new_sources == 0: + return db, num_new_sources, [] + db.add_documents(documents=sources) + elif db_type in ['chroma', 'chroma_old']: + collection = get_documents(db) + # files we already have: + metadata_files = set([x['source'] for x in collection['metadatas']]) + if avoid_dup_by_file: + # Too weak in case file changed content, assume parent shouldn't pass true for this for now + raise RuntimeError("Not desired code path") + if avoid_dup_by_content: + # look at hash, instead of page_content + # migration: If no hash previously, avoid updating, + # since don't know if need to update and may be expensive to redo all unhashed files + metadata_hash_ids = set( + [x['hashid'] for x in collection['metadatas'] if 'hashid' in x and x['hashid'] not in ["None", None]]) + # avoid sources with same hash + sources = [x for x in sources if x.metadata.get('hashid') not in metadata_hash_ids] + num_nohash = len([x for x in sources if not x.metadata.get('hashid')]) + print("Found %s new sources (%d have no hash in original source," + " so have to reprocess for migration to sources with hash)" % (len(sources), num_nohash), flush=True) + # get new file names that match existing file names. delete existing files we are overridding + dup_metadata_files = set([x.metadata['source'] for x in sources if x.metadata['source'] in metadata_files]) + print("Removing %s duplicate files from db because ingesting those as new documents" % len( + dup_metadata_files), flush=True) + client_collection = db._client.get_collection(name=db._collection.name, + embedding_function=db._collection._embedding_function) + for dup_file in dup_metadata_files: + dup_file_meta = dict(source=dup_file) + try: + client_collection.delete(where=dup_file_meta) + except KeyError: + pass + num_new_sources = len(sources) + if num_new_sources == 0: + return db, num_new_sources, [] + if hasattr(db, '_persist_directory'): + print("Existing db, adding to %s" % db._persist_directory, flush=True) + # chroma only + lock_file = get_db_lock_file(db) + context = filelock.FileLock + else: + lock_file = None + context = NullContext + with context(lock_file): + # this is place where add to db, but others maybe accessing db, so lock access. + # else see RuntimeError: Index seems to be corrupted or unsupported + import chromadb + api = chromadb.PersistentClient(path=db._persist_directory) + max_batch_size = api._producer.max_batch_size + sources_batches = split_list(sources, max_batch_size) + for sources_batch in sources_batches: + db.add_documents(documents=sources_batch) + db.persist() + clear_embedding(db) + # save here is for migration, in case old db directory without embedding saved + save_embed(db, use_openai_embedding, hf_embedding_model) + else: + raise RuntimeError("No such db_type=%s" % db_type) + + new_sources_metadata = [x.metadata for x in sources] + + return db, num_new_sources, new_sources_metadata + + +def create_or_update_db(db_type, persist_directory, collection_name, + user_path, langchain_type, + sources, use_openai_embedding, add_if_exists, verbose, + hf_embedding_model, migrate_embedding_model, auto_migrate_db, + n_jobs=-1): + if not os.path.isdir(persist_directory) or not add_if_exists: + if os.path.isdir(persist_directory): + if verbose: + print("Removing %s" % persist_directory, flush=True) + remove(persist_directory) + if verbose: + print("Generating db", flush=True) + if db_type == 'weaviate': + import weaviate + from weaviate.embedded import EmbeddedOptions + + if os.getenv('WEAVIATE_URL', None): + client = _create_local_weaviate_client() + else: + client = weaviate.Client( + embedded_options=EmbeddedOptions(persistence_data_path=persist_directory) + ) + + index_name = collection_name.replace(' ', '_').capitalize() + if client.schema.exists(index_name) and not add_if_exists: + client.schema.delete_class(index_name) + if verbose: + print("Removing %s" % index_name, flush=True) + elif db_type in ['chroma', 'chroma_old']: + pass + + if not add_if_exists: + if verbose: + print("Generating db", flush=True) + else: + if verbose: + print("Loading and updating db", flush=True) + + db = get_db(sources, + use_openai_embedding=use_openai_embedding, + db_type=db_type, + persist_directory=persist_directory, + langchain_mode=collection_name, + langchain_mode_paths={collection_name: user_path}, + langchain_mode_types={collection_name: langchain_type}, + hf_embedding_model=hf_embedding_model, + migrate_embedding_model=migrate_embedding_model, + auto_migrate_db=auto_migrate_db, + n_jobs=n_jobs) + + return db + + +from langchain.embeddings import FakeEmbeddings + + +class H2OFakeEmbeddings(FakeEmbeddings): + """Fake embedding model, but constant instead of random""" + + size: int + """The size of the embedding vector.""" + + def _get_embedding(self) -> typing.List[float]: + return [1] * self.size + + def embed_documents(self, texts: typing.List[str]) -> typing.List[typing.List[float]]: + return [self._get_embedding() for _ in texts] + + def embed_query(self, text: str) -> typing.List[float]: + return self._get_embedding() + + +def get_embedding(use_openai_embedding, hf_embedding_model=None, preload=False): + assert hf_embedding_model is not None + # Get embedding model + if use_openai_embedding: + assert os.getenv("OPENAI_API_KEY") is not None, "Set ENV OPENAI_API_KEY" + from langchain.embeddings import OpenAIEmbeddings + embedding = OpenAIEmbeddings(disallowed_special=()) + elif hf_embedding_model == 'fake': + embedding = H2OFakeEmbeddings(size=1) + else: + if isinstance(hf_embedding_model, str): + pass + elif isinstance(hf_embedding_model, dict): + # embedding itself preloaded globally + return hf_embedding_model['model'] + else: + # object + return hf_embedding_model + # to ensure can fork without deadlock + from langchain.embeddings import HuggingFaceEmbeddings + + device, torch_dtype, context_class = get_device_dtype() + model_kwargs = dict(device=device) + if 'instructor' in hf_embedding_model: + encode_kwargs = {'normalize_embeddings': True} + embedding = HuggingFaceInstructEmbeddings(model_name=hf_embedding_model, + model_kwargs=model_kwargs, + encode_kwargs=encode_kwargs) + else: + embedding = HuggingFaceEmbeddings(model_name=hf_embedding_model, model_kwargs=model_kwargs) + embedding.client.preload = preload + return embedding + + +def get_answer_from_sources(chain, sources, question): + return chain( + { + "input_documents": sources, + "question": question, + }, + return_only_outputs=True, + )["output_text"] + + +"""Wrapper around Huggingface text generation inference API.""" +from functools import partial +from typing import Any, Dict, List, Optional, Set, Iterable + +from pydantic import Extra, Field, root_validator + +from langchain.callbacks.manager import CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun +from langchain.llms.base import LLM + + +class GradioInference(LLM): + """ + Gradio generation inference API. + """ + inference_server_url: str = "" + + temperature: float = 0.8 + top_p: Optional[float] = 0.95 + top_k: Optional[int] = None + num_beams: Optional[int] = 1 + max_new_tokens: int = 512 + min_new_tokens: int = 1 + early_stopping: bool = False + max_time: int = 180 + repetition_penalty: Optional[float] = None + num_return_sequences: Optional[int] = 1 + do_sample: bool = False + chat_client: bool = False + + return_full_text: bool = False + stream_output: bool = False + sanitize_bot_response: bool = False + + prompter: Any = None + context: Any = '' + iinput: Any = '' + client: Any = None + tokenizer: Any = None + + system_prompt: Any = None + visible_models: Any = None + h2ogpt_key: Any = None + + count_input_tokens: Any = 0 + count_output_tokens: Any = 0 + + min_max_new_tokens: Any = 256 + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Validate that python package exists in environment.""" + + try: + if values['client'] is None: + import gradio_client + values["client"] = gradio_client.Client( + values["inference_server_url"] + ) + except ImportError: + raise ImportError( + "Could not import gradio_client python package. " + "Please install it with `pip install gradio_client`." + ) + return values + + @property + def _llm_type(self) -> str: + """Return type of llm.""" + return "gradio_inference" + + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> str: + # NOTE: prompt here has no prompt_type (e.g. human: bot:) prompt injection, + # so server should get prompt_type or '', not plain + # This is good, so gradio server can also handle stopping.py conditions + # this is different than TGI server that uses prompter to inject prompt_type prompting + stream_output = self.stream_output + gr_client = self.client + client_langchain_mode = 'Disabled' + client_add_chat_history_to_context = True + client_add_search_to_context = False + client_chat_conversation = [] + client_langchain_action = LangChainAction.QUERY.value + client_langchain_agents = [] + top_k_docs = 1 + chunk = True + chunk_size = 512 + client_kwargs = dict(instruction=prompt if self.chat_client else '', # only for chat=True + iinput=self.iinput if self.chat_client else '', # only for chat=True + context=self.context, + # streaming output is supported, loops over and outputs each generation in streaming mode + # but leave stream_output=False for simple input/output mode + stream_output=stream_output, + prompt_type=self.prompter.prompt_type, + prompt_dict='', + + temperature=self.temperature, + top_p=self.top_p, + top_k=self.top_k, + num_beams=self.num_beams, + max_new_tokens=self.max_new_tokens, + min_new_tokens=self.min_new_tokens, + early_stopping=self.early_stopping, + max_time=self.max_time, + repetition_penalty=self.repetition_penalty, + num_return_sequences=self.num_return_sequences, + do_sample=self.do_sample, + chat=self.chat_client, + + instruction_nochat=prompt if not self.chat_client else '', + iinput_nochat=self.iinput if not self.chat_client else '', + langchain_mode=client_langchain_mode, + add_chat_history_to_context=client_add_chat_history_to_context, + langchain_action=client_langchain_action, + langchain_agents=client_langchain_agents, + top_k_docs=top_k_docs, + chunk=chunk, + chunk_size=chunk_size, + document_subset=DocumentSubset.Relevant.name, + document_choice=[DocumentChoice.ALL.value], + pre_prompt_query=None, + prompt_query=None, + pre_prompt_summary=None, + prompt_summary=None, + system_prompt=self.system_prompt, + image_loaders=None, # don't need to further do doc specific things + pdf_loaders=None, # don't need to further do doc specific things + url_loaders=None, # don't need to further do doc specific things + jq_schema=None, # don't need to further do doc specific things + visible_models=self.visible_models, + h2ogpt_key=self.h2ogpt_key, + add_search_to_context=client_add_search_to_context, + chat_conversation=client_chat_conversation, + text_context_list=None, + docs_ordering_type=None, + min_max_new_tokens=self.min_max_new_tokens, + ) + api_name = '/submit_nochat_api' # NOTE: like submit_nochat but stable API for string dict passing + self.count_input_tokens += self.get_num_tokens(prompt) + + if not stream_output: + res = gr_client.predict(str(dict(client_kwargs)), api_name=api_name) + res_dict = ast.literal_eval(res) + text = res_dict['response'] + ret = self.prompter.get_response(prompt + text, prompt=prompt, + sanitize_bot_response=self.sanitize_bot_response) + self.count_output_tokens += self.get_num_tokens(ret) + return ret + else: + text_callback = None + if run_manager: + text_callback = partial( + run_manager.on_llm_new_token, verbose=self.verbose + ) + + job = gr_client.submit(str(dict(client_kwargs)), api_name=api_name) + text0 = '' + while not job.done(): + if job.communicator.job.latest_status.code.name == 'FINISHED': + break + e = job.future._exception + if e is not None: + break + outputs_list = job.communicator.job.outputs + if outputs_list: + res = job.communicator.job.outputs[-1] + res_dict = ast.literal_eval(res) + text = res_dict['response'] + text = self.prompter.get_response(prompt + text, prompt=prompt, + sanitize_bot_response=self.sanitize_bot_response) + # FIXME: derive chunk from full for now + text_chunk = text[len(text0):] + if not text_chunk: + continue + # save old + text0 = text + + if text_callback: + text_callback(text_chunk) + + time.sleep(0.01) + + # ensure get last output to avoid race + res_all = job.outputs() + if len(res_all) > 0: + res = res_all[-1] + res_dict = ast.literal_eval(res) + text = res_dict['response'] + # FIXME: derive chunk from full for now + else: + # go with old if failure + text = text0 + text_chunk = text[len(text0):] + if text_callback: + text_callback(text_chunk) + ret = self.prompter.get_response(prompt + text, prompt=prompt, + sanitize_bot_response=self.sanitize_bot_response) + self.count_output_tokens += self.get_num_tokens(ret) + return ret + + def get_token_ids(self, text: str) -> List[int]: + return self.tokenizer.encode(text) + # avoid base method that is not aware of how to properly tokenize (uses GPT2) + # return _get_token_ids_default_method(text) + + +class H2OHuggingFaceTextGenInference(HuggingFaceTextGenInference): + max_new_tokens: int = 512 + do_sample: bool = False + top_k: Optional[int] = None + top_p: Optional[float] = 0.95 + typical_p: Optional[float] = 0.95 + temperature: float = 0.8 + repetition_penalty: Optional[float] = None + return_full_text: bool = False + stop_sequences: List[str] = Field(default_factory=list) + seed: Optional[int] = None + inference_server_url: str = "" + timeout: int = 300 + headers: dict = None + stream_output: bool = False + sanitize_bot_response: bool = False + prompter: Any = None + context: Any = '' + iinput: Any = '' + tokenizer: Any = None + async_sem: Any = None + count_input_tokens: Any = 0 + count_output_tokens: Any = 0 + + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> str: + if stop is None: + stop = self.stop_sequences.copy() + else: + stop += self.stop_sequences.copy() + stop_tmp = stop.copy() + stop = [] + [stop.append(x) for x in stop_tmp if x not in stop] + + # HF inference server needs control over input tokens + assert self.tokenizer is not None + from h2oai_pipeline import H2OTextGenerationPipeline + prompt, num_prompt_tokens = H2OTextGenerationPipeline.limit_prompt(prompt, self.tokenizer) + + # NOTE: TGI server does not add prompting, so must do here + data_point = dict(context=self.context, instruction=prompt, input=self.iinput) + prompt = self.prompter.generate_prompt(data_point) + self.count_input_tokens += self.get_num_tokens(prompt) + + gen_server_kwargs = dict(do_sample=self.do_sample, + stop_sequences=stop, + max_new_tokens=self.max_new_tokens, + top_k=self.top_k, + top_p=self.top_p, + typical_p=self.typical_p, + temperature=self.temperature, + repetition_penalty=self.repetition_penalty, + return_full_text=self.return_full_text, + seed=self.seed, + ) + gen_server_kwargs.update(kwargs) + + # lower bound because client is re-used if multi-threading + self.client.timeout = max(300, self.timeout) + + if not self.stream_output: + res = self.client.generate( + prompt, + **gen_server_kwargs, + ) + if self.return_full_text: + gen_text = res.generated_text[len(prompt):] + else: + gen_text = res.generated_text + # remove stop sequences from the end of the generated text + for stop_seq in stop: + if stop_seq in gen_text: + gen_text = gen_text[:gen_text.index(stop_seq)] + text = prompt + gen_text + text = self.prompter.get_response(text, prompt=prompt, + sanitize_bot_response=self.sanitize_bot_response) + else: + text_callback = None + if run_manager: + text_callback = partial( + run_manager.on_llm_new_token, verbose=self.verbose + ) + # parent handler of streamer expects to see prompt first else output="" and lose if prompt=None in prompter + if text_callback: + text_callback(prompt) + text = "" + # Note: Streaming ignores return_full_text=True + for response in self.client.generate_stream(prompt, **gen_server_kwargs): + text_chunk = response.token.text + text += text_chunk + text = self.prompter.get_response(prompt + text, prompt=prompt, + sanitize_bot_response=self.sanitize_bot_response) + # stream part + is_stop = False + for stop_seq in stop: + if stop_seq in text_chunk: + is_stop = True + break + if is_stop: + break + if not response.token.special: + if text_callback: + text_callback(text_chunk) + self.count_output_tokens += self.get_num_tokens(text) + return text + + async def _acall( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> str: + # print("acall", flush=True) + if stop is None: + stop = self.stop_sequences.copy() + else: + stop += self.stop_sequences.copy() + stop_tmp = stop.copy() + stop = [] + [stop.append(x) for x in stop_tmp if x not in stop] + + # HF inference server needs control over input tokens + assert self.tokenizer is not None + from h2oai_pipeline import H2OTextGenerationPipeline + prompt, num_prompt_tokens = H2OTextGenerationPipeline.limit_prompt(prompt, self.tokenizer) + + # NOTE: TGI server does not add prompting, so must do here + data_point = dict(context=self.context, instruction=prompt, input=self.iinput) + prompt = self.prompter.generate_prompt(data_point) + + gen_text = await super()._acall(prompt, stop=stop, run_manager=run_manager, **kwargs) + + # remove stop sequences from the end of the generated text + for stop_seq in stop: + if stop_seq in gen_text: + gen_text = gen_text[:gen_text.index(stop_seq)] + text = prompt + gen_text + text = self.prompter.get_response(text, prompt=prompt, + sanitize_bot_response=self.sanitize_bot_response) + # print("acall done", flush=True) + return text + + async def _agenerate( + self, + prompts: List[str], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> LLMResult: + """Run the LLM on the given prompt and input.""" + generations = [] + new_arg_supported = inspect.signature(self._acall).parameters.get("run_manager") + self.count_input_tokens += sum([self.get_num_tokens(prompt) for prompt in prompts]) + tasks = [ + asyncio.ensure_future(self._agenerate_one(prompt, stop=stop, run_manager=run_manager, + new_arg_supported=new_arg_supported, **kwargs)) + for prompt in prompts + ] + texts = await asyncio.gather(*tasks) + self.count_output_tokens += sum([self.get_num_tokens(text) for text in texts]) + [generations.append([Generation(text=text)]) for text in texts] + return LLMResult(generations=generations) + + async def _agenerate_one( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + new_arg_supported=None, + **kwargs: Any, + ) -> str: + async with self.async_sem: # semaphore limits num of simultaneous downloads + return await self._acall(prompt, stop=stop, run_manager=run_manager, **kwargs) \ + if new_arg_supported else \ + await self._acall(prompt, stop=stop, **kwargs) + + def get_token_ids(self, text: str) -> List[int]: + return self.tokenizer.encode(text) + # avoid base method that is not aware of how to properly tokenize (uses GPT2) + # return _get_token_ids_default_method(text) + + +from langchain.chat_models import ChatOpenAI, AzureChatOpenAI +from langchain.llms import OpenAI, AzureOpenAI, Replicate +from langchain.llms.openai import _streaming_response_template, completion_with_retry, _update_response, \ + update_token_usage + + +class H2OOpenAI(OpenAI): + """ + New class to handle vLLM's use of OpenAI, no vllm_chat supported, so only need here + Handles prompting that OpenAI doesn't need, stopping as well + """ + stop_sequences: Any = None + sanitize_bot_response: bool = False + prompter: Any = None + context: Any = '' + iinput: Any = '' + tokenizer: Any = None + + @classmethod + def _all_required_field_names(cls) -> Set: + _all_required_field_names = super(OpenAI, cls)._all_required_field_names() + _all_required_field_names.update( + {'top_p', 'frequency_penalty', 'presence_penalty', 'stop_sequences', 'sanitize_bot_response', 'prompter', + 'tokenizer', 'logit_bias'}) + return _all_required_field_names + + def _generate( + self, + prompts: List[str], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> LLMResult: + stop_tmp = self.stop_sequences if not stop else self.stop_sequences + stop + stop = [] + [stop.append(x) for x in stop_tmp if x not in stop] + + # HF inference server needs control over input tokens + assert self.tokenizer is not None + from h2oai_pipeline import H2OTextGenerationPipeline + for prompti, prompt in enumerate(prompts): + prompt, num_prompt_tokens = H2OTextGenerationPipeline.limit_prompt(prompt, self.tokenizer) + # NOTE: OpenAI/vLLM server does not add prompting, so must do here + data_point = dict(context=self.context, instruction=prompt, input=self.iinput) + prompt = self.prompter.generate_prompt(data_point) + prompts[prompti] = prompt + + params = self._invocation_params + params = {**params, **kwargs} + sub_prompts = self.get_sub_prompts(params, prompts, stop) + choices = [] + token_usage: Dict[str, int] = {} + # Get the token usage from the response. + # Includes prompt, completion, and total tokens used. + _keys = {"completion_tokens", "prompt_tokens", "total_tokens"} + text = '' + for _prompts in sub_prompts: + if self.streaming: + text_with_prompt = "" + prompt = _prompts[0] + if len(_prompts) > 1: + raise ValueError("Cannot stream results with multiple prompts.") + params["stream"] = True + response = _streaming_response_template() + first = True + for stream_resp in completion_with_retry( + self, prompt=_prompts, **params + ): + if first: + stream_resp["choices"][0]["text"] = prompt + stream_resp["choices"][0]["text"] + first = False + text_chunk = stream_resp["choices"][0]["text"] + text_with_prompt += text_chunk + text = self.prompter.get_response(text_with_prompt, prompt=prompt, + sanitize_bot_response=self.sanitize_bot_response) + if run_manager: + run_manager.on_llm_new_token( + text_chunk, + verbose=self.verbose, + logprobs=stream_resp["choices"][0]["logprobs"], + ) + _update_response(response, stream_resp) + choices.extend(response["choices"]) + else: + response = completion_with_retry(self, prompt=_prompts, **params) + choices.extend(response["choices"]) + if not self.streaming: + # Can't update token usage if streaming + update_token_usage(_keys, response, token_usage) + if self.streaming: + choices[0]['text'] = text + return self.create_llm_result(choices, prompts, token_usage) + + def get_token_ids(self, text: str) -> List[int]: + if self.tokenizer is not None: + return self.tokenizer.encode(text) + else: + # OpenAI uses tiktoken + return super().get_token_ids(text) + + +class H2OReplicate(Replicate): + stop_sequences: Any = None + sanitize_bot_response: bool = False + prompter: Any = None + context: Any = '' + iinput: Any = '' + tokenizer: Any = None + + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> str: + """Call to replicate endpoint.""" + stop_tmp = self.stop_sequences if not stop else self.stop_sequences + stop + stop = [] + [stop.append(x) for x in stop_tmp if x not in stop] + + # HF inference server needs control over input tokens + assert self.tokenizer is not None + from h2oai_pipeline import H2OTextGenerationPipeline + prompt, num_prompt_tokens = H2OTextGenerationPipeline.limit_prompt(prompt, self.tokenizer) + # Note Replicate handles the prompting of the specific model + return super()._call(prompt, stop=stop, run_manager=run_manager, **kwargs) + + def get_token_ids(self, text: str) -> List[int]: + return self.tokenizer.encode(text) + # avoid base method that is not aware of how to properly tokenize (uses GPT2) + # return _get_token_ids_default_method(text) + + +class H2OChatOpenAI(ChatOpenAI): + @classmethod + def _all_required_field_names(cls) -> Set: + _all_required_field_names = super(ChatOpenAI, cls)._all_required_field_names() + _all_required_field_names.update({'top_p', 'frequency_penalty', 'presence_penalty', 'logit_bias'}) + return _all_required_field_names + + +class H2OAzureChatOpenAI(AzureChatOpenAI): + @classmethod + def _all_required_field_names(cls) -> Set: + _all_required_field_names = super(AzureChatOpenAI, cls)._all_required_field_names() + _all_required_field_names.update({'top_p', 'frequency_penalty', 'presence_penalty', 'logit_bias'}) + return _all_required_field_names + + +class H2OAzureOpenAI(AzureOpenAI): + @classmethod + def _all_required_field_names(cls) -> Set: + _all_required_field_names = super(AzureOpenAI, cls)._all_required_field_names() + _all_required_field_names.update({'top_p', 'frequency_penalty', 'presence_penalty', 'logit_bias'}) + return _all_required_field_names + + +class H2OHuggingFacePipeline(HuggingFacePipeline): + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> str: + response = self.pipeline(prompt, stop=stop) + if self.pipeline.task == "text-generation": + # Text generation return includes the starter text. + text = response[0]["generated_text"][len(prompt):] + elif self.pipeline.task == "text2text-generation": + text = response[0]["generated_text"] + elif self.pipeline.task == "summarization": + text = response[0]["summary_text"] + else: + raise ValueError( + f"Got invalid task {self.pipeline.task}, " + f"currently only {VALID_TASKS} are supported" + ) + if stop: + # This is a bit hacky, but I can't figure out a better way to enforce + # stop tokens when making calls to huggingface_hub. + text = enforce_stop_tokens(text, stop) + return text + + +def get_llm(use_openai_model=False, + model_name=None, + model=None, + tokenizer=None, + inference_server=None, + langchain_only_model=None, + stream_output=False, + async_output=True, + num_async=3, + do_sample=False, + temperature=0.1, + top_k=40, + top_p=0.7, + num_beams=1, + max_new_tokens=512, + min_new_tokens=1, + early_stopping=False, + max_time=180, + repetition_penalty=1.0, + num_return_sequences=1, + prompt_type=None, + prompt_dict=None, + prompter=None, + context=None, + iinput=None, + sanitize_bot_response=False, + system_prompt='', + visible_models=0, + h2ogpt_key=None, + min_max_new_tokens=None, + n_jobs=None, + cli=False, + llamacpp_dict=None, + verbose=False, + ): + # currently all but h2oai_pipeline case return prompt + new text, but could change + only_new_text = False + + if n_jobs in [None, -1]: + n_jobs = int(os.getenv('OMP_NUM_THREADS', str(os.cpu_count() // 2))) + if inference_server is None: + inference_server = '' + if inference_server.startswith('replicate'): + model_string = ':'.join(inference_server.split(':')[1:]) + if 'meta/llama' in model_string: + temperature = max(0.01, temperature if do_sample else 0) + else: + temperature =temperature if do_sample else 0 + gen_kwargs = dict(temperature=temperature, + seed=1234, + max_length=max_new_tokens, # langchain + max_new_tokens=max_new_tokens, # replicate docs + top_p=top_p if do_sample else 1, + top_k=top_k, # not always supported + repetition_penalty=repetition_penalty) + if system_prompt in [None, 'None', 'auto']: + if prompter.system_prompt: + system_prompt = prompter.system_prompt + else: + system_prompt = '' + if system_prompt: + gen_kwargs.update(dict(system_prompt=system_prompt)) + + # replicate handles prompting, so avoid get_response() filter + prompter.prompt_type = 'plain' + if stream_output: + callbacks = [StreamingGradioCallbackHandler()] + streamer = callbacks[0] if stream_output else None + llm = H2OReplicate( + streaming=True, + callbacks=callbacks, + model=model_string, + input=gen_kwargs, + stop=prompter.stop_sequences, + stop_sequences=prompter.stop_sequences, + sanitize_bot_response=sanitize_bot_response, + prompter=prompter, + context=context, + iinput=iinput, + tokenizer=tokenizer, + ) + else: + streamer = None + llm = H2OReplicate( + model=model_string, + input=gen_kwargs, + stop=prompter.stop_sequences, + stop_sequences=prompter.stop_sequences, + sanitize_bot_response=sanitize_bot_response, + prompter=prompter, + context=context, + iinput=iinput, + tokenizer=tokenizer, + ) + elif use_openai_model or inference_server.startswith('openai') or inference_server.startswith('vllm'): + if use_openai_model and model_name is None: + model_name = "gpt-3.5-turbo" + # FIXME: Will later import be ignored? I think so, so should be fine + openai, inf_type, deployment_name, base_url, api_version = set_openai(inference_server) + kwargs_extra = {} + if inf_type == 'openai_chat' or inf_type == 'vllm_chat': + cls = H2OChatOpenAI + # FIXME: Support context, iinput + # if inf_type == 'vllm_chat': + # kwargs_extra.update(dict(tokenizer=tokenizer)) + openai_api_key = openai.api_key + elif inf_type == 'openai_azure_chat': + cls = H2OAzureChatOpenAI + kwargs_extra.update(dict(openai_api_type='azure')) + # FIXME: Support context, iinput + if os.getenv('OPENAI_AZURE_KEY') is not None: + openai_api_key = os.getenv('OPENAI_AZURE_KEY') + else: + openai_api_key = openai.api_key + elif inf_type == 'openai_azure': + cls = H2OAzureOpenAI + kwargs_extra.update(dict(openai_api_type='azure')) + # FIXME: Support context, iinput + if os.getenv('OPENAI_AZURE_KEY') is not None: + openai_api_key = os.getenv('OPENAI_AZURE_KEY') + else: + openai_api_key = openai.api_key + else: + cls = H2OOpenAI + if inf_type == 'vllm': + kwargs_extra.update(dict(stop_sequences=prompter.stop_sequences, + sanitize_bot_response=sanitize_bot_response, + prompter=prompter, + context=context, + iinput=iinput, + tokenizer=tokenizer, + openai_api_base=openai.api_base, + client=None)) + else: + assert inf_type == 'openai' or use_openai_model + openai_api_key = openai.api_key + + if deployment_name: + kwargs_extra.update(dict(deployment_name=deployment_name)) + if api_version: + kwargs_extra.update(dict(openai_api_version=api_version)) + elif openai.api_version: + kwargs_extra.update(dict(openai_api_version=openai.api_version)) + elif inf_type in ['openai_azure', 'openai_azure_chat']: + kwargs_extra.update(dict(openai_api_version="2023-05-15")) + if base_url: + kwargs_extra.update(dict(openai_api_base=base_url)) + else: + kwargs_extra.update(dict(openai_api_base=openai.api_base)) + + callbacks = [StreamingGradioCallbackHandler()] + llm = cls(model_name=model_name, + temperature=temperature if do_sample else 0, + # FIXME: Need to count tokens and reduce max_new_tokens to fit like in generate.py + max_tokens=max_new_tokens, + top_p=top_p if do_sample else 1, + frequency_penalty=0, + presence_penalty=1.07 - repetition_penalty + 0.6, # so good default + callbacks=callbacks if stream_output else None, + openai_api_key=openai_api_key, + logit_bias=None if inf_type == 'vllm' else {}, + max_retries=6, + streaming=stream_output, + **kwargs_extra + ) + streamer = callbacks[0] if stream_output else None + if inf_type in ['openai', 'openai_chat', 'openai_azure', 'openai_azure_chat']: + prompt_type = inference_server + else: + # vllm goes here + prompt_type = prompt_type or 'plain' + elif inference_server and inference_server.startswith('sagemaker'): + callbacks = [StreamingGradioCallbackHandler()] # FIXME + streamer = None + + endpoint_name = ':'.join(inference_server.split(':')[1:2]) + region_name = ':'.join(inference_server.split(':')[2:]) + + from sagemaker import H2OSagemakerEndpoint, ChatContentHandler, BaseContentHandler + if inference_server.startswith('sagemaker_chat'): + content_handler = ChatContentHandler() + else: + content_handler = BaseContentHandler() + model_kwargs = dict(temperature=temperature if do_sample else 1E-10, + return_full_text=False, top_p=top_p, max_new_tokens=max_new_tokens) + llm = H2OSagemakerEndpoint( + endpoint_name=endpoint_name, + region_name=region_name, + aws_access_key_id=os.environ.get('AWS_ACCESS_KEY_ID'), + aws_secret_access_key=os.environ.get('AWS_SECRET_ACCESS_KEY'), + model_kwargs=model_kwargs, + content_handler=content_handler, + endpoint_kwargs={'CustomAttributes': 'accept_eula=true'}, + ) + elif inference_server: + assert inference_server.startswith( + 'http'), "Malformed inference_server=%s. Did you add http:// in front?" % inference_server + + from gradio_utils.grclient import GradioClient + from text_generation import Client as HFClient + if isinstance(model, GradioClient): + gr_client = model + hf_client = None + else: + gr_client = None + hf_client = model + assert isinstance(hf_client, HFClient) + + inference_server, headers = get_hf_server(inference_server) + + # quick sanity check to avoid long timeouts, just see if can reach server + requests.get(inference_server, timeout=int(os.getenv('REQUEST_TIMEOUT_FAST', '10'))) + callbacks = [StreamingGradioCallbackHandler()] + + if gr_client: + async_output = False # FIXME: not implemented yet + chat_client = False + llm = GradioInference( + inference_server_url=inference_server, + return_full_text=False, + + temperature=temperature, + top_p=top_p, + top_k=top_k, + num_beams=num_beams, + max_new_tokens=max_new_tokens, + min_new_tokens=min_new_tokens, + early_stopping=early_stopping, + max_time=max_time, + repetition_penalty=repetition_penalty, + num_return_sequences=num_return_sequences, + do_sample=do_sample, + chat_client=chat_client, + + callbacks=callbacks if stream_output else None, + stream_output=stream_output, + prompter=prompter, + context=context, + iinput=iinput, + client=gr_client, + sanitize_bot_response=sanitize_bot_response, + tokenizer=tokenizer, + system_prompt=system_prompt, + visible_models=visible_models, + h2ogpt_key=h2ogpt_key, + min_max_new_tokens=min_max_new_tokens, + ) + elif hf_client: + # no need to pass original client, no state and fast, so can use same validate_environment from base class + async_sem = asyncio.Semaphore(num_async) if async_output else NullContext() + llm = H2OHuggingFaceTextGenInference( + inference_server_url=inference_server, + do_sample=do_sample, + max_new_tokens=max_new_tokens, + repetition_penalty=repetition_penalty, + return_full_text=False, # this only controls internal behavior, still returns processed text + seed=SEED, + + stop_sequences=prompter.stop_sequences, + temperature=temperature, + top_k=top_k, + top_p=top_p, + # typical_p=top_p, + callbacks=callbacks if stream_output else None, + stream_output=stream_output, + prompter=prompter, + context=context, + iinput=iinput, + tokenizer=tokenizer, + timeout=max_time, + sanitize_bot_response=sanitize_bot_response, + async_sem=async_sem, + ) + else: + raise RuntimeError("No defined client") + streamer = callbacks[0] if stream_output else None + elif model_name in non_hf_types: + async_output = False # FIXME: not implemented yet + assert langchain_only_model + if model_name == 'llama': + callbacks = [StreamingGradioCallbackHandler()] + streamer = callbacks[0] if stream_output else None + else: + # stream_output = False + # doesn't stream properly as generator, but at least + callbacks = [streaming_stdout.StreamingStdOutCallbackHandler()] + streamer = None + if prompter: + prompt_type = prompter.prompt_type + else: + prompter = Prompter(prompt_type, prompt_dict, debug=False, chat=False, stream_output=stream_output) + pass # assume inputted prompt_type is correct + from gpt4all_llm import get_llm_gpt4all + max_max_tokens = tokenizer.model_max_length + llm = get_llm_gpt4all(model_name, + model=model, + max_new_tokens=max_new_tokens, + temperature=temperature, + repetition_penalty=repetition_penalty, + top_k=top_k, + top_p=top_p, + callbacks=callbacks, + n_jobs=n_jobs, + verbose=verbose, + streaming=stream_output, + prompter=prompter, + context=context, + iinput=iinput, + max_seq_len=max_max_tokens, + llamacpp_dict=llamacpp_dict, + ) + elif hasattr(model, 'is_exlama') and model.is_exlama(): + async_output = False # FIXME: not implemented yet + assert langchain_only_model + callbacks = [StreamingGradioCallbackHandler()] + streamer = callbacks[0] if stream_output else None + max_max_tokens = tokenizer.model_max_length + + from src.llm_exllama import Exllama + llm = Exllama(streaming=stream_output, + model_path=None, + model=model, + lora_path=None, + temperature=temperature, + top_k=top_k, + top_p=top_p, + typical=.7, + beams=1, + # beam_length = 40, + stop_sequences=prompter.stop_sequences, + callbacks=callbacks, + verbose=verbose, + max_seq_len=max_max_tokens, + fused_attn=False, + # alpha_value = 1.0, #For use with any models + # compress_pos_emb = 4.0, #For use with superhot + # set_auto_map = "3, 2" #Gpu split, this will split 3gigs/2gigs + prompter=prompter, + context=context, + iinput=iinput, + ) + else: + async_output = False # FIXME: not implemented yet + if model is None: + # only used if didn't pass model in + assert tokenizer is None + prompt_type = 'human_bot' + if model_name is None: + model_name = 'h2oai/h2ogpt-oasst1-512-12b' + # model_name = 'h2oai/h2ogpt-oig-oasst1-512-6_9b' + # model_name = 'h2oai/h2ogpt-oasst1-512-20b' + inference_server = '' + model, tokenizer, device = get_model(load_8bit=True, base_model=model_name, + inference_server=inference_server, gpu_id=0) + + max_max_tokens = tokenizer.model_max_length + only_new_text = True + gen_kwargs = dict(do_sample=do_sample, + num_beams=num_beams, + max_new_tokens=max_new_tokens, + min_new_tokens=min_new_tokens, + early_stopping=early_stopping, + max_time=max_time, + repetition_penalty=repetition_penalty, + num_return_sequences=num_return_sequences, + return_full_text=not only_new_text, + handle_long_generation=None) + if do_sample: + gen_kwargs.update(dict(temperature=temperature, + top_k=top_k, + top_p=top_p)) + assert len(set(gen_hyper).difference(gen_kwargs.keys())) == 0 + else: + assert len(set(gen_hyper0).difference(gen_kwargs.keys())) == 0 + + if stream_output: + skip_prompt = only_new_text + from gen import H2OTextIteratorStreamer + decoder_kwargs = {} + streamer = H2OTextIteratorStreamer(tokenizer, skip_prompt=skip_prompt, block=False, **decoder_kwargs) + gen_kwargs.update(dict(streamer=streamer)) + else: + streamer = None + + from h2oai_pipeline import H2OTextGenerationPipeline + pipe = H2OTextGenerationPipeline(model=model, use_prompter=True, + prompter=prompter, + context=context, + iinput=iinput, + prompt_type=prompt_type, + prompt_dict=prompt_dict, + sanitize_bot_response=sanitize_bot_response, + chat=False, stream_output=stream_output, + tokenizer=tokenizer, + # leave some room for 1 paragraph, even if min_new_tokens=0 + max_input_tokens=max_max_tokens - max(min_new_tokens, 256), + base_model=model_name, + **gen_kwargs) + # pipe.task = "text-generation" + # below makes it listen only to our prompt removal, + # not built in prompt removal that is less general and not specific for our model + pipe.task = "text2text-generation" + + llm = H2OHuggingFacePipeline(pipeline=pipe) + return llm, model_name, streamer, prompt_type, async_output, only_new_text + + +def get_device_dtype(): + # torch.device("cuda") leads to cuda:x cuda:y mismatches for multi-GPU consistently + import torch + n_gpus = torch.cuda.device_count() if torch.cuda.is_available else 0 + device = 'cpu' if n_gpus == 0 else 'cuda' + # from utils import NullContext + # context_class = NullContext if n_gpus > 1 or n_gpus == 0 else context_class + context_class = torch.device + torch_dtype = torch.float16 if device == 'cuda' else torch.float32 + return device, torch_dtype, context_class + + +def get_wiki_data(title, first_paragraph_only, text_limit=None, take_head=True): + """ + Get wikipedia data from online + :param title: + :param first_paragraph_only: + :param text_limit: + :param take_head: + :return: + """ + filename = 'wiki_%s_%s_%s_%s.data' % (first_paragraph_only, title, text_limit, take_head) + url = f"https://en.wikipedia.org/w/api.php?format=json&action=query&prop=extracts&explaintext=1&titles={title}" + if first_paragraph_only: + url += "&exintro=1" + import json + if not os.path.isfile(filename): + data = requests.get(url).json() + json.dump(data, open(filename, 'wt')) + else: + data = json.load(open(filename, "rt")) + page_content = list(data["query"]["pages"].values())[0]["extract"] + if take_head is not None and text_limit is not None: + page_content = page_content[:text_limit] if take_head else page_content[-text_limit:] + title_url = str(title).replace(' ', '_') + return Document( + page_content=str(page_content), + metadata={"source": f"https://en.wikipedia.org/wiki/{title_url}"}, + ) + + +def get_wiki_sources(first_para=True, text_limit=None): + """ + Get specific named sources from wikipedia + :param first_para: + :param text_limit: + :return: + """ + default_wiki_sources = ['Unix', 'Microsoft_Windows', 'Linux'] + wiki_sources = list(os.getenv('WIKI_SOURCES', default_wiki_sources)) + return [get_wiki_data(x, first_para, text_limit=text_limit) for x in wiki_sources] + + +def get_github_docs(repo_owner, repo_name): + """ + Access github from specific repo + :param repo_owner: + :param repo_name: + :return: + """ + with tempfile.TemporaryDirectory() as d: + subprocess.check_call( + f"git clone --depth 1 https://github.com/{repo_owner}/{repo_name}.git .", + cwd=d, + shell=True, + ) + git_sha = ( + subprocess.check_output("git rev-parse HEAD", shell=True, cwd=d) + .decode("utf-8") + .strip() + ) + repo_path = pathlib.Path(d) + markdown_files = list(repo_path.glob("*/*.md")) + list( + repo_path.glob("*/*.mdx") + ) + for markdown_file in markdown_files: + with open(markdown_file, "r") as f: + relative_path = markdown_file.relative_to(repo_path) + github_url = f"https://github.com/{repo_owner}/{repo_name}/blob/{git_sha}/{relative_path}" + yield Document(page_content=str(f.read()), metadata={"source": github_url}) + + +def get_dai_pickle(dest="."): + from huggingface_hub import hf_hub_download + # True for case when locally already logged in with correct token, so don't have to set key + token = os.getenv('HUGGING_FACE_HUB_TOKEN', True) + path_to_zip_file = hf_hub_download('h2oai/dai_docs', 'dai_docs.pickle', token=token, repo_type='dataset') + shutil.copy(path_to_zip_file, dest) + + +def get_dai_docs(from_hf=False, get_pickle=True): + """ + Consume DAI documentation, or consume from public pickle + :param from_hf: get DAI docs from HF, then generate pickle for later use by LangChain + :param get_pickle: Avoid raw DAI docs, just get pickle directly from HF + :return: + """ + import pickle + + if get_pickle: + get_dai_pickle() + + dai_store = 'dai_docs.pickle' + dst = "working_dir_docs" + if not os.path.isfile(dai_store): + from create_data import setup_dai_docs + dst = setup_dai_docs(dst=dst, from_hf=from_hf) + + import glob + files = list(glob.glob(os.path.join(dst, '*rst'), recursive=True)) + + basedir = os.path.abspath(os.getcwd()) + from create_data import rst_to_outputs + new_outputs = rst_to_outputs(files) + os.chdir(basedir) + + pickle.dump(new_outputs, open(dai_store, 'wb')) + else: + new_outputs = pickle.load(open(dai_store, 'rb')) + + sources = [] + for line, file in new_outputs: + # gradio requires any linked file to be with app.py + sym_src = os.path.abspath(os.path.join(dst, file)) + sym_dst = os.path.abspath(os.path.join(os.getcwd(), file)) + if os.path.lexists(sym_dst): + os.remove(sym_dst) + os.symlink(sym_src, sym_dst) + itm = Document(page_content=str(line), metadata={"source": file}) + # NOTE: yield has issues when going into db, loses metadata + # yield itm + sources.append(itm) + return sources + + +def get_supported_types(): + non_image_types0 = ["pdf", "txt", "csv", "toml", "py", "rst", "xml", "rtf", + "md", + "html", "mhtml", "htm", + "enex", "eml", "epub", "odt", "pptx", "ppt", + "zip", + "gz", + "gzip", + "urls", + ] + # "msg", GPL3 + + video_types0 = ['WEBM', + 'MPG', 'MP2', 'MPEG', 'MPE', '.PV', + 'OGG', + 'MP4', 'M4P', 'M4V', + 'AVI', 'WMV', + 'MOV', 'QT', + 'FLV', 'SWF', + 'AVCHD'] + video_types0 = [x.lower() for x in video_types0] + if have_pillow: + from PIL import Image + exts = Image.registered_extensions() + image_types0 = {ex for ex, f in exts.items() if f in Image.OPEN if ex not in video_types0 + non_image_types0} + image_types0 = sorted(image_types0) + image_types0 = [x[1:] if x.startswith('.') else x for x in image_types0] + else: + image_types0 = [] + return non_image_types0, image_types0, video_types0 + + +non_image_types, image_types, video_types = get_supported_types() +set_image_types = set(image_types) + +if have_libreoffice or True: + # or True so it tries to load, e.g. on MAC/Windows, even if don't have libreoffice since works without that + non_image_types.extend(["docx", "doc", "xls", "xlsx"]) +if have_jq: + non_image_types.extend(["json", "jsonl"]) + +file_types = non_image_types + image_types + + +def try_as_html(file): + # try treating as html as occurs when scraping websites + from bs4 import BeautifulSoup + with open(file, "rt") as f: + try: + is_html = bool(BeautifulSoup(f.read(), "html.parser").find()) + except: # FIXME + is_html = False + if is_html: + file_url = 'file://' + file + doc1 = UnstructuredURLLoader(urls=[file_url]).load() + doc1 = [x for x in doc1 if x.page_content] + else: + doc1 = [] + return doc1 + + +def json_metadata_func(record: dict, metadata: dict) -> dict: + # Define the metadata extraction function. + + if isinstance(record, dict): + metadata["sender_name"] = record.get("sender_name") + metadata["timestamp_ms"] = record.get("timestamp_ms") + + if "source" in metadata: + metadata["source_json"] = metadata['source'] + if "seq_num" in metadata: + metadata["seq_num_json"] = metadata['seq_num'] + + return metadata + + +def file_to_doc(file, + filei=0, + base_path=None, verbose=False, fail_any_exception=False, + chunk=True, chunk_size=512, n_jobs=-1, + is_url=False, is_txt=False, + + # urls + use_unstructured=True, + use_playwright=False, + use_selenium=False, + + # pdfs + use_pymupdf='auto', + use_unstructured_pdf='auto', + use_pypdf='auto', + enable_pdf_ocr='auto', + try_pdf_as_html='auto', + enable_pdf_doctr='auto', + + # images + enable_ocr=False, + enable_doctr=False, + enable_pix2struct=False, + enable_captions=True, + captions_model=None, + model_loaders=None, + + # json + jq_schema='.[]', + + headsize=50, # see also H2OSerpAPIWrapper + db_type=None, + selected_file_types=None): + assert isinstance(model_loaders, dict) + if selected_file_types is not None: + set_image_types1 = set_image_types.intersection(set(selected_file_types)) + else: + set_image_types1 = set_image_types + + assert db_type is not None + chunk_sources = functools.partial(_chunk_sources, chunk=chunk, chunk_size=chunk_size, db_type=db_type) + add_meta = functools.partial(_add_meta, headsize=headsize, filei=filei) + # FIXME: if zip, file index order will not be correct if other files involved + path_to_docs_func = functools.partial(path_to_docs, + verbose=verbose, + fail_any_exception=fail_any_exception, + n_jobs=n_jobs, + chunk=chunk, chunk_size=chunk_size, + # url=file if is_url else None, + # text=file if is_txt else None, + + # urls + use_unstructured=use_unstructured, + use_playwright=use_playwright, + use_selenium=use_selenium, + + # pdfs + use_pymupdf=use_pymupdf, + use_unstructured_pdf=use_unstructured_pdf, + use_pypdf=use_pypdf, + enable_pdf_ocr=enable_pdf_ocr, + enable_pdf_doctr=enable_pdf_doctr, + try_pdf_as_html=try_pdf_as_html, + + # images + enable_ocr=enable_ocr, + enable_doctr=enable_doctr, + enable_pix2struct=enable_pix2struct, + enable_captions=enable_captions, + captions_model=captions_model, + + caption_loader=model_loaders['caption'], + doctr_loader=model_loaders['doctr'], + pix2struct_loader=model_loaders['pix2struct'], + + # json + jq_schema=jq_schema, + + db_type=db_type, + ) + + if file is None: + if fail_any_exception: + raise RuntimeError("Unexpected None file") + else: + return [] + doc1 = [] # in case no support, or disabled support + if base_path is None and not is_txt and not is_url: + # then assume want to persist but don't care which path used + # can't be in base_path + dir_name = os.path.dirname(file) + base_name = os.path.basename(file) + # if from gradio, will have its own temp uuid too, but that's ok + base_name = sanitize_filename(base_name) + "_" + str(uuid.uuid4())[:10] + base_path = os.path.join(dir_name, base_name) + if is_url: + file = file.strip() # in case accidental spaces in front or at end + file_lower = file.lower() + case1 = file_lower.startswith('arxiv:') and len(file_lower.split('arxiv:')) == 2 + case2 = file_lower.startswith('https://arxiv.org/abs') and len(file_lower.split('https://arxiv.org/abs')) == 2 + case3 = file_lower.startswith('http://arxiv.org/abs') and len(file_lower.split('http://arxiv.org/abs')) == 2 + case4 = file_lower.startswith('arxiv.org/abs/') and len(file_lower.split('arxiv.org/abs/')) == 2 + if case1 or case2 or case3 or case4: + if case1: + query = file.lower().split('arxiv:')[1].strip() + elif case2: + query = file.lower().split('https://arxiv.org/abs/')[1].strip() + elif case2: + query = file.lower().split('http://arxiv.org/abs/')[1].strip() + elif case3: + query = file.lower().split('arxiv.org/abs/')[1].strip() + else: + raise RuntimeError("Unexpected arxiv error for %s" % file) + if have_arxiv: + trials = 3 + docs1 = [] + for trial in range(trials): + try: + docs1 = ArxivLoader(query=query, load_max_docs=20, load_all_available_meta=True).load() + break + except urllib.error.URLError: + pass + if not docs1: + print("Failed to get arxiv %s" % query, flush=True) + # ensure string, sometimes None + [[x.metadata.update({k: str(v)}) for k, v in x.metadata.items()] for x in docs1] + query_url = f"https://arxiv.org/abs/{query}" + [x.metadata.update( + dict(source=x.metadata.get('entry_id', query_url), query=query_url, + input_type='arxiv', head=x.metadata.get('Title', ''), date=str(datetime.now))) for x in + docs1] + else: + docs1 = [] + else: + if not (file.startswith("http://") or file.startswith("file://") or file.startswith("https://")): + file = 'http://' + file + docs1 = [] + do_unstructured = only_unstructured_urls or use_unstructured + if only_selenium or only_playwright: + do_unstructured = False + do_playwright = have_playwright and (use_playwright or only_playwright) + if only_unstructured_urls or only_selenium: + do_playwright = False + do_selenium = have_selenium and (use_selenium or only_selenium) + if only_unstructured_urls or only_playwright: + do_selenium = False + if do_unstructured or use_unstructured: + docs1a = UnstructuredURLLoader(urls=[file]).load() + docs1a = [x for x in docs1a if x.page_content] + add_parser(docs1a, 'UnstructuredURLLoader') + docs1.extend(docs1a) + if len(docs1) == 0 and have_playwright or do_playwright: + # then something went wrong, try another loader: + from langchain.document_loaders import PlaywrightURLLoader + docs1a = asyncio.run(PlaywrightURLLoader(urls=[file]).aload()) + # docs1 = PlaywrightURLLoader(urls=[file]).load() + docs1a = [x for x in docs1a if x.page_content] + add_parser(docs1a, 'PlaywrightURLLoader') + docs1.extend(docs1a) + if len(docs1) == 0 and have_selenium or do_selenium: + # then something went wrong, try another loader: + # but requires Chrome binary, else get: selenium.common.exceptions.WebDriverException: + # Message: unknown error: cannot find Chrome binary + from langchain.document_loaders import SeleniumURLLoader + from selenium.common.exceptions import WebDriverException + try: + docs1a = SeleniumURLLoader(urls=[file]).load() + docs1a = [x for x in docs1a if x.page_content] + add_parser(docs1a, 'SeleniumURLLoader') + docs1.extend(docs1a) + except WebDriverException as e: + print("No web driver: %s" % str(e), flush=True) + [x.metadata.update(dict(input_type='url', date=str(datetime.now))) for x in docs1] + add_meta(docs1, file, parser="is_url") + docs1 = clean_doc(docs1) + doc1 = chunk_sources(docs1) + elif is_txt: + base_path = "user_paste" + base_path = makedirs(base_path, exist_ok=True, tmp_ok=True, use_base=True) + source_file = os.path.join(base_path, "_%s" % str(uuid.uuid4())[:10]) + with open(source_file, "wt") as f: + f.write(file) + metadata = dict(source=source_file, date=str(datetime.now()), input_type='pasted txt') + doc1 = Document(page_content=str(file), metadata=metadata) + add_meta(doc1, file, parser="f.write") + # Bit odd to change if was original text + # doc1 = clean_doc(doc1) + elif file.lower().endswith('.html') or file.lower().endswith('.mhtml') or file.lower().endswith('.htm'): + docs1 = UnstructuredHTMLLoader(file_path=file).load() + add_meta(docs1, file, parser='UnstructuredHTMLLoader') + docs1 = clean_doc(docs1) + doc1 = chunk_sources(docs1, language=Language.HTML) + elif (file.lower().endswith('.docx') or file.lower().endswith('.doc')) and (have_libreoffice or True): + docs1 = UnstructuredWordDocumentLoader(file_path=file).load() + add_meta(docs1, file, parser='UnstructuredWordDocumentLoader') + doc1 = chunk_sources(docs1) + elif (file.lower().endswith('.xlsx') or file.lower().endswith('.xls')) and (have_libreoffice or True): + docs1 = UnstructuredExcelLoader(file_path=file).load() + add_meta(docs1, file, parser='UnstructuredExcelLoader') + doc1 = chunk_sources(docs1) + elif file.lower().endswith('.odt'): + docs1 = UnstructuredODTLoader(file_path=file).load() + add_meta(docs1, file, parser='UnstructuredODTLoader') + doc1 = chunk_sources(docs1) + elif file.lower().endswith('pptx') or file.lower().endswith('ppt'): + docs1 = UnstructuredPowerPointLoader(file_path=file).load() + add_meta(docs1, file, parser='UnstructuredPowerPointLoader') + docs1 = clean_doc(docs1) + doc1 = chunk_sources(docs1) + elif file.lower().endswith('.txt'): + # use UnstructuredFileLoader ? + docs1 = TextLoader(file, encoding="utf8", autodetect_encoding=True).load() + # makes just one, but big one + doc1 = chunk_sources(docs1) + # Bit odd to change if was original text + # doc1 = clean_doc(doc1) + add_meta(doc1, file, parser='TextLoader') + elif file.lower().endswith('.rtf'): + docs1 = UnstructuredRTFLoader(file).load() + add_meta(docs1, file, parser='UnstructuredRTFLoader') + doc1 = chunk_sources(docs1) + elif file.lower().endswith('.md'): + docs1 = UnstructuredMarkdownLoader(file).load() + add_meta(docs1, file, parser='UnstructuredMarkdownLoader') + docs1 = clean_doc(docs1) + doc1 = chunk_sources(docs1, language=Language.MARKDOWN) + elif file.lower().endswith('.enex'): + docs1 = EverNoteLoader(file).load() + add_meta(doc1, file, parser='EverNoteLoader') + doc1 = chunk_sources(docs1) + elif file.lower().endswith('.epub'): + docs1 = UnstructuredEPubLoader(file).load() + add_meta(docs1, file, parser='UnstructuredEPubLoader') + doc1 = chunk_sources(docs1) + elif any(file.lower().endswith(x) for x in set_image_types1): + docs1 = [] + if verbose: + print("BEGIN: Tesseract", flush=True) + if have_tesseract and enable_ocr: + # OCR, somewhat works, but not great + docs1a = UnstructuredImageLoader(file, strategy='ocr_only').load() + # docs1a = UnstructuredImageLoader(file, strategy='hi_res').load() + docs1a = [x for x in docs1a if x.page_content] + add_meta(docs1a, file, parser='UnstructuredImageLoader') + docs1.extend(docs1a) + if verbose: + print("END: Tesseract", flush=True) + if have_doctr and enable_doctr: + if verbose: + print("BEGIN: DocTR", flush=True) + if model_loaders['doctr'] is not None and not isinstance(model_loaders['doctr'], (str, bool)): + if verbose: + print("Reuse DocTR", flush=True) + model_loaders['doctr'].load_model() + else: + if verbose: + print("Fresh DocTR", flush=True) + from image_doctr import H2OOCRLoader + model_loaders['doctr'] = H2OOCRLoader() + model_loaders['doctr'].set_document_paths([file]) + docs1c = model_loaders['doctr'].load() + docs1c = [x for x in docs1c if x.page_content] + add_meta(docs1c, file, parser='H2OOCRLoader: %s' % 'DocTR') + # caption didn't set source, so fix-up meta + for doci in docs1c: + doci.metadata['source'] = doci.metadata.get('document_path', file) + doci.metadata['hashid'] = hash_file(doci.metadata['source']) + docs1.extend(docs1c) + if verbose: + print("END: DocTR", flush=True) + if enable_captions: + # BLIP + if verbose: + print("BEGIN: BLIP", flush=True) + if model_loaders['caption'] is not None and not isinstance(model_loaders['caption'], (str, bool)): + # assumes didn't fork into this process with joblib, else can deadlock + if verbose: + print("Reuse BLIP", flush=True) + model_loaders['caption'].load_model() + else: + if verbose: + print("Fresh BLIP", flush=True) + from image_captions import H2OImageCaptionLoader + model_loaders['caption'] = H2OImageCaptionLoader(caption_gpu=model_loaders['caption'] == 'gpu', + blip_model=captions_model, + blip_processor=captions_model) + model_loaders['caption'].set_image_paths([file]) + docs1c = model_loaders['caption'].load() + docs1c = [x for x in docs1c if x.page_content] + add_meta(docs1c, file, parser='H2OImageCaptionLoader: %s' % captions_model) + # caption didn't set source, so fix-up meta + for doci in docs1c: + doci.metadata['source'] = doci.metadata.get('image_path', file) + doci.metadata['hashid'] = hash_file(doci.metadata['source']) + docs1.extend(docs1c) + + if verbose: + print("END: BLIP", flush=True) + if enable_pix2struct: + # BLIP + if verbose: + print("BEGIN: Pix2Struct", flush=True) + if model_loaders['pix2struct'] is not None and not isinstance(model_loaders['pix2struct'], (str, bool)): + if verbose: + print("Reuse pix2struct", flush=True) + model_loaders['pix2struct'].load_model() + else: + if verbose: + print("Fresh pix2struct", flush=True) + from image_pix2struct import H2OPix2StructLoader + model_loaders['pix2struct'] = H2OPix2StructLoader() + model_loaders['pix2struct'].set_image_paths([file]) + docs1c = model_loaders['pix2struct'].load() + docs1c = [x for x in docs1c if x.page_content] + add_meta(docs1c, file, parser='H2OPix2StructLoader: %s' % model_loaders['pix2struct']) + # caption didn't set source, so fix-up meta + for doci in docs1c: + doci.metadata['source'] = doci.metadata.get('image_path', file) + doci.metadata['hashid'] = hash_file(doci.metadata['source']) + docs1.extend(docs1c) + if verbose: + print("END: Pix2Struct", flush=True) + doc1 = chunk_sources(docs1) + elif file.lower().endswith('.msg'): + raise RuntimeError("Not supported, GPL3 license") + # docs1 = OutlookMessageLoader(file).load() + # docs1[0].metadata['source'] = file + elif file.lower().endswith('.eml'): + try: + docs1 = UnstructuredEmailLoader(file).load() + add_meta(docs1, file, parser='UnstructuredEmailLoader') + doc1 = chunk_sources(docs1) + except ValueError as e: + if 'text/html content not found in email' in str(e): + pass + else: + raise + doc1 = [x for x in doc1 if x.page_content] + if len(doc1) == 0: + # e.g. plain/text dict key exists, but not + # doc1 = TextLoader(file, encoding="utf8").load() + docs1 = UnstructuredEmailLoader(file, content_source="text/plain").load() + docs1 = [x for x in docs1 if x.page_content] + add_meta(docs1, file, parser='UnstructuredEmailLoader text/plain') + doc1 = chunk_sources(docs1) + # elif file.lower().endswith('.gcsdir'): + # doc1 = GCSDirectoryLoader(project_name, bucket, prefix).load() + # elif file.lower().endswith('.gcsfile'): + # doc1 = GCSFileLoader(project_name, bucket, blob).load() + elif file.lower().endswith('.rst'): + with open(file, "r") as f: + doc1 = Document(page_content=str(f.read()), metadata={"source": file}) + add_meta(doc1, file, parser='f.read()') + doc1 = chunk_sources(doc1, language=Language.RST) + elif file.lower().endswith('.json'): + # 10k rows, 100 columns-like parts 4 bytes each + JSON_SIZE_LIMIT = int(os.getenv('JSON_SIZE_LIMIT', str(10 * 10 * 1024 * 10 * 4))) + if os.path.getsize(file) > JSON_SIZE_LIMIT: + raise ValueError( + "JSON file sizes > %s not supported for naive parsing and embedding, requires Agents enabled" % JSON_SIZE_LIMIT) + loader = JSONLoader( + file_path=file, + # jq_schema='.messages[].content', + jq_schema=jq_schema, + text_content=False, + metadata_func=json_metadata_func) + doc1 = loader.load() + add_meta(doc1, file, parser='JSONLoader: %s' % jq_schema) + fix_json_meta(doc1) + elif file.lower().endswith('.jsonl'): + loader = JSONLoader( + file_path=file, + # jq_schema='.messages[].content', + jq_schema=jq_schema, + json_lines=True, + text_content=False, + metadata_func=json_metadata_func) + doc1 = loader.load() + add_meta(doc1, file, parser='JSONLoader: %s' % jq_schema) + fix_json_meta(doc1) + elif file.lower().endswith('.pdf'): + # migration + if isinstance(use_pymupdf, bool): + if use_pymupdf == False: + use_pymupdf = 'off' + if use_pymupdf == True: + use_pymupdf = 'on' + if isinstance(use_unstructured_pdf, bool): + if use_unstructured_pdf == False: + use_unstructured_pdf = 'off' + if use_unstructured_pdf == True: + use_unstructured_pdf = 'on' + if isinstance(use_pypdf, bool): + if use_pypdf == False: + use_pypdf = 'off' + if use_pypdf == True: + use_pypdf = 'on' + if isinstance(enable_pdf_ocr, bool): + if enable_pdf_ocr == False: + enable_pdf_ocr = 'off' + if enable_pdf_ocr == True: + enable_pdf_ocr = 'on' + if isinstance(try_pdf_as_html, bool): + if try_pdf_as_html == False: + try_pdf_as_html = 'off' + if try_pdf_as_html == True: + try_pdf_as_html = 'on' + + doc1 = [] + tried_others = False + handled = False + did_pymupdf = False + did_unstructured = False + e = None + if have_pymupdf and (len(doc1) == 0 and use_pymupdf == 'auto' or use_pymupdf == 'on'): + # GPL, only use if installed + from langchain.document_loaders import PyMuPDFLoader + # load() still chunks by pages, but every page has title at start to help + try: + doc1a = PyMuPDFLoader(file).load() + did_pymupdf = True + except BaseException as e0: + doc1a = [] + print("PyMuPDFLoader: %s" % str(e0), flush=True) + e = e0 + # remove empty documents + handled |= len(doc1a) > 0 + doc1a = [x for x in doc1a if x.page_content] + doc1a = clean_doc(doc1a) + add_parser(doc1a, 'PyMuPDFLoader') + doc1.extend(doc1a) + if len(doc1) == 0 and use_unstructured_pdf == 'auto' or use_unstructured_pdf == 'on': + tried_others = True + try: + doc1a = UnstructuredPDFLoader(file).load() + did_unstructured = True + except BaseException as e0: + doc1a = [] + print("UnstructuredPDFLoader: %s" % str(e0), flush=True) + e = e0 + handled |= len(doc1a) > 0 + # remove empty documents + doc1a = [x for x in doc1a if x.page_content] + add_parser(doc1a, 'UnstructuredPDFLoader') + # seems to not need cleaning in most cases + doc1.extend(doc1a) + if len(doc1) == 0 and use_pypdf == 'auto' or use_pypdf == 'on': + tried_others = True + # open-source fallback + # load() still chunks by pages, but every page has title at start to help + try: + doc1a = PyPDFLoader(file).load() + except BaseException as e0: + doc1a = [] + print("PyPDFLoader: %s" % str(e0), flush=True) + e = e0 + handled |= len(doc1a) > 0 + # remove empty documents + doc1a = [x for x in doc1a if x.page_content] + doc1a = clean_doc(doc1a) + add_parser(doc1a, 'PyPDFLoader') + doc1.extend(doc1a) + if not did_pymupdf and ((have_pymupdf and len(doc1) == 0) and tried_others): + # try again in case only others used, but only if didn't already try (2nd part of and) + # GPL, only use if installed + from langchain.document_loaders import PyMuPDFLoader + # load() still chunks by pages, but every page has title at start to help + try: + doc1a = PyMuPDFLoader(file).load() + except BaseException as e0: + doc1a = [] + print("PyMuPDFLoader: %s" % str(e0), flush=True) + e = e0 + handled |= len(doc1a) > 0 + # remove empty documents + doc1a = [x for x in doc1a if x.page_content] + doc1a = clean_doc(doc1a) + add_parser(doc1a, 'PyMuPDFLoader2') + doc1.extend(doc1a) + did_pdf_ocr = False + if len(doc1) == 0 and (enable_pdf_ocr == 'auto' and enable_pdf_doctr != 'on') or enable_pdf_ocr == 'on': + did_pdf_ocr = True + # no did_unstructured condition here because here we do OCR, and before we did not + # try OCR in end since slowest, but works on pure image pages well + doc1a = UnstructuredPDFLoader(file, strategy='ocr_only').load() + handled |= len(doc1a) > 0 + # remove empty documents + doc1a = [x for x in doc1a if x.page_content] + add_parser(doc1a, 'UnstructuredPDFLoader ocr_only') + # seems to not need cleaning in most cases + doc1.extend(doc1a) + # Some PDFs return nothing or junk from PDFMinerLoader + if len(doc1) == 0 and enable_pdf_doctr == 'auto' or enable_pdf_doctr == 'on': + if verbose: + print("BEGIN: DocTR", flush=True) + if model_loaders['doctr'] is not None and not isinstance(model_loaders['doctr'], (str, bool)): + model_loaders['doctr'].load_model() + else: + from image_doctr import H2OOCRLoader + model_loaders['doctr'] = H2OOCRLoader() + model_loaders['doctr'].set_document_paths([file]) + doc1a = model_loaders['doctr'].load() + doc1a = [x for x in doc1a if x.page_content] + add_meta(doc1a, file, parser='H2OOCRLoader: %s' % 'DocTR') + handled |= len(doc1a) > 0 + # caption didn't set source, so fix-up meta + for doci in doc1a: + doci.metadata['source'] = doci.metadata.get('document_path', file) + doci.metadata['hashid'] = hash_file(doci.metadata['source']) + doc1.extend(doc1a) + if verbose: + print("END: DocTR", flush=True) + if try_pdf_as_html in ['auto', 'on']: + doc1a = try_as_html(file) + add_parser(doc1a, 'try_as_html') + doc1.extend(doc1a) + + if len(doc1) == 0: + # if literally nothing, show failed to parse so user knows, since unlikely nothing in PDF at all. + if handled: + raise ValueError("%s had no valid text, but meta data was parsed" % file) + else: + raise ValueError("%s had no valid text and no meta data was parsed: %s" % (file, str(e))) + add_meta(doc1, file, parser='pdf') + doc1 = chunk_sources(doc1) + elif file.lower().endswith('.csv'): + CSV_SIZE_LIMIT = int(os.getenv('CSV_SIZE_LIMIT', str(10 * 1024 * 10 * 4))) + if os.path.getsize(file) > CSV_SIZE_LIMIT: + raise ValueError( + "CSV file sizes > %s not supported for naive parsing and embedding, requires Agents enabled" % CSV_SIZE_LIMIT) + doc1 = CSVLoader(file).load() + add_meta(doc1, file, parser='CSVLoader') + if isinstance(doc1, list): + # each row is a Document, identify + [x.metadata.update(dict(chunk_id=chunk_id)) for chunk_id, x in enumerate(doc1)] + if db_type in ['chroma', 'chroma_old']: + # then separate summarize list + sdoc1 = clone_documents(doc1) + [x.metadata.update(dict(chunk_id=-1)) for chunk_id, x in enumerate(sdoc1)] + doc1 = sdoc1 + doc1 + elif file.lower().endswith('.py'): + doc1 = PythonLoader(file).load() + add_meta(doc1, file, parser='PythonLoader') + doc1 = chunk_sources(doc1, language=Language.PYTHON) + elif file.lower().endswith('.toml'): + doc1 = TomlLoader(file).load() + add_meta(doc1, file, parser='TomlLoader') + doc1 = chunk_sources(doc1) + elif file.lower().endswith('.xml'): + from langchain.document_loaders import UnstructuredXMLLoader + loader = UnstructuredXMLLoader(file_path=file) + doc1 = loader.load() + add_meta(doc1, file, parser='UnstructuredXMLLoader') + elif file.lower().endswith('.urls'): + with open(file, "r") as f: + urls = f.readlines() + # recurse + doc1 = path_to_docs_func(None, url=urls) + elif file.lower().endswith('.zip'): + with zipfile.ZipFile(file, 'r') as zip_ref: + # don't put into temporary path, since want to keep references to docs inside zip + # so just extract in path where + zip_ref.extractall(base_path) + # recurse + doc1 = path_to_docs_func(base_path) + elif file.lower().endswith('.gz') or file.lower().endswith('.gzip'): + if file.lower().endswith('.gz'): + de_file = file.lower().replace('.gz', '') + else: + de_file = file.lower().replace('.gzip', '') + with gzip.open(file, 'rb') as f_in: + with open(de_file, 'wb') as f_out: + shutil.copyfileobj(f_in, f_out) + # recurse + doc1 = file_to_doc(de_file, + filei=filei, # single file, same file index as outside caller + base_path=base_path, verbose=verbose, fail_any_exception=fail_any_exception, + chunk=chunk, chunk_size=chunk_size, n_jobs=n_jobs, + is_url=is_url, is_txt=is_txt, + + # urls + use_unstructured=use_unstructured, + use_playwright=use_playwright, + use_selenium=use_selenium, + + # pdfs + use_pymupdf=use_pymupdf, + use_unstructured_pdf=use_unstructured_pdf, + use_pypdf=use_pypdf, + enable_pdf_ocr=enable_pdf_ocr, + enable_pdf_doctr=enable_pdf_doctr, + try_pdf_as_html=try_pdf_as_html, + + # images + enable_ocr=enable_ocr, + enable_doctr=enable_doctr, + enable_pix2struct=enable_pix2struct, + enable_captions=enable_captions, + captions_model=captions_model, + model_loaders=model_loaders, + + # json + jq_schema=jq_schema, + + headsize=headsize, + db_type=db_type, + selected_file_types=selected_file_types) + else: + raise RuntimeError("No file handler for %s" % os.path.basename(file)) + + # allow doc1 to be list or not. + if not isinstance(doc1, list): + # If not list, did not chunk yet, so chunk now + docs = chunk_sources([doc1]) + elif isinstance(doc1, list) and len(doc1) == 1: + # if list of length one, don't trust and chunk it, chunk_id's will still be correct if repeat + docs = chunk_sources(doc1) + else: + docs = doc1 + + assert isinstance(docs, list) + return docs + + +def path_to_doc1(file, + filei=0, + verbose=False, fail_any_exception=False, return_file=True, + chunk=True, chunk_size=512, + n_jobs=-1, + is_url=False, is_txt=False, + + # urls + use_unstructured=True, + use_playwright=False, + use_selenium=False, + + # pdfs + use_pymupdf='auto', + use_unstructured_pdf='auto', + use_pypdf='auto', + enable_pdf_ocr='auto', + enable_pdf_doctr='auto', + try_pdf_as_html='auto', + + # images + enable_ocr=False, + enable_doctr=False, + enable_pix2struct=False, + enable_captions=True, + captions_model=None, + model_loaders=None, + + # json + jq_schema='.[]', + + db_type=None, + selected_file_types=None): + assert db_type is not None + if verbose: + if is_url: + print("Ingesting URL: %s" % file, flush=True) + elif is_txt: + print("Ingesting Text: %s" % file, flush=True) + else: + print("Ingesting file: %s" % file, flush=True) + res = None + try: + # don't pass base_path=path, would infinitely recurse + res = file_to_doc(file, + filei=filei, + base_path=None, verbose=verbose, fail_any_exception=fail_any_exception, + chunk=chunk, chunk_size=chunk_size, + n_jobs=n_jobs, + is_url=is_url, is_txt=is_txt, + + # urls + use_unstructured=use_unstructured, + use_playwright=use_playwright, + use_selenium=use_selenium, + + # pdfs + use_pymupdf=use_pymupdf, + use_unstructured_pdf=use_unstructured_pdf, + use_pypdf=use_pypdf, + enable_pdf_ocr=enable_pdf_ocr, + enable_pdf_doctr=enable_pdf_doctr, + try_pdf_as_html=try_pdf_as_html, + + # images + enable_ocr=enable_ocr, + enable_doctr=enable_doctr, + enable_pix2struct=enable_pix2struct, + enable_captions=enable_captions, + captions_model=captions_model, + model_loaders=model_loaders, + + # json + jq_schema=jq_schema, + + db_type=db_type, + selected_file_types=selected_file_types) + except BaseException as e: + print("Failed to ingest %s due to %s" % (file, traceback.format_exc())) + if fail_any_exception: + raise + else: + exception_doc = Document( + page_content='', + metadata={"source": file, "exception": '%s Exception: %s' % (file, str(e)), + "traceback": traceback.format_exc()}) + res = [exception_doc] + if verbose: + if is_url: + print("DONE Ingesting URL: %s" % file, flush=True) + elif is_txt: + print("DONE Ingesting Text: %s" % file, flush=True) + else: + print("DONE Ingesting file: %s" % file, flush=True) + if return_file: + base_tmp = "temp_path_to_doc1" + if not os.path.isdir(base_tmp): + base_tmp = makedirs(base_tmp, exist_ok=True, tmp_ok=True, use_base=True) + filename = os.path.join(base_tmp, str(uuid.uuid4()) + ".tmp.pickle") + with open(filename, 'wb') as f: + pickle.dump(res, f) + return filename + return res + + +def path_to_docs(path_or_paths, verbose=False, fail_any_exception=False, n_jobs=-1, + chunk=True, chunk_size=512, + url=None, text=None, + + # urls + use_unstructured=True, + use_playwright=False, + use_selenium=False, + + # pdfs + use_pymupdf='auto', + use_unstructured_pdf='auto', + use_pypdf='auto', + enable_pdf_ocr='auto', + enable_pdf_doctr='auto', + try_pdf_as_html='auto', + + # images + enable_ocr=False, + enable_doctr=False, + enable_pix2struct=False, + enable_captions=True, + captions_model=None, + + caption_loader=None, + doctr_loader=None, + pix2struct_loader=None, + + # json + jq_schema='.[]', + + existing_files=[], + existing_hash_ids={}, + db_type=None, + selected_file_types=None, + ): + if verbose: + print("BEGIN Consuming path_or_paths=%s url=%s text=%s" % (path_or_paths, url, text), flush=True) + if selected_file_types is not None: + non_image_types1 = [x for x in non_image_types if x in selected_file_types] + image_types1 = [x for x in image_types if x in selected_file_types] + else: + non_image_types1 = non_image_types.copy() + image_types1 = image_types.copy() + + assert db_type is not None + # path_or_paths could be str, list, tuple, generator + globs_image_types = [] + globs_non_image_types = [] + if not path_or_paths and not url and not text: + return [] + elif url: + url = get_list_or_str(url) + globs_non_image_types = url if isinstance(url, (list, tuple, types.GeneratorType)) else [url] + elif text: + globs_non_image_types = text if isinstance(text, (list, tuple, types.GeneratorType)) else [text] + elif isinstance(path_or_paths, str) and os.path.isdir(path_or_paths): + # single path, only consume allowed files + path = path_or_paths + # Below globs should match patterns in file_to_doc() + [globs_image_types.extend(glob.glob(os.path.join(path, "./**/*.%s" % ftype), recursive=True)) + for ftype in image_types1] + globs_image_types = [os.path.normpath(x) for x in globs_image_types] + [globs_non_image_types.extend(glob.glob(os.path.join(path, "./**/*.%s" % ftype), recursive=True)) + for ftype in non_image_types1] + globs_non_image_types = [os.path.normpath(x) for x in globs_non_image_types] + else: + if isinstance(path_or_paths, str): + if os.path.isfile(path_or_paths) or os.path.isdir(path_or_paths): + path_or_paths = [path_or_paths] + else: + # path was deleted etc. + return [] + # list/tuple of files (consume what can, and exception those that selected but cannot consume so user knows) + assert isinstance(path_or_paths, (list, tuple, types.GeneratorType)), \ + "Wrong type for path_or_paths: %s %s" % (path_or_paths, type(path_or_paths)) + # reform out of allowed types + globs_image_types.extend( + flatten_list([[os.path.normpath(x) for x in path_or_paths if x.endswith(y)] for y in image_types1])) + # could do below: + # globs_non_image_types = flatten_list([[x for x in path_or_paths if x.endswith(y)] for y in non_image_types1]) + # But instead, allow fail so can collect unsupported too + set_globs_image_types = set(globs_image_types) + globs_non_image_types.extend([os.path.normpath(x) for x in path_or_paths if x not in set_globs_image_types]) + + # filter out any files to skip (e.g. if already processed them) + # this is easy, but too aggressive in case a file changed, so parent probably passed existing_files=[] + assert not existing_files, "DEV: assume not using this approach" + if existing_files: + set_skip_files = set(existing_files) + globs_image_types = [x for x in globs_image_types if x not in set_skip_files] + globs_non_image_types = [x for x in globs_non_image_types if x not in set_skip_files] + if existing_hash_ids: + # assume consistent with add_meta() use of hash_file(file) + # also assume consistent with get_existing_hash_ids for dict creation + # assume hashable values + existing_hash_ids_set = set(existing_hash_ids.items()) + hash_ids_all_image = set({x: hash_file(x) for x in globs_image_types}.items()) + hash_ids_all_non_image = set({x: hash_file(x) for x in globs_non_image_types}.items()) + # don't use symmetric diff. If file is gone, ignore and don't remove or something + # just consider existing files (key) having new hash or not (value) + new_files_image = set(dict(hash_ids_all_image - existing_hash_ids_set).keys()) + new_files_non_image = set(dict(hash_ids_all_non_image - existing_hash_ids_set).keys()) + globs_image_types = [x for x in globs_image_types if x in new_files_image] + globs_non_image_types = [x for x in globs_non_image_types if x in new_files_non_image] + + # could use generator, but messes up metadata handling in recursive case + if caption_loader and not isinstance(caption_loader, (bool, str)) and caption_loader.device != 'cpu' or \ + get_device() == 'cuda': + # to avoid deadlocks, presume was preloaded and so can't fork due to cuda context + # get_device() == 'cuda' because presume faster to process image from (temporarily) preloaded model + n_jobs_image = 1 + else: + n_jobs_image = n_jobs + if enable_doctr or enable_pdf_doctr in [True, 'auto', 'on']: + if doctr_loader and not isinstance(doctr_loader, (bool, str)) and doctr_loader.device != 'cpu': + # can't fork cuda context + n_jobs = 1 + + return_file = True # local choice + is_url = url is not None + is_txt = text is not None + model_loaders = dict(caption=caption_loader, + doctr=doctr_loader, + pix2struct=pix2struct_loader) + model_loaders0 = model_loaders.copy() + kwargs = dict(verbose=verbose, fail_any_exception=fail_any_exception, + return_file=return_file, + chunk=chunk, chunk_size=chunk_size, + n_jobs=n_jobs, + is_url=is_url, + is_txt=is_txt, + + # urls + use_unstructured=use_unstructured, + use_playwright=use_playwright, + use_selenium=use_selenium, + + # pdfs + use_pymupdf=use_pymupdf, + use_unstructured_pdf=use_unstructured_pdf, + use_pypdf=use_pypdf, + enable_pdf_ocr=enable_pdf_ocr, + enable_pdf_doctr=enable_pdf_doctr, + try_pdf_as_html=try_pdf_as_html, + + # images + enable_ocr=enable_ocr, + enable_doctr=enable_doctr, + enable_pix2struct=enable_pix2struct, + enable_captions=enable_captions, + captions_model=captions_model, + model_loaders=model_loaders, + + # json + jq_schema=jq_schema, + + db_type=db_type, + selected_file_types=selected_file_types, + ) + if n_jobs != 1 and len(globs_non_image_types) > 1: + # avoid nesting, e.g. upload 1 zip and then inside many files + # harder to handle if upload many zips with many files, inner parallel one will be disabled by joblib + documents = ProgressParallel(n_jobs=n_jobs, verbose=10 if verbose else 0, backend='multiprocessing')( + delayed(path_to_doc1)(file, filei=filei, **kwargs) for filei, file in enumerate(globs_non_image_types) + ) + else: + documents = [path_to_doc1(file, filei=filei, **kwargs) for filei, file in + enumerate(tqdm(globs_non_image_types))] + + # do images separately since can't fork after cuda in parent, so can't be parallel + if n_jobs_image != 1 and len(globs_image_types) > 1: + # avoid nesting, e.g. upload 1 zip and then inside many files + # harder to handle if upload many zips with many files, inner parallel one will be disabled by joblib + image_documents = ProgressParallel(n_jobs=n_jobs, verbose=10 if verbose else 0, backend='multiprocessing')( + delayed(path_to_doc1)(file, filei=filei, **kwargs) for filei, file in enumerate(globs_image_types) + ) + else: + image_documents = [path_to_doc1(file, filei=filei, **kwargs) for filei, file in + enumerate(tqdm(globs_image_types))] + + # unload loaders (image loaders, includes enable_pdf_doctr that uses same loader) + for name, loader in model_loaders.items(): + loader0 = model_loaders0[name] + real_model_initial = loader0 is not None and not isinstance(loader0, (str, bool)) + real_model_final = model_loaders[name] is not None and not isinstance(model_loaders[name], (str, bool)) + if not real_model_initial and real_model_final: + # clear off GPU newly added model + model_loaders[name].unload_model() + + # add image docs in + documents += image_documents + + if return_file: + # then documents really are files + files = documents.copy() + documents = [] + for fil in files: + with open(fil, 'rb') as f: + documents.extend(pickle.load(f)) + # remove temp pickle + remove(fil) + else: + documents = reduce(concat, documents) + + if verbose: + print("END consuming path_or_paths=%s url=%s text=%s" % (path_or_paths, url, text), flush=True) + return documents + + +def prep_langchain(persist_directory, + load_db_if_exists, + db_type, use_openai_embedding, + langchain_mode, langchain_mode_paths, langchain_mode_types, + hf_embedding_model, + migrate_embedding_model, + auto_migrate_db, + n_jobs=-1, kwargs_make_db={}, + verbose=False): + """ + do prep first time, involving downloads + # FIXME: Add github caching then add here + :return: + """ + if os.getenv("HARD_ASSERTS"): + assert langchain_mode not in ['MyData'], "Should not prep scratch/personal data" + + if langchain_mode in langchain_modes_intrinsic: + return None + + db_dir_exists = os.path.isdir(persist_directory) + user_path = langchain_mode_paths.get(langchain_mode) + + if db_dir_exists and user_path is None: + if verbose: + print("Prep: persist_directory=%s exists, using" % persist_directory, flush=True) + db, use_openai_embedding, hf_embedding_model = \ + get_existing_db(None, persist_directory, load_db_if_exists, + db_type, use_openai_embedding, + langchain_mode, langchain_mode_paths, langchain_mode_types, + hf_embedding_model, migrate_embedding_model, auto_migrate_db, + n_jobs=n_jobs) + else: + if db_dir_exists and user_path is not None: + if verbose: + print("Prep: persist_directory=%s exists, user_path=%s passed, adding any changed or new documents" % ( + persist_directory, user_path), flush=True) + elif not db_dir_exists: + if verbose: + print("Prep: persist_directory=%s does not exist, regenerating" % persist_directory, flush=True) + db = None + if langchain_mode in ['DriverlessAI docs']: + # FIXME: Could also just use dai_docs.pickle directly and upload that + get_dai_docs(from_hf=True) + + if langchain_mode in ['wiki']: + get_wiki_sources(first_para=kwargs_make_db['first_para'], text_limit=kwargs_make_db['text_limit']) + + langchain_kwargs = kwargs_make_db.copy() + langchain_kwargs.update(locals()) + db, num_new_sources, new_sources_metadata = make_db(**langchain_kwargs) + + return db + + +import posthog + +posthog.disabled = True + + +class FakeConsumer(object): + def __init__(self, *args, **kwargs): + pass + + def run(self): + pass + + def pause(self): + pass + + def upload(self): + pass + + def next(self): + pass + + def request(self, batch): + pass + + +posthog.Consumer = FakeConsumer + + +def check_update_chroma_embedding(db, + db_type, + use_openai_embedding, + hf_embedding_model, migrate_embedding_model, auto_migrate_db, + langchain_mode, langchain_mode_paths, langchain_mode_types, + n_jobs=-1): + changed_db = False + embed_tuple = load_embed(db=db) + if embed_tuple not in [(True, use_openai_embedding, hf_embedding_model), + (False, use_openai_embedding, hf_embedding_model)]: + print("Detected new embedding %s vs. %s %s, updating db: %s" % ( + use_openai_embedding, hf_embedding_model, embed_tuple, langchain_mode), flush=True) + # handle embedding changes + db_get = get_documents(db) + sources = [Document(page_content=result[0], metadata=result[1] or {}) + for result in zip(db_get['documents'], db_get['metadatas'])] + # delete index, has to be redone + persist_directory = db._persist_directory + shutil.move(persist_directory, persist_directory + "_" + str(uuid.uuid4()) + ".bak") + assert db_type in ['chroma', 'chroma_old'] + load_db_if_exists = False + db = get_db(sources, use_openai_embedding=use_openai_embedding, db_type=db_type, + persist_directory=persist_directory, load_db_if_exists=load_db_if_exists, + langchain_mode=langchain_mode, + langchain_mode_paths=langchain_mode_paths, + langchain_mode_types=langchain_mode_types, + collection_name=None, + hf_embedding_model=hf_embedding_model, + migrate_embedding_model=migrate_embedding_model, + auto_migrate_db=auto_migrate_db, + n_jobs=n_jobs, + ) + changed_db = True + print("Done updating db for new embedding: %s" % langchain_mode, flush=True) + + return db, changed_db + + +def migrate_meta_func(db, langchain_mode): + changed_db = False + db_get = get_documents(db) + # just check one doc + if len(db_get['metadatas']) > 0 and 'chunk_id' not in db_get['metadatas'][0]: + print("Detected old metadata, adding additional information", flush=True) + t0 = time.time() + # handle meta changes + [x.update(dict(chunk_id=x.get('chunk_id', 0))) for x in db_get['metadatas']] + client_collection = db._client.get_collection(name=db._collection.name, + embedding_function=db._collection._embedding_function) + client_collection.update(ids=db_get['ids'], metadatas=db_get['metadatas']) + # check + db_get = get_documents(db) + assert 'chunk_id' in db_get['metadatas'][0], "Failed to add meta" + changed_db = True + print("Done updating db for new meta: %s in %s seconds" % (langchain_mode, time.time() - t0), flush=True) + + return db, changed_db + + +def get_existing_db(db, persist_directory, + load_db_if_exists, db_type, use_openai_embedding, + langchain_mode, langchain_mode_paths, langchain_mode_types, + hf_embedding_model, + migrate_embedding_model, + auto_migrate_db=False, + verbose=False, check_embedding=True, migrate_meta=True, + n_jobs=-1): + if load_db_if_exists and db_type in ['chroma', 'chroma_old'] and os.path.isdir(persist_directory): + if os.path.isfile(os.path.join(persist_directory, 'chroma.sqlite3')): + must_migrate = False + elif os.path.isdir(os.path.join(persist_directory, 'index')): + must_migrate = True + else: + return db, use_openai_embedding, hf_embedding_model + chroma_settings = dict(is_persistent=True) + use_chromamigdb = False + if must_migrate: + if auto_migrate_db: + print("Detected chromadb<0.4 database, require migration, doing now....", flush=True) + from chroma_migrate.import_duckdb import migrate_from_duckdb + import chromadb + api = chromadb.PersistentClient(path=persist_directory) + did_migration = migrate_from_duckdb(api, persist_directory) + assert did_migration, "Failed to migrate chroma collection at %s, see https://docs.trychroma.com/migration for CLI tool" % persist_directory + elif have_chromamigdb: + print( + "Detected chroma<0.4 database but --auto_migrate_db=False, but detected chromamigdb package, so using old database that still requires duckdb", + flush=True) + chroma_settings = dict(chroma_db_impl="duckdb+parquet") + use_chromamigdb = True + else: + raise ValueError( + "Detected chromadb<0.4 database, require migration, but did not detect chromamigdb package or did not choose auto_migrate_db=False (see FAQ.md)") + + if db is None: + if verbose: + print("DO Loading db: %s" % langchain_mode, flush=True) + got_embedding, use_openai_embedding0, hf_embedding_model0 = load_embed(persist_directory=persist_directory) + if got_embedding: + use_openai_embedding, hf_embedding_model = use_openai_embedding0, hf_embedding_model0 + embedding = get_embedding(use_openai_embedding, hf_embedding_model=hf_embedding_model) + import logging + logging.getLogger("chromadb").setLevel(logging.ERROR) + if use_chromamigdb: + from chromamigdb.config import Settings + chroma_class = ChromaMig + else: + from chromadb.config import Settings + chroma_class = Chroma + client_settings = Settings(anonymized_telemetry=False, + **chroma_settings, + persist_directory=persist_directory) + db = chroma_class(persist_directory=persist_directory, embedding_function=embedding, + collection_name=langchain_mode.replace(' ', '_'), + client_settings=client_settings) + try: + db.similarity_search('') + except BaseException as e: + # migration when no embed_info + if 'Dimensionality of (768) does not match index dimensionality (384)' in str(e) or \ + 'Embedding dimension 768 does not match collection dimensionality 384' in str(e): + hf_embedding_model = "sentence-transformers/all-MiniLM-L6-v2" + embedding = get_embedding(use_openai_embedding, hf_embedding_model=hf_embedding_model) + db = chroma_class(persist_directory=persist_directory, embedding_function=embedding, + collection_name=langchain_mode.replace(' ', '_'), + client_settings=client_settings) + # should work now, let fail if not + db.similarity_search('') + save_embed(db, use_openai_embedding, hf_embedding_model) + else: + raise + + if verbose: + print("DONE Loading db: %s" % langchain_mode, flush=True) + else: + if not migrate_embedding_model: + # OVERRIDE embedding choices if could load embedding info when not migrating + got_embedding, use_openai_embedding, hf_embedding_model = load_embed(db=db) + if verbose: + print("USING already-loaded db: %s" % langchain_mode, flush=True) + if check_embedding: + db_trial, changed_db = check_update_chroma_embedding(db, + db_type, + use_openai_embedding, + hf_embedding_model, + migrate_embedding_model, + auto_migrate_db, + langchain_mode, + langchain_mode_paths, + langchain_mode_types, + n_jobs=n_jobs) + if changed_db: + db = db_trial + # only call persist if really changed db, else takes too long for large db + if db is not None: + db.persist() + clear_embedding(db) + save_embed(db, use_openai_embedding, hf_embedding_model) + if migrate_meta and db is not None: + db_trial, changed_db = migrate_meta_func(db, langchain_mode) + if changed_db: + db = db_trial + return db, use_openai_embedding, hf_embedding_model + return db, use_openai_embedding, hf_embedding_model + + +def clear_embedding(db): + if db is None: + return + # don't keep on GPU, wastes memory, push back onto CPU and only put back on GPU once again embed + try: + if hasattr(db._embedding_function, 'client') and hasattr(db._embedding_function.client, 'cpu'): + # only push back to CPU if each db/user has own embedding model, else if shared share on GPU + if hasattr(db._embedding_function.client, 'preload') and not db._embedding_function.client.preload: + db._embedding_function.client.cpu() + clear_torch_cache() + except RuntimeError as e: + print("clear_embedding error: %s" % ''.join(traceback.format_tb(e.__traceback__)), flush=True) + + +def make_db(**langchain_kwargs): + func_names = list(inspect.signature(_make_db).parameters) + missing_kwargs = [x for x in func_names if x not in langchain_kwargs] + defaults_db = {k: v.default for k, v in dict(inspect.signature(run_qa_db).parameters).items()} + for k in missing_kwargs: + if k in defaults_db: + langchain_kwargs[k] = defaults_db[k] + # final check for missing + missing_kwargs = [x for x in func_names if x not in langchain_kwargs] + assert not missing_kwargs, "Missing kwargs for make_db: %s" % missing_kwargs + # only keep actual used + langchain_kwargs = {k: v for k, v in langchain_kwargs.items() if k in func_names} + return _make_db(**langchain_kwargs) + + +embed_lock_name = 'embed.lock' + + +def get_embed_lock_file(db, persist_directory=None): + if hasattr(db, '_persist_directory') or persist_directory: + if persist_directory is None: + persist_directory = db._persist_directory + check_persist_directory(persist_directory) + base_path = os.path.join('locks', persist_directory) + base_path = makedirs(base_path, exist_ok=True, tmp_ok=True, use_base=True) + lock_file = os.path.join(base_path, embed_lock_name) + makedirs(os.path.dirname(lock_file)) + return lock_file + return None + + +def save_embed(db, use_openai_embedding, hf_embedding_model): + if hasattr(db, '_persist_directory'): + persist_directory = db._persist_directory + lock_file = get_embed_lock_file(db) + with filelock.FileLock(lock_file): + embed_info_file = os.path.join(persist_directory, 'embed_info') + with open(embed_info_file, 'wb') as f: + if isinstance(hf_embedding_model, str): + hf_embedding_model_save = hf_embedding_model + elif hasattr(hf_embedding_model, 'model_name'): + hf_embedding_model_save = hf_embedding_model.model_name + elif isinstance(hf_embedding_model, dict) and 'name' in hf_embedding_model: + hf_embedding_model_save = hf_embedding_model['name'] + elif isinstance(hf_embedding_model, dict) and 'name' in hf_embedding_model: + if os.getenv('HARD_ASSERTS'): + # unexpected in testing or normally + raise RuntimeError("HERE") + hf_embedding_model_save = 'hkunlp/instructor-large' + pickle.dump((use_openai_embedding, hf_embedding_model_save), f) + return use_openai_embedding, hf_embedding_model + + +def load_embed(db=None, persist_directory=None): + if hasattr(db, 'embeddings') and hasattr(db.embeddings, 'model_name'): + hf_embedding_model = db.embeddings.model_name if 'openai' not in db.embeddings.model_name.lower() else None + use_openai_embedding = hf_embedding_model is None + save_embed(db, use_openai_embedding, hf_embedding_model) + return True, use_openai_embedding, hf_embedding_model + if persist_directory is None: + persist_directory = db._persist_directory + embed_info_file = os.path.join(persist_directory, 'embed_info') + if os.path.isfile(embed_info_file): + lock_file = get_embed_lock_file(db, persist_directory=persist_directory) + with filelock.FileLock(lock_file): + with open(embed_info_file, 'rb') as f: + try: + use_openai_embedding, hf_embedding_model = pickle.load(f) + if not isinstance(hf_embedding_model, str): + # work-around bug introduced here: https://github.com/h2oai/h2ogpt/commit/54c4414f1ce3b5b7c938def651c0f6af081c66de + hf_embedding_model = 'hkunlp/instructor-large' + # fix file + save_embed(db, use_openai_embedding, hf_embedding_model) + got_embedding = True + except EOFError: + use_openai_embedding, hf_embedding_model = False, 'hkunlp/instructor-large' + got_embedding = False + if os.getenv('HARD_ASSERTS'): + # unexpected in testing or normally + raise + else: + # migration, assume defaults + use_openai_embedding, hf_embedding_model = False, "sentence-transformers/all-MiniLM-L6-v2" + got_embedding = False + assert isinstance(hf_embedding_model, str) + return got_embedding, use_openai_embedding, hf_embedding_model + + +def get_persist_directory(langchain_mode, langchain_type=None, db1s=None, dbs=None): + if langchain_mode in [LangChainMode.DISABLED.value, LangChainMode.LLM.value]: + # not None so join works but will fail to find db + return '', langchain_type + + userid = get_userid_direct(db1s) + username = get_username_direct(db1s) + + # sanity for bad code + assert userid != 'None' + assert username != 'None' + + dirid = username or userid + if langchain_type == LangChainTypes.SHARED.value and not dirid: + dirid = './' # just to avoid error + if langchain_type == LangChainTypes.PERSONAL.value and not dirid: + # e.g. from client when doing transient calls with MyData + if db1s is None: + # just trick to get filled locally + db1s = {LangChainMode.MY_DATA.value: [None, None, None]} + set_userid_direct(db1s, str(uuid.uuid4()), str(uuid.uuid4())) + userid = get_userid_direct(db1s) + username = get_username_direct(db1s) + dirid = username or userid + langchain_type = LangChainTypes.PERSONAL.value + + # deal with existing locations + user_base_dir = os.getenv('USERS_BASE_DIR', 'users') + persist_directory = os.path.join(user_base_dir, dirid, 'db_dir_%s' % langchain_mode) + if userid and \ + (os.path.isdir(persist_directory) or + db1s is not None and langchain_mode in db1s or + langchain_type == LangChainTypes.PERSONAL.value): + langchain_type = LangChainTypes.PERSONAL.value + persist_directory = makedirs(persist_directory, use_base=True) + check_persist_directory(persist_directory) + return persist_directory, langchain_type + + persist_directory = 'db_dir_%s' % langchain_mode + if (os.path.isdir(persist_directory) or + dbs is not None and langchain_mode in dbs or + langchain_type == LangChainTypes.SHARED.value): + # ensure consistent + langchain_type = LangChainTypes.SHARED.value + persist_directory = makedirs(persist_directory, use_base=True) + check_persist_directory(persist_directory) + return persist_directory, langchain_type + + # dummy return for prep_langchain() or full personal space + base_others = 'db_nonusers' + persist_directory = os.path.join(base_others, 'db_dir_%s' % str(uuid.uuid4())) + persist_directory = makedirs(persist_directory, use_base=True) + langchain_type = LangChainTypes.PERSONAL.value + + check_persist_directory(persist_directory) + return persist_directory, langchain_type + + +def check_persist_directory(persist_directory): + # deal with some cases when see intrinsic names being used as shared + for langchain_mode in langchain_modes_intrinsic: + if persist_directory == 'db_dir_%s' % langchain_mode: + raise RuntimeError("Illegal access to %s" % persist_directory) + + +def _make_db(use_openai_embedding=False, + hf_embedding_model=None, + migrate_embedding_model=False, + auto_migrate_db=False, + first_para=False, text_limit=None, + chunk=True, chunk_size=512, + + # urls + use_unstructured=True, + use_playwright=False, + use_selenium=False, + + # pdfs + use_pymupdf='auto', + use_unstructured_pdf='auto', + use_pypdf='auto', + enable_pdf_ocr='auto', + enable_pdf_doctr='auto', + try_pdf_as_html='auto', + + # images + enable_ocr=False, + enable_doctr=False, + enable_pix2struct=False, + enable_captions=True, + captions_model=None, + caption_loader=None, + doctr_loader=None, + pix2struct_loader=None, + + # json + jq_schema='.[]', + + langchain_mode=None, + langchain_mode_paths=None, + langchain_mode_types=None, + db_type='faiss', + load_db_if_exists=True, + db=None, + n_jobs=-1, + verbose=False): + assert hf_embedding_model is not None + user_path = langchain_mode_paths.get(langchain_mode) + langchain_type = langchain_mode_types.get(langchain_mode, LangChainTypes.EITHER.value) + persist_directory, langchain_type = get_persist_directory(langchain_mode, langchain_type=langchain_type) + langchain_mode_types[langchain_mode] = langchain_type + # see if can get persistent chroma db + db_trial, use_openai_embedding, hf_embedding_model = \ + get_existing_db(db, persist_directory, load_db_if_exists, db_type, + use_openai_embedding, + langchain_mode, langchain_mode_paths, langchain_mode_types, + hf_embedding_model, migrate_embedding_model, auto_migrate_db, verbose=verbose, + n_jobs=n_jobs) + if db_trial is not None: + db = db_trial + + sources = [] + if not db: + chunk_sources = functools.partial(_chunk_sources, chunk=chunk, chunk_size=chunk_size, db_type=db_type) + if langchain_mode in ['wiki_full']: + from read_wiki_full import get_all_documents + small_test = None + print("Generating new wiki", flush=True) + sources1 = get_all_documents(small_test=small_test, n_jobs=os.cpu_count() // 2) + print("Got new wiki", flush=True) + sources1 = chunk_sources(sources1, chunk=chunk) + print("Chunked new wiki", flush=True) + sources.extend(sources1) + elif langchain_mode in ['wiki']: + sources1 = get_wiki_sources(first_para=first_para, text_limit=text_limit) + sources1 = chunk_sources(sources1, chunk=chunk) + sources.extend(sources1) + elif langchain_mode in ['github h2oGPT']: + # sources = get_github_docs("dagster-io", "dagster") + sources1 = get_github_docs("h2oai", "h2ogpt") + # FIXME: always chunk for now + sources1 = chunk_sources(sources1) + sources.extend(sources1) + elif langchain_mode in ['DriverlessAI docs']: + sources1 = get_dai_docs(from_hf=True) + # FIXME: DAI docs are already chunked well, should only chunk more if over limit + sources1 = chunk_sources(sources1, chunk=False) + sources.extend(sources1) + if user_path: + # UserData or custom, which has to be from user's disk + if db is not None: + # NOTE: Ignore file names for now, only go by hash ids + # existing_files = get_existing_files(db) + existing_files = [] + existing_hash_ids = get_existing_hash_ids(db) + else: + # pretend no existing files so won't filter + existing_files = [] + existing_hash_ids = [] + # chunk internally for speed over multiple docs + # FIXME: If first had old Hash=None and switch embeddings, + # then re-embed, and then hit here and reload so have hash, and then re-embed. + sources1 = path_to_docs(user_path, n_jobs=n_jobs, chunk=chunk, chunk_size=chunk_size, + # urls + use_unstructured=use_unstructured, + use_playwright=use_playwright, + use_selenium=use_selenium, + + # pdfs + use_pymupdf=use_pymupdf, + use_unstructured_pdf=use_unstructured_pdf, + use_pypdf=use_pypdf, + enable_pdf_ocr=enable_pdf_ocr, + enable_pdf_doctr=enable_pdf_doctr, + try_pdf_as_html=try_pdf_as_html, + + # images + enable_ocr=enable_ocr, + enable_doctr=enable_doctr, + enable_pix2struct=enable_pix2struct, + enable_captions=enable_captions, + captions_model=captions_model, + caption_loader=caption_loader, + doctr_loader=doctr_loader, + pix2struct_loader=pix2struct_loader, + + # json + jq_schema=jq_schema, + + existing_files=existing_files, existing_hash_ids=existing_hash_ids, + db_type=db_type) + new_metadata_sources = set([x.metadata['source'] for x in sources1]) + if new_metadata_sources: + if os.getenv('NO_NEW_FILES') is not None: + raise RuntimeError("Expected no new files! %s" % new_metadata_sources) + print("Loaded %s new files as sources to add to %s" % (len(new_metadata_sources), langchain_mode), + flush=True) + if verbose: + print("Files added: %s" % '\n'.join(new_metadata_sources), flush=True) + sources.extend(sources1) + if len(sources) > 0 and os.getenv('NO_NEW_FILES') is not None: + raise RuntimeError("Expected no new files! %s" % langchain_mode) + if len(sources) == 0 and os.getenv('SHOULD_NEW_FILES') is not None: + raise RuntimeError("Expected new files! %s" % langchain_mode) + print("Loaded %s sources for potentially adding to %s" % (len(sources), langchain_mode), flush=True) + + # see if got sources + if not sources: + if verbose: + if db is not None: + print("langchain_mode %s has no new sources, nothing to add to db" % langchain_mode, flush=True) + else: + print("langchain_mode %s has no sources, not making new db" % langchain_mode, flush=True) + return db, 0, [] + if verbose: + if db is not None: + print("Generating db", flush=True) + else: + print("Adding to db", flush=True) + if not db: + if sources: + db = get_db(sources, use_openai_embedding=use_openai_embedding, db_type=db_type, + persist_directory=persist_directory, + langchain_mode=langchain_mode, + langchain_mode_paths=langchain_mode_paths, + langchain_mode_types=langchain_mode_types, + hf_embedding_model=hf_embedding_model, + migrate_embedding_model=migrate_embedding_model, + auto_migrate_db=auto_migrate_db, + n_jobs=n_jobs) + if verbose: + print("Generated db", flush=True) + elif langchain_mode not in langchain_modes_intrinsic: + print("Did not generate db for %s since no sources" % langchain_mode, flush=True) + new_sources_metadata = [x.metadata for x in sources] + elif user_path is not None: + print("Existing db, potentially adding %s sources from user_path=%s" % (len(sources), user_path), flush=True) + db, num_new_sources, new_sources_metadata = add_to_db(db, sources, db_type=db_type, + use_openai_embedding=use_openai_embedding, + hf_embedding_model=hf_embedding_model) + print("Existing db, added %s new sources from user_path=%s" % (num_new_sources, user_path), flush=True) + else: + new_sources_metadata = [x.metadata for x in sources] + + return db, len(new_sources_metadata), new_sources_metadata + + +def get_metadatas(db): + metadatas = [] + from langchain.vectorstores import FAISS + if isinstance(db, FAISS): + metadatas = [v.metadata for k, v in db.docstore._dict.items()] + elif isinstance(db, Chroma) or isinstance(db, ChromaMig) or ChromaMig.__name__ in str(db): + metadatas = get_documents(db)['metadatas'] + elif db is not None: + # FIXME: Hack due to https://github.com/weaviate/weaviate/issues/1947 + # seems no way to get all metadata, so need to avoid this approach for weaviate + metadatas = [x.metadata for x in db.similarity_search("", k=10000)] + return metadatas + + +def get_db_lock_file(db, lock_type='getdb'): + if hasattr(db, '_persist_directory'): + persist_directory = db._persist_directory + check_persist_directory(persist_directory) + base_path = os.path.join('locks', persist_directory) + base_path = makedirs(base_path, exist_ok=True, tmp_ok=True, use_base=True) + lock_file = os.path.join(base_path, "%s.lock" % lock_type) + makedirs(os.path.dirname(lock_file)) # ensure made + return lock_file + return None + + +def get_documents(db): + if hasattr(db, '_persist_directory'): + lock_file = get_db_lock_file(db) + with filelock.FileLock(lock_file): + # get segfaults and other errors when multiple threads access this + return _get_documents(db) + else: + return _get_documents(db) + + +def _get_documents(db): + from langchain.vectorstores import FAISS + if isinstance(db, FAISS): + documents = [v for k, v in db.docstore._dict.items()] + documents = dict(documents=documents) + elif isinstance(db, Chroma) or isinstance(db, ChromaMig) or ChromaMig.__name__ in str(db): + documents = db.get() + else: + # FIXME: Hack due to https://github.com/weaviate/weaviate/issues/1947 + # seems no way to get all metadata, so need to avoid this approach for weaviate + documents = [x for x in db.similarity_search("", k=10000)] + documents = dict(documents=documents) + return documents + + +def get_docs_and_meta(db, top_k_docs, filter_kwargs={}, text_context_list=None): + if hasattr(db, '_persist_directory'): + lock_file = get_db_lock_file(db) + with filelock.FileLock(lock_file): + return _get_docs_and_meta(db, top_k_docs, filter_kwargs=filter_kwargs, text_context_list=text_context_list) + else: + return _get_docs_and_meta(db, top_k_docs, filter_kwargs=filter_kwargs, text_context_list=text_context_list) + + +def _get_docs_and_meta(db, top_k_docs, filter_kwargs={}, text_context_list=None): + db_documents = [] + db_metadatas = [] + + if text_context_list: + db_documents += [x.page_content if hasattr(x, 'page_content') else x for x in text_context_list] + db_metadatas += [x.metadata if hasattr(x, 'metadata') else {} for x in text_context_list] + + from langchain.vectorstores import FAISS + if isinstance(db, Chroma) or isinstance(db, ChromaMig) or ChromaMig.__name__ in str(db): + db_get = db._collection.get(where=filter_kwargs.get('filter')) + db_metadatas += db_get['metadatas'] + db_documents += db_get['documents'] + elif isinstance(db, FAISS): + import itertools + db_metadatas += get_metadatas(db) + # FIXME: FAISS has no filter + if top_k_docs == -1: + db_documents += list(db.docstore._dict.values()) + else: + # slice dict first + db_documents += list(dict(itertools.islice(db.docstore._dict.items(), top_k_docs)).values()) + elif db is not None: + db_metadatas += get_metadatas(db) + db_documents += get_documents(db)['documents'] + + return db_documents, db_metadatas + + +def get_existing_files(db): + metadatas = get_metadatas(db) + metadata_sources = set([x['source'] for x in metadatas]) + return metadata_sources + + +def get_existing_hash_ids(db): + metadatas = get_metadatas(db) + # assume consistency, that any prior hashed source was single hashed file at the time among all source chunks + metadata_hash_ids = {os.path.normpath(x['source']): x.get('hashid') for x in metadatas} + return metadata_hash_ids + + +def run_qa_db(**kwargs): + func_names = list(inspect.signature(_run_qa_db).parameters) + # hard-coded defaults + kwargs['answer_with_sources'] = kwargs.get('answer_with_sources', True) + kwargs['show_rank'] = kwargs.get('show_rank', False) + kwargs['show_accordions'] = kwargs.get('show_accordions', True) + kwargs['show_link_in_sources'] = kwargs.get('show_link_in_sources', True) + kwargs['top_k_docs_max_show'] = kwargs.get('top_k_docs_max_show', 10) + kwargs['llamacpp_dict'] = {} # shouldn't be required unless from test using _run_qa_db + missing_kwargs = [x for x in func_names if x not in kwargs] + assert not missing_kwargs, "Missing kwargs for run_qa_db: %s" % missing_kwargs + # only keep actual used + kwargs = {k: v for k, v in kwargs.items() if k in func_names} + try: + return _run_qa_db(**kwargs) + finally: + clear_torch_cache() + + +def _run_qa_db(query=None, + iinput=None, + context=None, + use_openai_model=False, use_openai_embedding=False, + first_para=False, text_limit=None, top_k_docs=4, chunk=True, chunk_size=512, + + # urls + use_unstructured=True, + use_playwright=False, + use_selenium=False, + + # pdfs + use_pymupdf='auto', + use_unstructured_pdf='auto', + use_pypdf='auto', + enable_pdf_ocr='auto', + enable_pdf_doctr='auto', + try_pdf_as_html='auto', + + # images + enable_ocr=False, + enable_doctr=False, + enable_pix2struct=False, + enable_captions=True, + captions_model=None, + caption_loader=None, + doctr_loader=None, + pix2struct_loader=None, + + # json + jq_schema='.[]', + + langchain_mode_paths={}, + langchain_mode_types={}, + detect_user_path_changes_every_query=False, + db_type=None, + model_name=None, model=None, tokenizer=None, inference_server=None, + langchain_only_model=False, + hf_embedding_model=None, + migrate_embedding_model=False, + auto_migrate_db=False, + stream_output=False, + async_output=True, + num_async=3, + prompter=None, + prompt_type=None, + prompt_dict=None, + answer_with_sources=True, + append_sources_to_answer=True, + cut_distance=1.64, + add_chat_history_to_context=True, + add_search_to_context=False, + keep_sources_in_context=False, + memory_restriction_level=0, + system_prompt='', + sanitize_bot_response=False, + show_rank=False, + show_accordions=True, + show_link_in_sources=True, + top_k_docs_max_show=10, + use_llm_if_no_docs=True, + load_db_if_exists=False, + db=None, + do_sample=False, + temperature=0.1, + top_k=40, + top_p=0.7, + num_beams=1, + max_new_tokens=512, + min_new_tokens=1, + early_stopping=False, + max_time=180, + repetition_penalty=1.0, + num_return_sequences=1, + langchain_mode=None, + langchain_action=None, + langchain_agents=None, + document_subset=DocumentSubset.Relevant.name, + document_choice=[DocumentChoice.ALL.value], + pre_prompt_query=None, + prompt_query=None, + pre_prompt_summary=None, + prompt_summary=None, + text_context_list=None, + chat_conversation=None, + visible_models=None, + h2ogpt_key=None, + docs_ordering_type='reverse_ucurve_sort', + min_max_new_tokens=256, + + n_jobs=-1, + llamacpp_dict=None, + verbose=False, + cli=False, + lora_weights='', + auto_reduce_chunks=True, + max_chunks=100, + total_tokens_for_docs=None, + headsize=50, + ): + """ + + :param query: + :param use_openai_model: + :param use_openai_embedding: + :param first_para: + :param text_limit: + :param top_k_docs: + :param chunk: + :param chunk_size: + :param langchain_mode_paths: dict of langchain_mode -> user path to glob recursively from + :param db_type: 'faiss' for in-memory + 'chroma' (for chroma >= 0.4) + 'chroma_old' (for chroma < 0.4) + 'weaviate' for persisted on disk + :param model_name: model name, used to switch behaviors + :param model: pre-initialized model, else will make new one + :param tokenizer: pre-initialized tokenizer, else will make new one. Required not None if model is not None + :param answer_with_sources + :return: + """ + t_run = time.time() + if stream_output: + # threads and asyncio don't mix + async_output = False + if langchain_action in [LangChainAction.QUERY.value]: + # only summarization supported + async_output = False + + # in case None, e.g. lazy client, then set based upon actual model + pre_prompt_query, prompt_query, pre_prompt_summary, prompt_summary = \ + get_langchain_prompts(pre_prompt_query, prompt_query, + pre_prompt_summary, prompt_summary, + model_name, inference_server, + llamacpp_dict.get('model_path_llama')) + + assert db_type is not None + assert hf_embedding_model is not None + assert langchain_mode_paths is not None + assert langchain_mode_types is not None + if model is not None: + assert model_name is not None # require so can make decisions + assert query is not None + assert prompter is not None or prompt_type is not None or model is None # if model is None, then will generate + if prompter is not None: + prompt_type = prompter.prompt_type + prompt_dict = prompter.prompt_dict + if model is not None: + assert prompt_type is not None + if prompt_type == PromptType.custom.name: + assert prompt_dict is not None # should at least be {} or '' + else: + prompt_dict = '' + + if LangChainAgent.SEARCH.value in langchain_agents and 'llama' in model_name.lower(): + system_prompt = """You are a zero shot react agent. +Consider to prompt of Question that was original query from the user. +Respond to prompt of Thought with a thought that may lead to a reasonable new action choice. +Respond to prompt of Action with an action to take out of the tools given, giving exactly single word for the tool name. +Respond to prompt of Action Input with an input to give the tool. +Consider to prompt of Observation that was response from the tool. +Repeat this Thought, Action, Action Input, Observation, Thought sequence several times with new and different thoughts and actions each time, do not repeat. +Once satisfied that the thoughts, responses are sufficient to answer the question, then respond to prompt of Thought with: I now know the final answer +Respond to prompt of Final Answer with your final high-quality bullet list answer to the original query. +""" + prompter.system_prompt = system_prompt + + assert len(set(gen_hyper).difference(inspect.signature(get_llm).parameters)) == 0 + # pass in context to LLM directly, since already has prompt_type structure + # can't pass through langchain in get_chain() to LLM: https://github.com/hwchase17/langchain/issues/6638 + llm, model_name, streamer, prompt_type_out, async_output, only_new_text = \ + get_llm(use_openai_model=use_openai_model, model_name=model_name, + model=model, + tokenizer=tokenizer, + inference_server=inference_server, + langchain_only_model=langchain_only_model, + stream_output=stream_output, + async_output=async_output, + num_async=num_async, + do_sample=do_sample, + temperature=temperature, + top_k=top_k, + top_p=top_p, + num_beams=num_beams, + max_new_tokens=max_new_tokens, + min_new_tokens=min_new_tokens, + early_stopping=early_stopping, + max_time=max_time, + repetition_penalty=repetition_penalty, + num_return_sequences=num_return_sequences, + prompt_type=prompt_type, + prompt_dict=prompt_dict, + prompter=prompter, + context=context, + iinput=iinput, + sanitize_bot_response=sanitize_bot_response, + system_prompt=system_prompt, + visible_models=visible_models, + h2ogpt_key=h2ogpt_key, + min_max_new_tokens=min_max_new_tokens, + n_jobs=n_jobs, + llamacpp_dict=llamacpp_dict, + cli=cli, + verbose=verbose, + ) + # in case change, override original prompter + if hasattr(llm, 'prompter'): + prompter = llm.prompter + if hasattr(llm, 'pipeline') and hasattr(llm.pipeline, 'prompter'): + prompter = llm.pipeline.prompter + + if prompter is None: + if prompt_type is None: + prompt_type = prompt_type_out + # get prompter + chat = True # FIXME? + prompter = Prompter(prompt_type, prompt_dict, debug=False, chat=chat, stream_output=stream_output, + system_prompt=system_prompt) + + use_docs_planned = False + scores = [] + chain = None + + # basic version of prompt without docs etc. + data_point = dict(context=context, instruction=query, input=iinput) + prompt_basic = prompter.generate_prompt(data_point) + + if isinstance(document_choice, str): + # support string as well + document_choice = [document_choice] + + func_names = list(inspect.signature(get_chain).parameters) + sim_kwargs = {k: v for k, v in locals().items() if k in func_names} + missing_kwargs = [x for x in func_names if x not in sim_kwargs] + assert not missing_kwargs, "Missing: %s" % missing_kwargs + docs, chain, scores, \ + use_docs_planned, num_docs_before_cut, \ + use_llm_if_no_docs, llm_mode, top_k_docs_max_show = \ + get_chain(**sim_kwargs) + if document_subset in non_query_commands: + formatted_doc_chunks = '\n\n'.join([get_url(x) + '\n\n' + x.page_content for x in docs]) + if not formatted_doc_chunks and not use_llm_if_no_docs: + yield dict(prompt=prompt_basic, response="No sources", sources='', num_prompt_tokens=0) + return + # if no souces, outside gpt_langchain, LLM will be used with '' input + scores = [1] * len(docs) + get_answer_args = tuple([query, docs, formatted_doc_chunks, scores, show_rank, + answer_with_sources, + append_sources_to_answer]) + get_answer_kwargs = dict(show_accordions=show_accordions, + show_link_in_sources=show_link_in_sources, + top_k_docs_max_show=top_k_docs_max_show, + docs_ordering_type=docs_ordering_type, + num_docs_before_cut=num_docs_before_cut, + verbose=verbose) + ret, extra = get_sources_answer(*get_answer_args, **get_answer_kwargs) + yield dict(prompt=prompt_basic, response=formatted_doc_chunks, sources=extra, num_prompt_tokens=0) + return + if not use_llm_if_no_docs: + if not docs and langchain_action in [LangChainAction.SUMMARIZE_MAP.value, + LangChainAction.SUMMARIZE_ALL.value, + LangChainAction.SUMMARIZE_REFINE.value]: + ret = 'No relevant documents to summarize.' if num_docs_before_cut else 'No documents to summarize.' + extra = '' + yield dict(prompt=prompt_basic, response=ret, sources=extra, num_prompt_tokens=0) + return + if not docs and not llm_mode: + ret = 'No relevant documents to query (for chatting with LLM, pick Resources->Collections->LLM).' if num_docs_before_cut else 'No documents to query (for chatting with LLM, pick Resources->Collections->LLM).' + extra = '' + yield dict(prompt=prompt_basic, response=ret, sources=extra, num_prompt_tokens=0) + return + + if chain is None and not langchain_only_model: + # here if no docs at all and not HF type + # can only return if HF type + return + + # context stuff similar to used in evaluate() + import torch + device, torch_dtype, context_class = get_device_dtype() + conditional_type = hasattr(llm, 'pipeline') and hasattr(llm.pipeline, 'model') and hasattr(llm.pipeline.model, + 'conditional_type') and llm.pipeline.model.conditional_type + with torch.no_grad(): + have_lora_weights = lora_weights not in [no_lora_str, '', None] + context_class_cast = NullContext if device == 'cpu' or have_lora_weights else torch.autocast + if conditional_type: + # issues when casting to float16, can mess up t5 model, e.g. only when not streaming, or other odd behaviors + context_class_cast = NullContext + with context_class_cast(device): + if stream_output and streamer: + answer = None + import queue + bucket = queue.Queue() + thread = EThread(target=chain, streamer=streamer, bucket=bucket) + thread.start() + outputs = "" + try: + for new_text in streamer: + # print("new_text: %s" % new_text, flush=True) + if bucket.qsize() > 0 or thread.exc: + thread.join() + outputs += new_text + if prompter: # and False: # FIXME: pipeline can already use prompter + if conditional_type: + if prompter.botstr: + prompt = prompter.botstr + output_with_prompt = prompt + outputs + only_new_text = False + else: + prompt = None + output_with_prompt = outputs + only_new_text = True + else: + prompt = None # FIXME + output_with_prompt = outputs + # don't specify only_new_text here, use get_llm() value + output1 = prompter.get_response(output_with_prompt, prompt=prompt, + only_new_text=only_new_text, + sanitize_bot_response=sanitize_bot_response) + yield dict(prompt=prompt, response=output1, sources='', num_prompt_tokens=0) + else: + yield dict(prompt=prompt, response=outputs, sources='', num_prompt_tokens=0) + except BaseException: + # if any exception, raise that exception if was from thread, first + if thread.exc: + raise thread.exc + raise + finally: + # in case no exception and didn't join with thread yet, then join + if not thread.exc: + answer = thread.join() + if isinstance(answer, dict): + if 'output_text' in answer: + answer = answer['output_text'] + elif 'output' in answer: + answer = answer['output'] + # in case raise StopIteration or broke queue loop in streamer, but still have exception + if thread.exc: + raise thread.exc + else: + if async_output: + import asyncio + answer = asyncio.run(chain()) + else: + answer = chain() + if isinstance(answer, dict): + if 'output_text' in answer: + answer = answer['output_text'] + elif 'output' in answer: + answer = answer['output'] + + get_answer_args = tuple([query, docs, answer, scores, show_rank, + answer_with_sources, + append_sources_to_answer]) + get_answer_kwargs = dict(show_accordions=show_accordions, + show_link_in_sources=show_link_in_sources, + top_k_docs_max_show=top_k_docs_max_show, + docs_ordering_type=docs_ordering_type, + num_docs_before_cut=num_docs_before_cut, + verbose=verbose, + t_run=t_run, + count_input_tokens=llm.count_input_tokens + if hasattr(llm, 'count_input_tokens') else None, + count_output_tokens=llm.count_output_tokens + if hasattr(llm, 'count_output_tokens') else None) + + t_run = time.time() - t_run + + # for final yield, get real prompt used + if hasattr(llm, 'prompter') and llm.prompter.prompt is not None: + prompt = llm.prompter.prompt + else: + prompt = prompt_basic + num_prompt_tokens = get_token_count(prompt, tokenizer) + + if not use_docs_planned: + ret = answer + extra = '' + yield dict(prompt=prompt, response=ret, sources=extra, num_prompt_tokens=num_prompt_tokens) + elif answer is not None: + ret, extra = get_sources_answer(*get_answer_args, **get_answer_kwargs) + yield dict(prompt=prompt, response=ret, sources=extra, num_prompt_tokens=num_prompt_tokens) + return + + +def get_docs_with_score(query, k_db, filter_kwargs, db, db_type, text_context_list=None, verbose=False): + docs_with_score = [] + got_db_docs = False + + if text_context_list: + docs_with_score += [(x, x.metadata.get('score', 1.0)) for x in text_context_list] + + # deal with bug in chroma where if (say) 234 doc chunks and ask for 233+ then fails due to reduction misbehavior + if hasattr(db, '_embedding_function') and isinstance(db._embedding_function, FakeEmbeddings): + top_k_docs = -1 + # don't add text_context_list twice + db_documents, db_metadatas = get_docs_and_meta(db, top_k_docs, filter_kwargs=filter_kwargs, + text_context_list=None) + # sort by order given to parser (file_id) and any chunk_id if chunked + doc_file_ids = [x.get('file_id', 0) for x in db_metadatas] + doc_chunk_ids = [x.get('chunk_id', 0) for x in db_metadatas] + docs_with_score_fake = [(Document(page_content=result[0], metadata=result[1] or {}), 1.0) + for result in zip(db_documents, db_metadatas)] + docs_with_score_fake = [x for fx, cx, x in + sorted(zip(doc_file_ids, doc_chunk_ids, docs_with_score_fake), + key=lambda x: (x[0], x[1])) + ] + got_db_docs |= len(docs_with_score_fake) > 0 + docs_with_score += docs_with_score_fake + elif db is not None and db_type in ['chroma', 'chroma_old']: + while True: + try: + docs_with_score_chroma = db.similarity_search_with_score(query, k=k_db, **filter_kwargs) + break + except (RuntimeError, AttributeError) as e: + # AttributeError is for people with wrong version of langchain + if verbose: + print("chroma bug: %s" % str(e), flush=True) + if k_db == 1: + raise + if k_db > 500: + k_db -= 200 + elif k_db > 100: + k_db -= 50 + elif k_db > 10: + k_db -= 5 + else: + k_db -= 1 + k_db = max(1, k_db) + got_db_docs |= len(docs_with_score_chroma) > 0 + docs_with_score += docs_with_score_chroma + elif db is not None: + docs_with_score_other = db.similarity_search_with_score(query, k=k_db, **filter_kwargs) + got_db_docs |= len(docs_with_score_other) > 0 + docs_with_score += docs_with_score_other + + # set in metadata original order of docs + [x[0].metadata.update(orig_index=ii) for ii, x in enumerate(docs_with_score)] + + return docs_with_score, got_db_docs + + +def get_chain(query=None, + iinput=None, + context=None, # FIXME: https://github.com/hwchase17/langchain/issues/6638 + use_openai_model=False, use_openai_embedding=False, + first_para=False, text_limit=None, top_k_docs=4, chunk=True, chunk_size=512, + + # urls + use_unstructured=True, + use_playwright=False, + use_selenium=False, + + # pdfs + use_pymupdf='auto', + use_unstructured_pdf='auto', + use_pypdf='auto', + enable_pdf_ocr='auto', + enable_pdf_doctr='auto', + try_pdf_as_html='auto', + + # images + enable_ocr=False, + enable_doctr=False, + enable_pix2struct=False, + enable_captions=True, + captions_model=None, + caption_loader=None, + doctr_loader=None, + pix2struct_loader=None, + + # json + jq_schema='.[]', + + langchain_mode_paths=None, + langchain_mode_types=None, + detect_user_path_changes_every_query=False, + db_type='faiss', + model_name=None, + inference_server='', + max_new_tokens=None, + langchain_only_model=False, + hf_embedding_model=None, + migrate_embedding_model=False, + auto_migrate_db=False, + prompter=None, + prompt_type=None, + prompt_dict=None, + system_prompt=None, + cut_distance=1.1, + add_chat_history_to_context=True, # FIXME: https://github.com/hwchase17/langchain/issues/6638 + add_search_to_context=False, + keep_sources_in_context=False, + memory_restriction_level=0, + top_k_docs_max_show=10, + + load_db_if_exists=False, + db=None, + langchain_mode=None, + langchain_action=None, + langchain_agents=None, + document_subset=DocumentSubset.Relevant.name, + document_choice=[DocumentChoice.ALL.value], + pre_prompt_query=None, + prompt_query=None, + pre_prompt_summary=None, + prompt_summary=None, + text_context_list=None, + chat_conversation=None, + + n_jobs=-1, + # beyond run_db_query: + llm=None, + tokenizer=None, + verbose=False, + docs_ordering_type='reverse_ucurve_sort', + min_max_new_tokens=256, + stream_output=True, + async_output=True, + + # local + auto_reduce_chunks=True, + max_chunks=100, + total_tokens_for_docs=None, + use_llm_if_no_docs=None, + headsize=50, + ): + if inference_server is None: + inference_server = '' + assert hf_embedding_model is not None + assert langchain_agents is not None # should be at least [] + if text_context_list is None: + text_context_list = [] + + # default value: + llm_mode = langchain_mode in ['Disabled', 'LLM'] and len(text_context_list) == 0 + query_action = langchain_action == LangChainAction.QUERY.value + summarize_action = langchain_action in [LangChainAction.SUMMARIZE_MAP.value, + LangChainAction.SUMMARIZE_ALL.value, + LangChainAction.SUMMARIZE_REFINE.value] + + if len(text_context_list) > 0: + # turn into documents to make easy to manage and add meta + # try to account for summarization vs. query + chunk_id = 0 if query_action else -1 + text_context_list = [ + Document(page_content=x, metadata=dict(source='text_context_list', score=1.0, chunk_id=chunk_id)) for x + in text_context_list] + + if add_search_to_context: + params = { + "engine": "duckduckgo", + "gl": "us", + "hl": "en", + } + search = H2OSerpAPIWrapper(params=params) + # if doing search, allow more docs + docs_search, top_k_docs = search.get_search_documents(query, + query_action=query_action, + chunk=chunk, chunk_size=chunk_size, + db_type=db_type, + headsize=headsize, + top_k_docs=top_k_docs) + text_context_list = docs_search + text_context_list + add_search_to_context &= len(docs_search) > 0 + top_k_docs_max_show = max(top_k_docs_max_show, len(docs_search)) + + if len(text_context_list) > 0: + llm_mode = False + use_llm_if_no_docs = True + + from src.output_parser import H2OMRKLOutputParser + from langchain.agents import AgentType, load_tools, initialize_agent, create_vectorstore_agent, \ + create_pandas_dataframe_agent, create_json_agent, create_csv_agent + from langchain.agents.agent_toolkits import VectorStoreInfo, VectorStoreToolkit, create_python_agent, JsonToolkit + if LangChainAgent.SEARCH.value in langchain_agents: + output_parser = H2OMRKLOutputParser() + tools = load_tools(["serpapi"], llm=llm, serpapi_api_key=os.environ.get('SERPAPI_API_KEY')) + if inference_server.startswith('openai'): + agent_type = AgentType.OPENAI_FUNCTIONS + agent_executor_kwargs = {"handle_parsing_errors": True, 'output_parser': output_parser} + else: + agent_type = AgentType.ZERO_SHOT_REACT_DESCRIPTION + agent_executor_kwargs = {'output_parser': output_parser} + chain = initialize_agent(tools, llm, agent=agent_type, + agent_executor_kwargs=agent_executor_kwargs, + agent_kwargs=dict(output_parser=output_parser, + format_instructions=output_parser.get_format_instructions()), + output_parser=output_parser, + max_iterations=10, + verbose=True) + chain_kwargs = dict(input=query) + target = wrapped_partial(chain, chain_kwargs) + + docs = [] + scores = [] + use_docs_planned = False + num_docs_before_cut = 0 + use_llm_if_no_docs = True + return docs, target, scores, use_docs_planned, num_docs_before_cut, use_llm_if_no_docs, llm_mode, top_k_docs_max_show + + if LangChainAgent.COLLECTION.value in langchain_agents: + output_parser = H2OMRKLOutputParser() + vectorstore_info = VectorStoreInfo( + name=langchain_mode, + description="DataBase of text from PDFs, Image Captions, or web URL content", + vectorstore=db, + ) + toolkit = VectorStoreToolkit(vectorstore_info=vectorstore_info) + chain = create_vectorstore_agent(llm=llm, toolkit=toolkit, + agent_executor_kwargs=dict(output_parser=output_parser), + verbose=True) + + chain_kwargs = dict(input=query) + target = wrapped_partial(chain, chain_kwargs) + + docs = [] + scores = [] + use_docs_planned = False + num_docs_before_cut = 0 + use_llm_if_no_docs = True + return docs, target, scores, use_docs_planned, num_docs_before_cut, use_llm_if_no_docs, llm_mode, top_k_docs_max_show + + if LangChainAgent.PYTHON.value in langchain_agents and inference_server.startswith('openai'): + chain = create_python_agent( + llm=llm, + tool=PythonREPLTool(), + verbose=True, + agent_type=AgentType.OPENAI_FUNCTIONS, + agent_executor_kwargs={"handle_parsing_errors": True}, + ) + + chain_kwargs = dict(input=query) + target = wrapped_partial(chain, chain_kwargs) + + docs = [] + scores = [] + use_docs_planned = False + num_docs_before_cut = 0 + use_llm_if_no_docs = True + return docs, target, scores, use_docs_planned, num_docs_before_cut, use_llm_if_no_docs, llm_mode, top_k_docs_max_show + + if LangChainAgent.PANDAS.value in langchain_agents and inference_server.startswith('openai_chat'): + # FIXME: DATA + df = pd.DataFrame(None) + chain = create_pandas_dataframe_agent( + llm, + df, + verbose=True, + agent_type=AgentType.OPENAI_FUNCTIONS, + ) + + chain_kwargs = dict(input=query) + target = wrapped_partial(chain, chain_kwargs) + + docs = [] + scores = [] + use_docs_planned = False + num_docs_before_cut = 0 + use_llm_if_no_docs = True + return docs, target, scores, use_docs_planned, num_docs_before_cut, use_llm_if_no_docs, llm_mode, top_k_docs_max_show + + if isinstance(document_choice, str): + document_choice = [document_choice] + if document_choice and document_choice[0] == DocumentChoice.ALL.value: + document_choice_agent = document_choice[1:] + else: + document_choice_agent = document_choice + document_choice_agent = [x for x in document_choice_agent if x.endswith('.json')] + if LangChainAgent.JSON.value in \ + langchain_agents and \ + inference_server.startswith('openai_chat') and \ + len(document_choice_agent) == 1 and \ + document_choice_agent[0].endswith('.json'): + # with open('src/openai.yaml') as f: + # data = yaml.load(f, Loader=yaml.FullLoader) + with open(document_choice[0], 'rt') as f: + data = json.loads(f.read()) + json_spec = JsonSpec(dict_=data, max_value_length=4000) + json_toolkit = JsonToolkit(spec=json_spec) + + chain = create_json_agent( + llm=llm, toolkit=json_toolkit, verbose=True + ) + + chain_kwargs = dict(input=query) + target = wrapped_partial(chain, chain_kwargs) + + docs = [] + scores = [] + use_docs_planned = False + num_docs_before_cut = 0 + use_llm_if_no_docs = True + return docs, target, scores, use_docs_planned, num_docs_before_cut, use_llm_if_no_docs, llm_mode, top_k_docs_max_show + + if isinstance(document_choice, str): + document_choice = [document_choice] + if document_choice and document_choice[0] == DocumentChoice.ALL.value: + document_choice_agent = document_choice[1:] + else: + document_choice_agent = document_choice + document_choice_agent = [x for x in document_choice_agent if x.endswith('.csv')] + if LangChainAgent.CSV.value in langchain_agents and len(document_choice_agent) == 1 and document_choice_agent[ + 0].endswith( + '.csv'): + data_file = document_choice[0] + if inference_server.startswith('openai_chat'): + chain = create_csv_agent( + llm, + data_file, + verbose=True, + agent_type=AgentType.OPENAI_FUNCTIONS, + ) + else: + chain = create_csv_agent( + llm, + data_file, + verbose=True, + agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION, + ) + chain_kwargs = dict(input=query) + target = wrapped_partial(chain, chain_kwargs) + + docs = [] + scores = [] + use_docs_planned = False + num_docs_before_cut = 0 + use_llm_if_no_docs = True + return docs, target, scores, use_docs_planned, num_docs_before_cut, use_llm_if_no_docs, llm_mode, top_k_docs_max_show + + # determine whether use of context out of docs is planned + if not use_openai_model and prompt_type not in ['plain'] or langchain_only_model: + if llm_mode: + use_docs_planned = False + else: + use_docs_planned = True + else: + use_docs_planned = True + + # https://github.com/hwchase17/langchain/issues/1946 + # FIXME: Seems to way to get size of chroma db to limit top_k_docs to avoid + # Chroma collection MyData contains fewer than 4 elements. + # type logger error + if top_k_docs == -1: + k_db = 1000 if db_type in ['chroma', 'chroma_old'] else 100 + else: + # top_k_docs=100 works ok too + k_db = 1000 if db_type in ['chroma', 'chroma_old'] else top_k_docs + + # FIXME: For All just go over all dbs instead of a separate db for All + if not detect_user_path_changes_every_query and db is not None: + # avoid looking at user_path during similarity search db handling, + # if already have db and not updating from user_path every query + # but if db is None, no db yet loaded (e.g. from prep), so allow user_path to be whatever it was + if langchain_mode_paths is None: + langchain_mode_paths = {} + langchain_mode_paths = langchain_mode_paths.copy() + langchain_mode_paths[langchain_mode] = None + # once use_openai_embedding, hf_embedding_model passed in, possibly changed, + # but that's ok as not used below or in calling functions + db, num_new_sources, new_sources_metadata = make_db(use_openai_embedding=use_openai_embedding, + hf_embedding_model=hf_embedding_model, + migrate_embedding_model=migrate_embedding_model, + auto_migrate_db=auto_migrate_db, + first_para=first_para, text_limit=text_limit, + chunk=chunk, chunk_size=chunk_size, + + # urls + use_unstructured=use_unstructured, + use_playwright=use_playwright, + use_selenium=use_selenium, + + # pdfs + use_pymupdf=use_pymupdf, + use_unstructured_pdf=use_unstructured_pdf, + use_pypdf=use_pypdf, + enable_pdf_ocr=enable_pdf_ocr, + enable_pdf_doctr=enable_pdf_doctr, + try_pdf_as_html=try_pdf_as_html, + + # images + enable_ocr=enable_ocr, + enable_doctr=enable_doctr, + enable_pix2struct=enable_pix2struct, + enable_captions=enable_captions, + captions_model=captions_model, + caption_loader=caption_loader, + doctr_loader=doctr_loader, + pix2struct_loader=pix2struct_loader, + + # json + jq_schema=jq_schema, + + langchain_mode=langchain_mode, + langchain_mode_paths=langchain_mode_paths, + langchain_mode_types=langchain_mode_types, + db_type=db_type, + load_db_if_exists=load_db_if_exists, + db=db, + n_jobs=n_jobs, + verbose=verbose) + num_docs_before_cut = 0 + use_template = not use_openai_model and prompt_type not in ['plain'] or langchain_only_model + got_db_docs = False # not yet at least + template, template_if_no_docs, auto_reduce_chunks, query = \ + get_template(query, iinput, + pre_prompt_query, prompt_query, + pre_prompt_summary, prompt_summary, + langchain_action, + llm_mode, + use_docs_planned, + auto_reduce_chunks, + got_db_docs, + add_search_to_context) + + max_input_tokens = get_max_input_tokens(llm=llm, tokenizer=tokenizer, inference_server=inference_server, + model_name=model_name, max_new_tokens=max_new_tokens) + + if (db or text_context_list) and use_docs_planned: + if hasattr(db, '_persist_directory'): + lock_file = get_db_lock_file(db, lock_type='sim') + else: + base_path = 'locks' + base_path = makedirs(base_path, exist_ok=True, tmp_ok=True, use_base=True) + name_path = "sim.lock" + lock_file = os.path.join(base_path, name_path) + + if not (isinstance(db, Chroma) or isinstance(db, ChromaMig) or ChromaMig.__name__ in str(db)): + # only chroma supports filtering + filter_kwargs = {} + filter_kwargs_backup = {} + else: + import logging + logging.getLogger("chromadb").setLevel(logging.ERROR) + assert document_choice is not None, "Document choice was None" + if isinstance(db, Chroma): + filter_kwargs_backup = {} # shouldn't ever need backup + # chroma >= 0.4 + if len(document_choice) == 0 or len(document_choice) >= 1 and document_choice[ + 0] == DocumentChoice.ALL.value: + filter_kwargs = {"filter": {"chunk_id": {"$gte": 0}}} if query_action else \ + {"filter": {"chunk_id": {"$eq": -1}}} + else: + if document_choice[0] == DocumentChoice.ALL.value: + document_choice = document_choice[1:] + if len(document_choice) == 0: + filter_kwargs = {} + elif len(document_choice) > 1: + or_filter = [ + {"$and": [dict(source={"$eq": x}), dict(chunk_id={"$gte": 0})]} if query_action else { + "$and": [dict(source={"$eq": x}), dict(chunk_id={"$eq": -1})]} + for x in document_choice] + filter_kwargs = dict(filter={"$or": or_filter}) + else: + # still chromadb UX bug, have to do different thing for 1 vs. 2+ docs when doing filter + one_filter = \ + [{"source": {"$eq": x}, "chunk_id": {"$gte": 0}} if query_action else { + "source": {"$eq": x}, + "chunk_id": { + "$eq": -1}} + for x in document_choice][0] + + filter_kwargs = dict(filter={"$and": [dict(source=one_filter['source']), + dict(chunk_id=one_filter['chunk_id'])]}) + else: + # migration for chroma < 0.4 + if len(document_choice) == 0 or len(document_choice) >= 1 and document_choice[ + 0] == DocumentChoice.ALL.value: + filter_kwargs = {"filter": {"chunk_id": {"$gte": 0}}} if query_action else \ + {"filter": {"chunk_id": {"$eq": -1}}} + filter_kwargs_backup = {"filter": {"chunk_id": {"$gte": 0}}} + elif len(document_choice) >= 2: + if document_choice[0] == DocumentChoice.ALL.value: + document_choice = document_choice[1:] + or_filter = [ + {"source": {"$eq": x}, "chunk_id": {"$gte": 0}} if query_action else {"source": {"$eq": x}, + "chunk_id": { + "$eq": -1}} + for x in document_choice] + filter_kwargs = dict(filter={"$or": or_filter}) + or_filter_backup = [ + {"source": {"$eq": x}} if query_action else {"source": {"$eq": x}} + for x in document_choice] + filter_kwargs_backup = dict(filter={"$or": or_filter_backup}) + elif len(document_choice) == 1: + # degenerate UX bug in chroma + one_filter = \ + [{"source": {"$eq": x}, "chunk_id": {"$gte": 0}} if query_action else {"source": {"$eq": x}, + "chunk_id": { + "$eq": -1}} + for x in document_choice][0] + filter_kwargs = dict(filter=one_filter) + one_filter_backup = \ + [{"source": {"$eq": x}} if query_action else {"source": {"$eq": x}} + for x in document_choice][0] + filter_kwargs_backup = dict(filter=one_filter_backup) + else: + # shouldn't reach + filter_kwargs = {} + filter_kwargs_backup = {} + + if llm_mode: + docs = [] + scores = [] + elif document_subset == DocumentSubset.TopKSources.name or query in [None, '', '\n']: + db_documents, db_metadatas = get_docs_and_meta(db, top_k_docs, filter_kwargs=filter_kwargs, + text_context_list=text_context_list) + if len(db_documents) == 0 and filter_kwargs_backup: + db_documents, db_metadatas = get_docs_and_meta(db, top_k_docs, filter_kwargs=filter_kwargs_backup, + text_context_list=text_context_list) + + if top_k_docs == -1: + top_k_docs = len(db_documents) + # similar to langchain's chroma's _results_to_docs_and_scores + docs_with_score = [(Document(page_content=result[0], metadata=result[1] or {}), 0) + for result in zip(db_documents, db_metadatas)] + # set in metadata original order of docs + [x[0].metadata.update(orig_index=ii) for ii, x in enumerate(docs_with_score)] + + # order documents + doc_hashes = [x.get('doc_hash', 'None') for x in db_metadatas] + if query_action: + doc_chunk_ids = [x.get('chunk_id', 0) for x in db_metadatas] + docs_with_score2 = [x for hx, cx, x in + sorted(zip(doc_hashes, doc_chunk_ids, docs_with_score), key=lambda x: (x[0], x[1])) + if cx >= 0] + else: + assert summarize_action + doc_chunk_ids = [x.get('chunk_id', -1) for x in db_metadatas] + docs_with_score2 = [x for hx, cx, x in + sorted(zip(doc_hashes, doc_chunk_ids, docs_with_score), key=lambda x: (x[0], x[1])) + if cx == -1 + ] + if len(docs_with_score2) == 0 and len(docs_with_score) > 0: + # old database without chunk_id, migration added 0 but didn't make -1 as that would be expensive + # just do again and relax filter, let summarize operate on actual chunks if nothing else + docs_with_score2 = [x for hx, cx, x in + sorted(zip(doc_hashes, doc_chunk_ids, docs_with_score), + key=lambda x: (x[0], x[1])) + ] + docs_with_score = docs_with_score2 + + docs_with_score = docs_with_score[:top_k_docs] + docs = [x[0] for x in docs_with_score] + scores = [x[1] for x in docs_with_score] + num_docs_before_cut = len(docs) + else: + with filelock.FileLock(lock_file): + docs_with_score, got_db_docs = get_docs_with_score(query, k_db, filter_kwargs, db, db_type, + text_context_list=text_context_list, + verbose=verbose) + if len(docs_with_score) == 0 and filter_kwargs_backup: + docs_with_score, got_db_docs = get_docs_with_score(query, k_db, filter_kwargs_backup, db, + db_type, + text_context_list=text_context_list, + verbose=verbose) + + tokenizer = get_tokenizer(db=db, llm=llm, tokenizer=tokenizer, inference_server=inference_server, + use_openai_model=use_openai_model, + db_type=db_type) + # NOTE: if map_reduce, then no need to auto reduce chunks + if query_action and (top_k_docs == -1 or auto_reduce_chunks): + top_k_docs_tokenize = 100 + docs_with_score = docs_with_score[:top_k_docs_tokenize] + + prompt_no_docs = template.format(context='', question=query) + + model_max_length = tokenizer.model_max_length + chat = True # FIXME? + + # first docs_with_score are most important with highest score + full_prompt, \ + instruction, iinput, context, \ + num_prompt_tokens, max_new_tokens, \ + num_prompt_tokens0, num_prompt_tokens_actual, \ + chat_index, top_k_docs_trial, one_doc_size = \ + get_limited_prompt(prompt_no_docs, + iinput, + tokenizer, + prompter=prompter, + inference_server=inference_server, + prompt_type=prompt_type, + prompt_dict=prompt_dict, + chat=chat, + max_new_tokens=max_new_tokens, + system_prompt=system_prompt, + context=context, + chat_conversation=chat_conversation, + text_context_list=[x[0].page_content for x in docs_with_score], + keep_sources_in_context=keep_sources_in_context, + model_max_length=model_max_length, + memory_restriction_level=memory_restriction_level, + langchain_mode=langchain_mode, + add_chat_history_to_context=add_chat_history_to_context, + min_max_new_tokens=min_max_new_tokens, + ) + # avoid craziness + if 0 < top_k_docs_trial < max_chunks: + # avoid craziness + if top_k_docs == -1: + top_k_docs = top_k_docs_trial + else: + top_k_docs = min(top_k_docs, top_k_docs_trial) + elif top_k_docs_trial >= max_chunks: + top_k_docs = max_chunks + if top_k_docs > 0: + docs_with_score = docs_with_score[:top_k_docs] + elif one_doc_size is not None: + docs_with_score = [docs_with_score[0][:one_doc_size]] + else: + docs_with_score = [] + else: + if total_tokens_for_docs is not None: + # used to limit tokens for summarization, e.g. public instance + top_k_docs, one_doc_size, num_doc_tokens = \ + get_docs_tokens(tokenizer, + text_context_list=[x[0].page_content for x in docs_with_score], + max_input_tokens=total_tokens_for_docs) + + docs_with_score = docs_with_score[:top_k_docs] + + # put most relevant chunks closest to question, + # esp. if truncation occurs will be "oldest" or "farthest from response" text that is truncated + # BUT: for small models, e.g. 6_9 pythia, if sees some stuff related to h2oGPT first, it can connect that and not listen to rest + if docs_ordering_type in ['best_first']: + pass + elif docs_ordering_type in ['best_near_prompt', 'reverse_sort']: + docs_with_score.reverse() + elif docs_ordering_type in ['', None, 'reverse_ucurve_sort']: + docs_with_score = reverse_ucurve_list(docs_with_score) + else: + raise ValueError("No such docs_ordering_type=%s" % docs_ordering_type) + + # cut off so no high distance docs/sources considered + num_docs_before_cut = len(docs_with_score) + docs = [x[0] for x in docs_with_score if x[1] < cut_distance] + scores = [x[1] for x in docs_with_score if x[1] < cut_distance] + if len(scores) > 0 and verbose: + print("Distance: min: %s max: %s mean: %s median: %s" % + (scores[0], scores[-1], np.mean(scores), np.median(scores)), flush=True) + else: + docs = [] + scores = [] + + if not docs and use_docs_planned and not langchain_only_model: + # if HF type and have no docs, can bail out + return docs, None, [], False, num_docs_before_cut, use_llm_if_no_docs, llm_mode, top_k_docs_max_show + + if document_subset in non_query_commands: + # no LLM use + return docs, None, [], False, num_docs_before_cut, use_llm_if_no_docs, llm_mode, top_k_docs_max_show + + # FIXME: WIP + common_words_file = "data/NGSL_1.2_stats.csv.zip" + if False and os.path.isfile(common_words_file) and langchain_action == LangChainAction.QUERY.value: + df = pd.read_csv("data/NGSL_1.2_stats.csv.zip") + import string + reduced_query = query.translate(str.maketrans(string.punctuation, ' ' * len(string.punctuation))).strip() + reduced_query_words = reduced_query.split(' ') + set_common = set(df['Lemma'].values.tolist()) + num_common = len([x.lower() in set_common for x in reduced_query_words]) + frac_common = num_common / len(reduced_query) if reduced_query else 0 + # FIXME: report to user bad query that uses too many common words + if verbose: + print("frac_common: %s" % frac_common, flush=True) + + if len(docs) == 0: + # avoid context == in prompt then + use_docs_planned = False + template = template_if_no_docs + + got_db_docs = got_db_docs and len(text_context_list) < len(docs) + # update template in case situation changed or did get docs + # then no new documents from database or not used, redo template + # got template earlier as estimate of template token size, here is final used version + template, template_if_no_docs, auto_reduce_chunks, query = \ + get_template(query, iinput, + pre_prompt_query, prompt_query, + pre_prompt_summary, prompt_summary, + langchain_action, + llm_mode, + use_docs_planned, + auto_reduce_chunks, + got_db_docs, + add_search_to_context) + + if langchain_action == LangChainAction.QUERY.value: + if use_template: + # instruct-like, rather than few-shot prompt_type='plain' as default + # but then sources confuse the model with how inserted among rest of text, so avoid + prompt = PromptTemplate( + # input_variables=["summaries", "question"], + input_variables=["context", "question"], + template=template, + ) + chain = load_qa_chain(llm, prompt=prompt, verbose=verbose) + else: + # only if use_openai_model = True, unused normally except in testing + chain = load_qa_with_sources_chain(llm) + if not use_docs_planned: + chain_kwargs = dict(input_documents=[], question=query) + else: + chain_kwargs = dict(input_documents=docs, question=query) + target = wrapped_partial(chain, chain_kwargs) + elif langchain_action in [LangChainAction.SUMMARIZE_MAP.value, + LangChainAction.SUMMARIZE_REFINE, + LangChainAction.SUMMARIZE_ALL.value]: + if async_output: + return_intermediate_steps = False + else: + return_intermediate_steps = True + from langchain.chains.summarize import load_summarize_chain + if langchain_action == LangChainAction.SUMMARIZE_MAP.value: + prompt = PromptTemplate(input_variables=["text"], template=template) + chain = load_summarize_chain(llm, chain_type="map_reduce", + map_prompt=prompt, combine_prompt=prompt, + return_intermediate_steps=return_intermediate_steps, + token_max=max_input_tokens, verbose=verbose) + if async_output: + chain_func = chain.arun + else: + chain_func = chain + target = wrapped_partial(chain_func, {"input_documents": docs}) # , return_only_outputs=True) + elif langchain_action == LangChainAction.SUMMARIZE_ALL.value: + assert use_template + prompt = PromptTemplate(input_variables=["text"], template=template) + chain = load_summarize_chain(llm, chain_type="stuff", prompt=prompt, + return_intermediate_steps=return_intermediate_steps, verbose=verbose) + if async_output: + chain_func = chain.arun + else: + chain_func = chain + target = wrapped_partial(chain_func) + elif langchain_action == LangChainAction.SUMMARIZE_REFINE.value: + chain = load_summarize_chain(llm, chain_type="refine", + return_intermediate_steps=return_intermediate_steps, verbose=verbose) + if async_output: + chain_func = chain.arun + else: + chain_func = chain + target = wrapped_partial(chain_func) + else: + raise RuntimeError("No such langchain_action=%s" % langchain_action) + else: + raise RuntimeError("No such langchain_action=%s" % langchain_action) + + return docs, target, scores, use_docs_planned, num_docs_before_cut, use_llm_if_no_docs, llm_mode, top_k_docs_max_show + + +def get_max_model_length(llm=None, tokenizer=None, inference_server=None, model_name=None): + if hasattr(tokenizer, 'model_max_length'): + return tokenizer.model_max_length + elif inference_server in ['openai', 'openai_azure']: + return llm.modelname_to_contextsize(model_name) + elif inference_server in ['openai_chat', 'openai_azure_chat']: + return model_token_mapping[model_name] + elif isinstance(tokenizer, FakeTokenizer): + # GGML + return tokenizer.model_max_length + else: + return 2048 + + +def get_max_input_tokens(llm=None, tokenizer=None, inference_server=None, model_name=None, max_new_tokens=None): + model_max_length = get_max_model_length(llm=llm, tokenizer=tokenizer, inference_server=inference_server, + model_name=model_name) + + if any([inference_server.startswith(x) for x in + ['openai', 'openai_azure', 'openai_chat', 'openai_azure_chat', 'vllm']]): + # openai can't handle tokens + max_new_tokens > max_tokens even if never generate those tokens + # and vllm uses OpenAI API with same limits + max_input_tokens = model_max_length - max_new_tokens + elif isinstance(tokenizer, FakeTokenizer): + # don't trust that fake tokenizer (e.g. GGML) will make lots of tokens normally, allow more input + max_input_tokens = model_max_length - min(256, max_new_tokens) + else: + if 'falcon' in model_name or inference_server.startswith('http'): + # allow for more input for falcon, assume won't make as long outputs as default max_new_tokens + # Also allow if TGI or Gradio, because we tell it input may be same as output, even if model can't actually handle + max_input_tokens = model_max_length - min(256, max_new_tokens) + else: + # trust that maybe model will make so many tokens, so limit input + max_input_tokens = model_max_length - max_new_tokens + + return max_input_tokens + + +def get_tokenizer(db=None, llm=None, tokenizer=None, inference_server=None, use_openai_model=False, + db_type='chroma'): + if hasattr(llm, 'pipeline') and hasattr(llm.pipeline, 'tokenizer'): + # more accurate + return llm.pipeline.tokenizer + elif hasattr(llm, 'tokenizer'): + # e.g. TGI client mode etc. + return llm.tokenizer + elif inference_server in ['openai', 'openai_chat', 'openai_azure', + 'openai_azure_chat']: + return tokenizer + elif isinstance(tokenizer, FakeTokenizer): + return tokenizer + elif use_openai_model: + return FakeTokenizer() + elif (hasattr(db, '_embedding_function') and + hasattr(db._embedding_function, 'client') and + hasattr(db._embedding_function.client, 'tokenize')): + # in case model is not our pipeline with HF tokenizer + return db._embedding_function.client.tokenize + else: + # backup method + if os.getenv('HARD_ASSERTS'): + assert db_type in ['faiss', 'weaviate'] + # use tiktoken for faiss since embedding called differently + return FakeTokenizer() + + +def get_template(query, iinput, + pre_prompt_query, prompt_query, + pre_prompt_summary, prompt_summary, + langchain_action, + llm_mode, + use_docs_planned, + auto_reduce_chunks, + got_db_docs, + add_search_to_context): + if got_db_docs and add_search_to_context: + # modify prompts, assumes patterns like in predefined prompts. If user customizes, then they'd need to account for that. + prompt_query = prompt_query.replace('information in the document sources', + 'information in the document and web search sources (and their source dates and website source)') + prompt_summary = prompt_summary.replace('information in the document sources', + 'information in the document and web search sources (and their source dates and website source)') + elif got_db_docs and not add_search_to_context: + pass + elif not got_db_docs and add_search_to_context: + # modify prompts, assumes patterns like in predefined prompts. If user customizes, then they'd need to account for that. + prompt_query = prompt_query.replace('information in the document sources', + 'information in the web search sources (and their source dates and website source)') + prompt_summary = prompt_summary.replace('information in the document sources', + 'information in the web search sources (and their source dates and website source)') + + if langchain_action == LangChainAction.QUERY.value: + if iinput: + query = "%s\n%s" % (query, iinput) + if llm_mode or not use_docs_planned: + template_if_no_docs = template = """{context}{question}""" + else: + template = """%s +\"\"\" +{context} +\"\"\" +%s{question}""" % (pre_prompt_query, prompt_query) + template_if_no_docs = """{context}{question}""" + elif langchain_action in [LangChainAction.SUMMARIZE_ALL.value, LangChainAction.SUMMARIZE_MAP.value]: + none = ['', '\n', None] + + # modify prompt_summary if user passes query or iinput + if query not in none and iinput not in none: + prompt_summary = "Focusing on %s, %s, %s" % (query, iinput, prompt_summary) + elif query not in none: + prompt_summary = "Focusing on %s, %s" % (query, prompt_summary) + # don't auto reduce + auto_reduce_chunks = False + if langchain_action == LangChainAction.SUMMARIZE_MAP.value: + fstring = '{text}' + else: + fstring = '{input_documents}' + template = """%s: +\"\"\" +%s +\"\"\"\n%s""" % (pre_prompt_summary, fstring, prompt_summary) + template_if_no_docs = "Exactly only say: There are no documents to summarize." + elif langchain_action in [LangChainAction.SUMMARIZE_REFINE]: + template = '' # unused + template_if_no_docs = '' # unused + else: + raise RuntimeError("No such langchain_action=%s" % langchain_action) + + return template, template_if_no_docs, auto_reduce_chunks, query + + +def get_sources_answer(query, docs, answer, scores, show_rank, + answer_with_sources, append_sources_to_answer, + show_accordions=True, + show_link_in_sources=True, + top_k_docs_max_show=10, + docs_ordering_type='reverse_ucurve_sort', + num_docs_before_cut=0, + verbose=False, + t_run=None, + count_input_tokens=None, count_output_tokens=None): + if verbose: + print("query: %s" % query, flush=True) + print("answer: %s" % answer, flush=True) + + if len(docs) == 0: + extra = '' + ret = answer + extra + return ret, extra + + if answer_with_sources == -1: + extra = [dict(score=score, content=get_doc(x), source=get_source(x), orig_index=x.metadata.get('orig_index', 0)) + for score, x in zip(scores, docs)][ + :top_k_docs_max_show] + if append_sources_to_answer: + extra_str = [str(x) for x in extra] + ret = answer + '\n\n' + '\n'.join(extra_str) + else: + ret = answer + return ret, extra + + # link + answer_sources = [(max(0.0, 1.5 - score) / 1.5, + get_url(doc, font_size=font_size), + get_accordion(doc, font_size=font_size, head_acc=head_acc)) for score, doc in + zip(scores, docs)] + if not show_accordions: + answer_sources_dict = defaultdict(list) + [answer_sources_dict[url].append(score) for score, url in answer_sources] + answers_dict = {} + for url, scores_url in answer_sources_dict.items(): + answers_dict[url] = np.max(scores_url) + answer_sources = [(score, url) for url, score in answers_dict.items()] + answer_sources.sort(key=lambda x: x[0], reverse=True) + if show_rank: + # answer_sources = ['%d | %s' % (1 + rank, url) for rank, (score, url) in enumerate(answer_sources)] + # sorted_sources_urls = "Sources [Rank | Link]:
" + "
".join(answer_sources) + answer_sources = ['%s' % url for rank, (score, url) in enumerate(answer_sources)] + answer_sources = answer_sources[:top_k_docs_max_show] + sorted_sources_urls = "Ranked Sources:
" + "
".join(answer_sources) + else: + if show_accordions: + if show_link_in_sources: + answer_sources = ['
  • %.2g | %s
  • %s
    ' % (font_size, score, url, accordion) + for score, url, accordion in answer_sources] + else: + answer_sources = ['
  • %.2g
  • %s
    ' % (font_size, score, accordion) + for score, url, accordion in answer_sources] + else: + if show_link_in_sources: + answer_sources = ['
  • %.2g | %s
  • ' % (font_size, score, url) + for score, url in answer_sources] + else: + answer_sources = ['
  • %.2g
  • ' % (font_size, score) + for score, url in answer_sources] + answer_sources = answer_sources[:top_k_docs_max_show] + if show_accordions: + sorted_sources_urls = f"{source_prefix}