diff --git "a/gpt_langchain.py" "b/gpt_langchain.py" deleted file mode 100644--- "a/gpt_langchain.py" +++ /dev/null @@ -1,5443 +0,0 @@ -import ast -import asyncio -import copy -import functools -import glob -import gzip -import inspect -import json -import os -import pathlib -import pickle -import shutil -import subprocess -import tempfile -import time -import traceback -import types -import typing -import urllib.error -import uuid -import zipfile -from collections import defaultdict -from datetime import datetime -from functools import reduce -from operator import concat -import filelock -import tabulate -import yaml - -from joblib import delayed -from langchain.callbacks import streaming_stdout -from langchain.embeddings import HuggingFaceInstructEmbeddings -from langchain.llms.huggingface_pipeline import VALID_TASKS -from langchain.llms.utils import enforce_stop_tokens -from langchain.schema import LLMResult, Generation -from langchain.tools import PythonREPLTool -from langchain.tools.json.tool import JsonSpec -from tqdm import tqdm - -from utils import wrapped_partial, EThread, import_matplotlib, sanitize_filename, makedirs, get_url, flatten_list, \ - get_device, ProgressParallel, remove, hash_file, clear_torch_cache, NullContext, get_hf_server, FakeTokenizer, \ - have_libreoffice, have_arxiv, have_playwright, have_selenium, have_tesseract, have_doctr, have_pymupdf, set_openai, \ - get_list_or_str, have_pillow, only_selenium, only_playwright, only_unstructured_urls, get_sha, get_short_name, \ - get_accordion, have_jq, get_doc, get_source, have_chromamigdb, get_token_count, reverse_ucurve_list -from enums import DocumentSubset, no_lora_str, model_token_mapping, source_prefix, source_postfix, non_query_commands, \ - LangChainAction, LangChainMode, DocumentChoice, LangChainTypes, font_size, head_acc, super_source_prefix, \ - super_source_postfix, langchain_modes_intrinsic, get_langchain_prompts, LangChainAgent -from evaluate_params import gen_hyper, gen_hyper0 -from gen import get_model, SEED, get_limited_prompt, get_docs_tokens -from prompter import non_hf_types, PromptType, Prompter -from src.serpapi import H2OSerpAPIWrapper -from utils_langchain import StreamingGradioCallbackHandler, _chunk_sources, _add_meta, add_parser, fix_json_meta - -import_matplotlib() - -import numpy as np -import pandas as pd -import requests -from langchain.chains.qa_with_sources import load_qa_with_sources_chain -# , GCSDirectoryLoader, GCSFileLoader -# , OutlookMessageLoader # GPL3 -# ImageCaptionLoader, # use our own wrapper -# ReadTheDocsLoader, # no special file, some path, so have to give as special option -from langchain.document_loaders import PyPDFLoader, TextLoader, CSVLoader, PythonLoader, TomlLoader, \ - UnstructuredURLLoader, UnstructuredHTMLLoader, UnstructuredWordDocumentLoader, UnstructuredMarkdownLoader, \ - EverNoteLoader, UnstructuredEmailLoader, UnstructuredODTLoader, UnstructuredPowerPointLoader, \ - UnstructuredEPubLoader, UnstructuredImageLoader, UnstructuredRTFLoader, ArxivLoader, UnstructuredPDFLoader, \ - UnstructuredExcelLoader, JSONLoader -from langchain.text_splitter import Language -from langchain.chains.question_answering import load_qa_chain -from langchain.docstore.document import Document -from langchain import PromptTemplate, HuggingFaceTextGenInference, HuggingFacePipeline -from langchain.vectorstores import Chroma -from chromamig import ChromaMig - - -def split_list(input_list, split_size): - for i in range(0, len(input_list), split_size): - yield input_list[i:i + split_size] - - -def get_db(sources, use_openai_embedding=False, db_type='faiss', - persist_directory=None, load_db_if_exists=True, - langchain_mode='notset', - langchain_mode_paths={}, - langchain_mode_types={}, - collection_name=None, - hf_embedding_model=None, - migrate_embedding_model=False, - auto_migrate_db=False, - n_jobs=-1): - if not sources: - return None - user_path = langchain_mode_paths.get(langchain_mode) - if persist_directory is None: - langchain_type = langchain_mode_types.get(langchain_mode, LangChainTypes.EITHER.value) - persist_directory, langchain_type = get_persist_directory(langchain_mode, langchain_type=langchain_type) - langchain_mode_types[langchain_mode] = langchain_type - assert hf_embedding_model is not None - - # get freshly-determined embedding model - embedding = get_embedding(use_openai_embedding, hf_embedding_model=hf_embedding_model) - assert collection_name is not None or langchain_mode != 'notset' - if collection_name is None: - collection_name = langchain_mode.replace(' ', '_') - - # Create vector database - if db_type == 'faiss': - from langchain.vectorstores import FAISS - db = FAISS.from_documents(sources, embedding) - elif db_type == 'weaviate': - import weaviate - from weaviate.embedded import EmbeddedOptions - from langchain.vectorstores import Weaviate - - if os.getenv('WEAVIATE_URL', None): - client = _create_local_weaviate_client() - else: - client = weaviate.Client( - embedded_options=EmbeddedOptions(persistence_data_path=persist_directory) - ) - index_name = collection_name.capitalize() - db = Weaviate.from_documents(documents=sources, embedding=embedding, client=client, by_text=False, - index_name=index_name) - elif db_type in ['chroma', 'chroma_old']: - assert persist_directory is not None - # use_base already handled when making persist_directory, unless was passed into get_db() - makedirs(persist_directory, exist_ok=True) - - # see if already actually have persistent db, and deal with possible changes in embedding - db, use_openai_embedding, hf_embedding_model = \ - get_existing_db(None, persist_directory, load_db_if_exists, db_type, - use_openai_embedding, - langchain_mode, langchain_mode_paths, langchain_mode_types, - hf_embedding_model, migrate_embedding_model, auto_migrate_db, - verbose=False, - n_jobs=n_jobs) - if db is None: - import logging - logging.getLogger("chromadb").setLevel(logging.ERROR) - if db_type == 'chroma': - from chromadb.config import Settings - settings_extra_kwargs = dict(is_persistent=True) - else: - from chromamigdb.config import Settings - settings_extra_kwargs = dict(chroma_db_impl="duckdb+parquet") - client_settings = Settings(anonymized_telemetry=False, - persist_directory=persist_directory, - **settings_extra_kwargs) - if n_jobs in [None, -1]: - n_jobs = int(os.getenv('OMP_NUM_THREADS', str(os.cpu_count() // 2))) - num_threads = max(1, min(n_jobs, 8)) - else: - num_threads = max(1, n_jobs) - collection_metadata = {"hnsw:num_threads": num_threads} - from_kwargs = dict(embedding=embedding, - persist_directory=persist_directory, - collection_name=collection_name, - client_settings=client_settings, - collection_metadata=collection_metadata) - if db_type == 'chroma': - import chromadb - api = chromadb.PersistentClient(path=persist_directory) - max_batch_size = api._producer.max_batch_size - sources_batches = split_list(sources, max_batch_size) - for sources_batch in sources_batches: - db = Chroma.from_documents(documents=sources_batch, **from_kwargs) - db.persist() - else: - db = ChromaMig.from_documents(documents=sources, **from_kwargs) - clear_embedding(db) - save_embed(db, use_openai_embedding, hf_embedding_model) - else: - # then just add - # doesn't check or change embedding, just saves it in case not saved yet, after persisting - db, num_new_sources, new_sources_metadata = add_to_db(db, sources, db_type=db_type, - use_openai_embedding=use_openai_embedding, - hf_embedding_model=hf_embedding_model) - else: - raise RuntimeError("No such db_type=%s" % db_type) - - # once here, db is not changing and embedding choices in calling functions does not matter - return db - - -def _get_unique_sources_in_weaviate(db): - batch_size = 100 - id_source_list = [] - result = db._client.data_object.get(class_name=db._index_name, limit=batch_size) - - while result['objects']: - id_source_list += [(obj['id'], obj['properties']['source']) for obj in result['objects']] - last_id = id_source_list[-1][0] - result = db._client.data_object.get(class_name=db._index_name, limit=batch_size, after=last_id) - - unique_sources = {source for _, source in id_source_list} - return unique_sources - - -def del_from_db(db, sources, db_type=None): - if db_type in ['chroma', 'chroma_old'] and db is not None: - # sources should be list of x.metadata['source'] from document metadatas - if isinstance(sources, str): - sources = [sources] - else: - assert isinstance(sources, (list, tuple, types.GeneratorType)) - metadatas = set(sources) - client_collection = db._client.get_collection(name=db._collection.name, - embedding_function=db._collection._embedding_function) - for source in metadatas: - meta = dict(source=source) - try: - client_collection.delete(where=meta) - except KeyError: - pass - - -def add_to_db(db, sources, db_type='faiss', - avoid_dup_by_file=False, - avoid_dup_by_content=True, - use_openai_embedding=False, - hf_embedding_model=None): - assert hf_embedding_model is not None - num_new_sources = len(sources) - if not sources: - return db, num_new_sources, [] - if db_type == 'faiss': - db.add_documents(sources) - elif db_type == 'weaviate': - # FIXME: only control by file name, not hash yet - if avoid_dup_by_file or avoid_dup_by_content: - unique_sources = _get_unique_sources_in_weaviate(db) - sources = [x for x in sources if x.metadata['source'] not in unique_sources] - num_new_sources = len(sources) - if num_new_sources == 0: - return db, num_new_sources, [] - db.add_documents(documents=sources) - elif db_type in ['chroma', 'chroma_old']: - collection = get_documents(db) - # files we already have: - metadata_files = set([x['source'] for x in collection['metadatas']]) - if avoid_dup_by_file: - # Too weak in case file changed content, assume parent shouldn't pass true for this for now - raise RuntimeError("Not desired code path") - if avoid_dup_by_content: - # look at hash, instead of page_content - # migration: If no hash previously, avoid updating, - # since don't know if need to update and may be expensive to redo all unhashed files - metadata_hash_ids = set( - [x['hashid'] for x in collection['metadatas'] if 'hashid' in x and x['hashid'] not in ["None", None]]) - # avoid sources with same hash - sources = [x for x in sources if x.metadata.get('hashid') not in metadata_hash_ids] - num_nohash = len([x for x in sources if not x.metadata.get('hashid')]) - print("Found %s new sources (%d have no hash in original source," - " so have to reprocess for migration to sources with hash)" % (len(sources), num_nohash), flush=True) - # get new file names that match existing file names. delete existing files we are overridding - dup_metadata_files = set([x.metadata['source'] for x in sources if x.metadata['source'] in metadata_files]) - print("Removing %s duplicate files from db because ingesting those as new documents" % len( - dup_metadata_files), flush=True) - client_collection = db._client.get_collection(name=db._collection.name, - embedding_function=db._collection._embedding_function) - for dup_file in dup_metadata_files: - dup_file_meta = dict(source=dup_file) - try: - client_collection.delete(where=dup_file_meta) - except KeyError: - pass - num_new_sources = len(sources) - if num_new_sources == 0: - return db, num_new_sources, [] - if hasattr(db, '_persist_directory'): - print("Existing db, adding to %s" % db._persist_directory, flush=True) - # chroma only - lock_file = get_db_lock_file(db) - context = filelock.FileLock - else: - lock_file = None - context = NullContext - with context(lock_file): - # this is place where add to db, but others maybe accessing db, so lock access. - # else see RuntimeError: Index seems to be corrupted or unsupported - import chromadb - api = chromadb.PersistentClient(path=db._persist_directory) - max_batch_size = api._producer.max_batch_size - sources_batches = split_list(sources, max_batch_size) - for sources_batch in sources_batches: - db.add_documents(documents=sources_batch) - db.persist() - clear_embedding(db) - # save here is for migration, in case old db directory without embedding saved - save_embed(db, use_openai_embedding, hf_embedding_model) - else: - raise RuntimeError("No such db_type=%s" % db_type) - - new_sources_metadata = [x.metadata for x in sources] - - return db, num_new_sources, new_sources_metadata - - -def create_or_update_db(db_type, persist_directory, collection_name, - user_path, langchain_type, - sources, use_openai_embedding, add_if_exists, verbose, - hf_embedding_model, migrate_embedding_model, auto_migrate_db, - n_jobs=-1): - if not os.path.isdir(persist_directory) or not add_if_exists: - if os.path.isdir(persist_directory): - if verbose: - print("Removing %s" % persist_directory, flush=True) - remove(persist_directory) - if verbose: - print("Generating db", flush=True) - if db_type == 'weaviate': - import weaviate - from weaviate.embedded import EmbeddedOptions - - if os.getenv('WEAVIATE_URL', None): - client = _create_local_weaviate_client() - else: - client = weaviate.Client( - embedded_options=EmbeddedOptions(persistence_data_path=persist_directory) - ) - - index_name = collection_name.replace(' ', '_').capitalize() - if client.schema.exists(index_name) and not add_if_exists: - client.schema.delete_class(index_name) - if verbose: - print("Removing %s" % index_name, flush=True) - elif db_type in ['chroma', 'chroma_old']: - pass - - if not add_if_exists: - if verbose: - print("Generating db", flush=True) - else: - if verbose: - print("Loading and updating db", flush=True) - - db = get_db(sources, - use_openai_embedding=use_openai_embedding, - db_type=db_type, - persist_directory=persist_directory, - langchain_mode=collection_name, - langchain_mode_paths={collection_name: user_path}, - langchain_mode_types={collection_name: langchain_type}, - hf_embedding_model=hf_embedding_model, - migrate_embedding_model=migrate_embedding_model, - auto_migrate_db=auto_migrate_db, - n_jobs=n_jobs) - - return db - - -from langchain.embeddings import FakeEmbeddings - - -class H2OFakeEmbeddings(FakeEmbeddings): - """Fake embedding model, but constant instead of random""" - - size: int - """The size of the embedding vector.""" - - def _get_embedding(self) -> typing.List[float]: - return [1] * self.size - - def embed_documents(self, texts: typing.List[str]) -> typing.List[typing.List[float]]: - return [self._get_embedding() for _ in texts] - - def embed_query(self, text: str) -> typing.List[float]: - return self._get_embedding() - - -def get_embedding(use_openai_embedding, hf_embedding_model=None, preload=False): - assert hf_embedding_model is not None - # Get embedding model - if use_openai_embedding: - assert os.getenv("OPENAI_API_KEY") is not None, "Set ENV OPENAI_API_KEY" - from langchain.embeddings import OpenAIEmbeddings - embedding = OpenAIEmbeddings(disallowed_special=()) - elif hf_embedding_model == 'fake': - embedding = H2OFakeEmbeddings(size=1) - else: - if isinstance(hf_embedding_model, str): - pass - elif isinstance(hf_embedding_model, dict): - # embedding itself preloaded globally - return hf_embedding_model['model'] - else: - # object - return hf_embedding_model - # to ensure can fork without deadlock - from langchain.embeddings import HuggingFaceEmbeddings - - device, torch_dtype, context_class = get_device_dtype() - model_kwargs = dict(device=device) - if 'instructor' in hf_embedding_model: - encode_kwargs = {'normalize_embeddings': True} - embedding = HuggingFaceInstructEmbeddings(model_name=hf_embedding_model, - model_kwargs=model_kwargs, - encode_kwargs=encode_kwargs) - else: - embedding = HuggingFaceEmbeddings(model_name=hf_embedding_model, model_kwargs=model_kwargs) - embedding.client.preload = preload - return embedding - - -def get_answer_from_sources(chain, sources, question): - return chain( - { - "input_documents": sources, - "question": question, - }, - return_only_outputs=True, - )["output_text"] - - -"""Wrapper around Huggingface text generation inference API.""" -from functools import partial -from typing import Any, Dict, List, Optional, Set, Iterable - -from pydantic import Extra, Field, root_validator - -from langchain.callbacks.manager import CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun -from langchain.llms.base import LLM - - -class GradioInference(LLM): - """ - Gradio generation inference API. - """ - inference_server_url: str = "" - - temperature: float = 0.8 - top_p: Optional[float] = 0.95 - top_k: Optional[int] = None - num_beams: Optional[int] = 1 - max_new_tokens: int = 512 - min_new_tokens: int = 1 - early_stopping: bool = False - max_time: int = 180 - repetition_penalty: Optional[float] = None - num_return_sequences: Optional[int] = 1 - do_sample: bool = False - chat_client: bool = False - - return_full_text: bool = False - stream_output: bool = False - sanitize_bot_response: bool = False - - prompter: Any = None - context: Any = '' - iinput: Any = '' - client: Any = None - tokenizer: Any = None - - system_prompt: Any = None - visible_models: Any = None - h2ogpt_key: Any = None - - count_input_tokens: Any = 0 - count_output_tokens: Any = 0 - - min_max_new_tokens: Any = 256 - - class Config: - """Configuration for this pydantic object.""" - - extra = Extra.forbid - - @root_validator() - def validate_environment(cls, values: Dict) -> Dict: - """Validate that python package exists in environment.""" - - try: - if values['client'] is None: - import gradio_client - values["client"] = gradio_client.Client( - values["inference_server_url"] - ) - except ImportError: - raise ImportError( - "Could not import gradio_client python package. " - "Please install it with `pip install gradio_client`." - ) - return values - - @property - def _llm_type(self) -> str: - """Return type of llm.""" - return "gradio_inference" - - def _call( - self, - prompt: str, - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> str: - # NOTE: prompt here has no prompt_type (e.g. human: bot:) prompt injection, - # so server should get prompt_type or '', not plain - # This is good, so gradio server can also handle stopping.py conditions - # this is different than TGI server that uses prompter to inject prompt_type prompting - stream_output = self.stream_output - gr_client = self.client - client_langchain_mode = 'Disabled' - client_add_chat_history_to_context = True - client_add_search_to_context = False - client_chat_conversation = [] - client_langchain_action = LangChainAction.QUERY.value - client_langchain_agents = [] - top_k_docs = 1 - chunk = True - chunk_size = 512 - client_kwargs = dict(instruction=prompt if self.chat_client else '', # only for chat=True - iinput=self.iinput if self.chat_client else '', # only for chat=True - context=self.context, - # streaming output is supported, loops over and outputs each generation in streaming mode - # but leave stream_output=False for simple input/output mode - stream_output=stream_output, - prompt_type=self.prompter.prompt_type, - prompt_dict='', - - temperature=self.temperature, - top_p=self.top_p, - top_k=self.top_k, - num_beams=self.num_beams, - max_new_tokens=self.max_new_tokens, - min_new_tokens=self.min_new_tokens, - early_stopping=self.early_stopping, - max_time=self.max_time, - repetition_penalty=self.repetition_penalty, - num_return_sequences=self.num_return_sequences, - do_sample=self.do_sample, - chat=self.chat_client, - - instruction_nochat=prompt if not self.chat_client else '', - iinput_nochat=self.iinput if not self.chat_client else '', - langchain_mode=client_langchain_mode, - add_chat_history_to_context=client_add_chat_history_to_context, - langchain_action=client_langchain_action, - langchain_agents=client_langchain_agents, - top_k_docs=top_k_docs, - chunk=chunk, - chunk_size=chunk_size, - document_subset=DocumentSubset.Relevant.name, - document_choice=[DocumentChoice.ALL.value], - pre_prompt_query=None, - prompt_query=None, - pre_prompt_summary=None, - prompt_summary=None, - system_prompt=self.system_prompt, - image_loaders=None, # don't need to further do doc specific things - pdf_loaders=None, # don't need to further do doc specific things - url_loaders=None, # don't need to further do doc specific things - jq_schema=None, # don't need to further do doc specific things - visible_models=self.visible_models, - h2ogpt_key=self.h2ogpt_key, - add_search_to_context=client_add_search_to_context, - chat_conversation=client_chat_conversation, - text_context_list=None, - docs_ordering_type=None, - min_max_new_tokens=self.min_max_new_tokens, - ) - api_name = '/submit_nochat_api' # NOTE: like submit_nochat but stable API for string dict passing - self.count_input_tokens += self.get_num_tokens(prompt) - - if not stream_output: - res = gr_client.predict(str(dict(client_kwargs)), api_name=api_name) - res_dict = ast.literal_eval(res) - text = res_dict['response'] - ret = self.prompter.get_response(prompt + text, prompt=prompt, - sanitize_bot_response=self.sanitize_bot_response) - self.count_output_tokens += self.get_num_tokens(ret) - return ret - else: - text_callback = None - if run_manager: - text_callback = partial( - run_manager.on_llm_new_token, verbose=self.verbose - ) - - job = gr_client.submit(str(dict(client_kwargs)), api_name=api_name) - text0 = '' - while not job.done(): - if job.communicator.job.latest_status.code.name == 'FINISHED': - break - e = job.future._exception - if e is not None: - break - outputs_list = job.communicator.job.outputs - if outputs_list: - res = job.communicator.job.outputs[-1] - res_dict = ast.literal_eval(res) - text = res_dict['response'] - text = self.prompter.get_response(prompt + text, prompt=prompt, - sanitize_bot_response=self.sanitize_bot_response) - # FIXME: derive chunk from full for now - text_chunk = text[len(text0):] - if not text_chunk: - continue - # save old - text0 = text - - if text_callback: - text_callback(text_chunk) - - time.sleep(0.01) - - # ensure get last output to avoid race - res_all = job.outputs() - if len(res_all) > 0: - res = res_all[-1] - res_dict = ast.literal_eval(res) - text = res_dict['response'] - # FIXME: derive chunk from full for now - else: - # go with old if failure - text = text0 - text_chunk = text[len(text0):] - if text_callback: - text_callback(text_chunk) - ret = self.prompter.get_response(prompt + text, prompt=prompt, - sanitize_bot_response=self.sanitize_bot_response) - self.count_output_tokens += self.get_num_tokens(ret) - return ret - - def get_token_ids(self, text: str) -> List[int]: - return self.tokenizer.encode(text) - # avoid base method that is not aware of how to properly tokenize (uses GPT2) - # return _get_token_ids_default_method(text) - - -class H2OHuggingFaceTextGenInference(HuggingFaceTextGenInference): - max_new_tokens: int = 512 - do_sample: bool = False - top_k: Optional[int] = None - top_p: Optional[float] = 0.95 - typical_p: Optional[float] = 0.95 - temperature: float = 0.8 - repetition_penalty: Optional[float] = None - return_full_text: bool = False - stop_sequences: List[str] = Field(default_factory=list) - seed: Optional[int] = None - inference_server_url: str = "" - timeout: int = 300 - headers: dict = None - stream_output: bool = False - sanitize_bot_response: bool = False - prompter: Any = None - context: Any = '' - iinput: Any = '' - tokenizer: Any = None - async_sem: Any = None - count_input_tokens: Any = 0 - count_output_tokens: Any = 0 - - def _call( - self, - prompt: str, - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> str: - if stop is None: - stop = self.stop_sequences.copy() - else: - stop += self.stop_sequences.copy() - stop_tmp = stop.copy() - stop = [] - [stop.append(x) for x in stop_tmp if x not in stop] - - # HF inference server needs control over input tokens - assert self.tokenizer is not None - from h2oai_pipeline import H2OTextGenerationPipeline - prompt, num_prompt_tokens = H2OTextGenerationPipeline.limit_prompt(prompt, self.tokenizer) - - # NOTE: TGI server does not add prompting, so must do here - data_point = dict(context=self.context, instruction=prompt, input=self.iinput) - prompt = self.prompter.generate_prompt(data_point) - self.count_input_tokens += self.get_num_tokens(prompt) - - gen_server_kwargs = dict(do_sample=self.do_sample, - stop_sequences=stop, - max_new_tokens=self.max_new_tokens, - top_k=self.top_k, - top_p=self.top_p, - typical_p=self.typical_p, - temperature=self.temperature, - repetition_penalty=self.repetition_penalty, - return_full_text=self.return_full_text, - seed=self.seed, - ) - gen_server_kwargs.update(kwargs) - - # lower bound because client is re-used if multi-threading - self.client.timeout = max(300, self.timeout) - - if not self.stream_output: - res = self.client.generate( - prompt, - **gen_server_kwargs, - ) - if self.return_full_text: - gen_text = res.generated_text[len(prompt):] - else: - gen_text = res.generated_text - # remove stop sequences from the end of the generated text - for stop_seq in stop: - if stop_seq in gen_text: - gen_text = gen_text[:gen_text.index(stop_seq)] - text = prompt + gen_text - text = self.prompter.get_response(text, prompt=prompt, - sanitize_bot_response=self.sanitize_bot_response) - else: - text_callback = None - if run_manager: - text_callback = partial( - run_manager.on_llm_new_token, verbose=self.verbose - ) - # parent handler of streamer expects to see prompt first else output="" and lose if prompt=None in prompter - if text_callback: - text_callback(prompt) - text = "" - # Note: Streaming ignores return_full_text=True - for response in self.client.generate_stream(prompt, **gen_server_kwargs): - text_chunk = response.token.text - text += text_chunk - text = self.prompter.get_response(prompt + text, prompt=prompt, - sanitize_bot_response=self.sanitize_bot_response) - # stream part - is_stop = False - for stop_seq in stop: - if stop_seq in text_chunk: - is_stop = True - break - if is_stop: - break - if not response.token.special: - if text_callback: - text_callback(text_chunk) - self.count_output_tokens += self.get_num_tokens(text) - return text - - async def _acall( - self, - prompt: str, - stop: Optional[List[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> str: - # print("acall", flush=True) - if stop is None: - stop = self.stop_sequences.copy() - else: - stop += self.stop_sequences.copy() - stop_tmp = stop.copy() - stop = [] - [stop.append(x) for x in stop_tmp if x not in stop] - - # HF inference server needs control over input tokens - assert self.tokenizer is not None - from h2oai_pipeline import H2OTextGenerationPipeline - prompt, num_prompt_tokens = H2OTextGenerationPipeline.limit_prompt(prompt, self.tokenizer) - - # NOTE: TGI server does not add prompting, so must do here - data_point = dict(context=self.context, instruction=prompt, input=self.iinput) - prompt = self.prompter.generate_prompt(data_point) - - gen_text = await super()._acall(prompt, stop=stop, run_manager=run_manager, **kwargs) - - # remove stop sequences from the end of the generated text - for stop_seq in stop: - if stop_seq in gen_text: - gen_text = gen_text[:gen_text.index(stop_seq)] - text = prompt + gen_text - text = self.prompter.get_response(text, prompt=prompt, - sanitize_bot_response=self.sanitize_bot_response) - # print("acall done", flush=True) - return text - - async def _agenerate( - self, - prompts: List[str], - stop: Optional[List[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> LLMResult: - """Run the LLM on the given prompt and input.""" - generations = [] - new_arg_supported = inspect.signature(self._acall).parameters.get("run_manager") - self.count_input_tokens += sum([self.get_num_tokens(prompt) for prompt in prompts]) - tasks = [ - asyncio.ensure_future(self._agenerate_one(prompt, stop=stop, run_manager=run_manager, - new_arg_supported=new_arg_supported, **kwargs)) - for prompt in prompts - ] - texts = await asyncio.gather(*tasks) - self.count_output_tokens += sum([self.get_num_tokens(text) for text in texts]) - [generations.append([Generation(text=text)]) for text in texts] - return LLMResult(generations=generations) - - async def _agenerate_one( - self, - prompt: str, - stop: Optional[List[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, - new_arg_supported=None, - **kwargs: Any, - ) -> str: - async with self.async_sem: # semaphore limits num of simultaneous downloads - return await self._acall(prompt, stop=stop, run_manager=run_manager, **kwargs) \ - if new_arg_supported else \ - await self._acall(prompt, stop=stop, **kwargs) - - def get_token_ids(self, text: str) -> List[int]: - return self.tokenizer.encode(text) - # avoid base method that is not aware of how to properly tokenize (uses GPT2) - # return _get_token_ids_default_method(text) - - -from langchain.chat_models import ChatOpenAI, AzureChatOpenAI -from langchain.llms import OpenAI, AzureOpenAI, Replicate -from langchain.llms.openai import _streaming_response_template, completion_with_retry, _update_response, \ - update_token_usage - - -class H2OOpenAI(OpenAI): - """ - New class to handle vLLM's use of OpenAI, no vllm_chat supported, so only need here - Handles prompting that OpenAI doesn't need, stopping as well - """ - stop_sequences: Any = None - sanitize_bot_response: bool = False - prompter: Any = None - context: Any = '' - iinput: Any = '' - tokenizer: Any = None - - @classmethod - def _all_required_field_names(cls) -> Set: - _all_required_field_names = super(OpenAI, cls)._all_required_field_names() - _all_required_field_names.update( - {'top_p', 'frequency_penalty', 'presence_penalty', 'stop_sequences', 'sanitize_bot_response', 'prompter', - 'tokenizer', 'logit_bias'}) - return _all_required_field_names - - def _generate( - self, - prompts: List[str], - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> LLMResult: - stop_tmp = self.stop_sequences if not stop else self.stop_sequences + stop - stop = [] - [stop.append(x) for x in stop_tmp if x not in stop] - - # HF inference server needs control over input tokens - assert self.tokenizer is not None - from h2oai_pipeline import H2OTextGenerationPipeline - for prompti, prompt in enumerate(prompts): - prompt, num_prompt_tokens = H2OTextGenerationPipeline.limit_prompt(prompt, self.tokenizer) - # NOTE: OpenAI/vLLM server does not add prompting, so must do here - data_point = dict(context=self.context, instruction=prompt, input=self.iinput) - prompt = self.prompter.generate_prompt(data_point) - prompts[prompti] = prompt - - params = self._invocation_params - params = {**params, **kwargs} - sub_prompts = self.get_sub_prompts(params, prompts, stop) - choices = [] - token_usage: Dict[str, int] = {} - # Get the token usage from the response. - # Includes prompt, completion, and total tokens used. - _keys = {"completion_tokens", "prompt_tokens", "total_tokens"} - text = '' - for _prompts in sub_prompts: - if self.streaming: - text_with_prompt = "" - prompt = _prompts[0] - if len(_prompts) > 1: - raise ValueError("Cannot stream results with multiple prompts.") - params["stream"] = True - response = _streaming_response_template() - first = True - for stream_resp in completion_with_retry( - self, prompt=_prompts, **params - ): - if first: - stream_resp["choices"][0]["text"] = prompt + stream_resp["choices"][0]["text"] - first = False - text_chunk = stream_resp["choices"][0]["text"] - text_with_prompt += text_chunk - text = self.prompter.get_response(text_with_prompt, prompt=prompt, - sanitize_bot_response=self.sanitize_bot_response) - if run_manager: - run_manager.on_llm_new_token( - text_chunk, - verbose=self.verbose, - logprobs=stream_resp["choices"][0]["logprobs"], - ) - _update_response(response, stream_resp) - choices.extend(response["choices"]) - else: - response = completion_with_retry(self, prompt=_prompts, **params) - choices.extend(response["choices"]) - if not self.streaming: - # Can't update token usage if streaming - update_token_usage(_keys, response, token_usage) - if self.streaming: - choices[0]['text'] = text - return self.create_llm_result(choices, prompts, token_usage) - - def get_token_ids(self, text: str) -> List[int]: - if self.tokenizer is not None: - return self.tokenizer.encode(text) - else: - # OpenAI uses tiktoken - return super().get_token_ids(text) - - -class H2OReplicate(Replicate): - stop_sequences: Any = None - sanitize_bot_response: bool = False - prompter: Any = None - context: Any = '' - iinput: Any = '' - tokenizer: Any = None - - def _call( - self, - prompt: str, - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> str: - """Call to replicate endpoint.""" - stop_tmp = self.stop_sequences if not stop else self.stop_sequences + stop - stop = [] - [stop.append(x) for x in stop_tmp if x not in stop] - - # HF inference server needs control over input tokens - assert self.tokenizer is not None - from h2oai_pipeline import H2OTextGenerationPipeline - prompt, num_prompt_tokens = H2OTextGenerationPipeline.limit_prompt(prompt, self.tokenizer) - # Note Replicate handles the prompting of the specific model - return super()._call(prompt, stop=stop, run_manager=run_manager, **kwargs) - - def get_token_ids(self, text: str) -> List[int]: - return self.tokenizer.encode(text) - # avoid base method that is not aware of how to properly tokenize (uses GPT2) - # return _get_token_ids_default_method(text) - - -class H2OChatOpenAI(ChatOpenAI): - @classmethod - def _all_required_field_names(cls) -> Set: - _all_required_field_names = super(ChatOpenAI, cls)._all_required_field_names() - _all_required_field_names.update({'top_p', 'frequency_penalty', 'presence_penalty', 'logit_bias'}) - return _all_required_field_names - - -class H2OAzureChatOpenAI(AzureChatOpenAI): - @classmethod - def _all_required_field_names(cls) -> Set: - _all_required_field_names = super(AzureChatOpenAI, cls)._all_required_field_names() - _all_required_field_names.update({'top_p', 'frequency_penalty', 'presence_penalty', 'logit_bias'}) - return _all_required_field_names - - -class H2OAzureOpenAI(AzureOpenAI): - @classmethod - def _all_required_field_names(cls) -> Set: - _all_required_field_names = super(AzureOpenAI, cls)._all_required_field_names() - _all_required_field_names.update({'top_p', 'frequency_penalty', 'presence_penalty', 'logit_bias'}) - return _all_required_field_names - - -class H2OHuggingFacePipeline(HuggingFacePipeline): - def _call( - self, - prompt: str, - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> str: - response = self.pipeline(prompt, stop=stop) - if self.pipeline.task == "text-generation": - # Text generation return includes the starter text. - text = response[0]["generated_text"][len(prompt):] - elif self.pipeline.task == "text2text-generation": - text = response[0]["generated_text"] - elif self.pipeline.task == "summarization": - text = response[0]["summary_text"] - else: - raise ValueError( - f"Got invalid task {self.pipeline.task}, " - f"currently only {VALID_TASKS} are supported" - ) - if stop: - # This is a bit hacky, but I can't figure out a better way to enforce - # stop tokens when making calls to huggingface_hub. - text = enforce_stop_tokens(text, stop) - return text - - -def get_llm(use_openai_model=False, - model_name=None, - model=None, - tokenizer=None, - inference_server=None, - langchain_only_model=None, - stream_output=False, - async_output=True, - num_async=3, - do_sample=False, - temperature=0.1, - top_k=40, - top_p=0.7, - num_beams=1, - max_new_tokens=512, - min_new_tokens=1, - early_stopping=False, - max_time=180, - repetition_penalty=1.0, - num_return_sequences=1, - prompt_type=None, - prompt_dict=None, - prompter=None, - context=None, - iinput=None, - sanitize_bot_response=False, - system_prompt='', - visible_models=0, - h2ogpt_key=None, - min_max_new_tokens=None, - n_jobs=None, - cli=False, - llamacpp_dict=None, - verbose=False, - ): - # currently all but h2oai_pipeline case return prompt + new text, but could change - only_new_text = False - - if n_jobs in [None, -1]: - n_jobs = int(os.getenv('OMP_NUM_THREADS', str(os.cpu_count() // 2))) - if inference_server is None: - inference_server = '' - if inference_server.startswith('replicate'): - model_string = ':'.join(inference_server.split(':')[1:]) - if 'meta/llama' in model_string: - temperature = max(0.01, temperature if do_sample else 0) - else: - temperature =temperature if do_sample else 0 - gen_kwargs = dict(temperature=temperature, - seed=1234, - max_length=max_new_tokens, # langchain - max_new_tokens=max_new_tokens, # replicate docs - top_p=top_p if do_sample else 1, - top_k=top_k, # not always supported - repetition_penalty=repetition_penalty) - if system_prompt in [None, 'None', 'auto']: - if prompter.system_prompt: - system_prompt = prompter.system_prompt - else: - system_prompt = '' - if system_prompt: - gen_kwargs.update(dict(system_prompt=system_prompt)) - - # replicate handles prompting, so avoid get_response() filter - prompter.prompt_type = 'plain' - if stream_output: - callbacks = [StreamingGradioCallbackHandler()] - streamer = callbacks[0] if stream_output else None - llm = H2OReplicate( - streaming=True, - callbacks=callbacks, - model=model_string, - input=gen_kwargs, - stop=prompter.stop_sequences, - stop_sequences=prompter.stop_sequences, - sanitize_bot_response=sanitize_bot_response, - prompter=prompter, - context=context, - iinput=iinput, - tokenizer=tokenizer, - ) - else: - streamer = None - llm = H2OReplicate( - model=model_string, - input=gen_kwargs, - stop=prompter.stop_sequences, - stop_sequences=prompter.stop_sequences, - sanitize_bot_response=sanitize_bot_response, - prompter=prompter, - context=context, - iinput=iinput, - tokenizer=tokenizer, - ) - elif use_openai_model or inference_server.startswith('openai') or inference_server.startswith('vllm'): - if use_openai_model and model_name is None: - model_name = "gpt-3.5-turbo" - # FIXME: Will later import be ignored? I think so, so should be fine - openai, inf_type, deployment_name, base_url, api_version = set_openai(inference_server) - kwargs_extra = {} - if inf_type == 'openai_chat' or inf_type == 'vllm_chat': - cls = H2OChatOpenAI - # FIXME: Support context, iinput - # if inf_type == 'vllm_chat': - # kwargs_extra.update(dict(tokenizer=tokenizer)) - openai_api_key = openai.api_key - elif inf_type == 'openai_azure_chat': - cls = H2OAzureChatOpenAI - kwargs_extra.update(dict(openai_api_type='azure')) - # FIXME: Support context, iinput - if os.getenv('OPENAI_AZURE_KEY') is not None: - openai_api_key = os.getenv('OPENAI_AZURE_KEY') - else: - openai_api_key = openai.api_key - elif inf_type == 'openai_azure': - cls = H2OAzureOpenAI - kwargs_extra.update(dict(openai_api_type='azure')) - # FIXME: Support context, iinput - if os.getenv('OPENAI_AZURE_KEY') is not None: - openai_api_key = os.getenv('OPENAI_AZURE_KEY') - else: - openai_api_key = openai.api_key - else: - cls = H2OOpenAI - if inf_type == 'vllm': - kwargs_extra.update(dict(stop_sequences=prompter.stop_sequences, - sanitize_bot_response=sanitize_bot_response, - prompter=prompter, - context=context, - iinput=iinput, - tokenizer=tokenizer, - openai_api_base=openai.api_base, - client=None)) - else: - assert inf_type == 'openai' or use_openai_model - openai_api_key = openai.api_key - - if deployment_name: - kwargs_extra.update(dict(deployment_name=deployment_name)) - if api_version: - kwargs_extra.update(dict(openai_api_version=api_version)) - elif openai.api_version: - kwargs_extra.update(dict(openai_api_version=openai.api_version)) - elif inf_type in ['openai_azure', 'openai_azure_chat']: - kwargs_extra.update(dict(openai_api_version="2023-05-15")) - if base_url: - kwargs_extra.update(dict(openai_api_base=base_url)) - else: - kwargs_extra.update(dict(openai_api_base=openai.api_base)) - - callbacks = [StreamingGradioCallbackHandler()] - llm = cls(model_name=model_name, - temperature=temperature if do_sample else 0, - # FIXME: Need to count tokens and reduce max_new_tokens to fit like in generate.py - max_tokens=max_new_tokens, - top_p=top_p if do_sample else 1, - frequency_penalty=0, - presence_penalty=1.07 - repetition_penalty + 0.6, # so good default - callbacks=callbacks if stream_output else None, - openai_api_key=openai_api_key, - logit_bias=None if inf_type == 'vllm' else {}, - max_retries=6, - streaming=stream_output, - **kwargs_extra - ) - streamer = callbacks[0] if stream_output else None - if inf_type in ['openai', 'openai_chat', 'openai_azure', 'openai_azure_chat']: - prompt_type = inference_server - else: - # vllm goes here - prompt_type = prompt_type or 'plain' - elif inference_server and inference_server.startswith('sagemaker'): - callbacks = [StreamingGradioCallbackHandler()] # FIXME - streamer = None - - endpoint_name = ':'.join(inference_server.split(':')[1:2]) - region_name = ':'.join(inference_server.split(':')[2:]) - - from sagemaker import H2OSagemakerEndpoint, ChatContentHandler, BaseContentHandler - if inference_server.startswith('sagemaker_chat'): - content_handler = ChatContentHandler() - else: - content_handler = BaseContentHandler() - model_kwargs = dict(temperature=temperature if do_sample else 1E-10, - return_full_text=False, top_p=top_p, max_new_tokens=max_new_tokens) - llm = H2OSagemakerEndpoint( - endpoint_name=endpoint_name, - region_name=region_name, - aws_access_key_id=os.environ.get('AWS_ACCESS_KEY_ID'), - aws_secret_access_key=os.environ.get('AWS_SECRET_ACCESS_KEY'), - model_kwargs=model_kwargs, - content_handler=content_handler, - endpoint_kwargs={'CustomAttributes': 'accept_eula=true'}, - ) - elif inference_server: - assert inference_server.startswith( - 'http'), "Malformed inference_server=%s. Did you add http:// in front?" % inference_server - - from gradio_utils.grclient import GradioClient - from text_generation import Client as HFClient - if isinstance(model, GradioClient): - gr_client = model - hf_client = None - else: - gr_client = None - hf_client = model - assert isinstance(hf_client, HFClient) - - inference_server, headers = get_hf_server(inference_server) - - # quick sanity check to avoid long timeouts, just see if can reach server - requests.get(inference_server, timeout=int(os.getenv('REQUEST_TIMEOUT_FAST', '10'))) - callbacks = [StreamingGradioCallbackHandler()] - - if gr_client: - async_output = False # FIXME: not implemented yet - chat_client = False - llm = GradioInference( - inference_server_url=inference_server, - return_full_text=False, - - temperature=temperature, - top_p=top_p, - top_k=top_k, - num_beams=num_beams, - max_new_tokens=max_new_tokens, - min_new_tokens=min_new_tokens, - early_stopping=early_stopping, - max_time=max_time, - repetition_penalty=repetition_penalty, - num_return_sequences=num_return_sequences, - do_sample=do_sample, - chat_client=chat_client, - - callbacks=callbacks if stream_output else None, - stream_output=stream_output, - prompter=prompter, - context=context, - iinput=iinput, - client=gr_client, - sanitize_bot_response=sanitize_bot_response, - tokenizer=tokenizer, - system_prompt=system_prompt, - visible_models=visible_models, - h2ogpt_key=h2ogpt_key, - min_max_new_tokens=min_max_new_tokens, - ) - elif hf_client: - # no need to pass original client, no state and fast, so can use same validate_environment from base class - async_sem = asyncio.Semaphore(num_async) if async_output else NullContext() - llm = H2OHuggingFaceTextGenInference( - inference_server_url=inference_server, - do_sample=do_sample, - max_new_tokens=max_new_tokens, - repetition_penalty=repetition_penalty, - return_full_text=False, # this only controls internal behavior, still returns processed text - seed=SEED, - - stop_sequences=prompter.stop_sequences, - temperature=temperature, - top_k=top_k, - top_p=top_p, - # typical_p=top_p, - callbacks=callbacks if stream_output else None, - stream_output=stream_output, - prompter=prompter, - context=context, - iinput=iinput, - tokenizer=tokenizer, - timeout=max_time, - sanitize_bot_response=sanitize_bot_response, - async_sem=async_sem, - ) - else: - raise RuntimeError("No defined client") - streamer = callbacks[0] if stream_output else None - elif model_name in non_hf_types: - async_output = False # FIXME: not implemented yet - assert langchain_only_model - if model_name == 'llama': - callbacks = [StreamingGradioCallbackHandler()] - streamer = callbacks[0] if stream_output else None - else: - # stream_output = False - # doesn't stream properly as generator, but at least - callbacks = [streaming_stdout.StreamingStdOutCallbackHandler()] - streamer = None - if prompter: - prompt_type = prompter.prompt_type - else: - prompter = Prompter(prompt_type, prompt_dict, debug=False, chat=False, stream_output=stream_output) - pass # assume inputted prompt_type is correct - from gpt4all_llm import get_llm_gpt4all - max_max_tokens = tokenizer.model_max_length - llm = get_llm_gpt4all(model_name, - model=model, - max_new_tokens=max_new_tokens, - temperature=temperature, - repetition_penalty=repetition_penalty, - top_k=top_k, - top_p=top_p, - callbacks=callbacks, - n_jobs=n_jobs, - verbose=verbose, - streaming=stream_output, - prompter=prompter, - context=context, - iinput=iinput, - max_seq_len=max_max_tokens, - llamacpp_dict=llamacpp_dict, - ) - elif hasattr(model, 'is_exlama') and model.is_exlama(): - async_output = False # FIXME: not implemented yet - assert langchain_only_model - callbacks = [StreamingGradioCallbackHandler()] - streamer = callbacks[0] if stream_output else None - max_max_tokens = tokenizer.model_max_length - - from src.llm_exllama import Exllama - llm = Exllama(streaming=stream_output, - model_path=None, - model=model, - lora_path=None, - temperature=temperature, - top_k=top_k, - top_p=top_p, - typical=.7, - beams=1, - # beam_length = 40, - stop_sequences=prompter.stop_sequences, - callbacks=callbacks, - verbose=verbose, - max_seq_len=max_max_tokens, - fused_attn=False, - # alpha_value = 1.0, #For use with any models - # compress_pos_emb = 4.0, #For use with superhot - # set_auto_map = "3, 2" #Gpu split, this will split 3gigs/2gigs - prompter=prompter, - context=context, - iinput=iinput, - ) - else: - async_output = False # FIXME: not implemented yet - if model is None: - # only used if didn't pass model in - assert tokenizer is None - prompt_type = 'human_bot' - if model_name is None: - model_name = 'h2oai/h2ogpt-oasst1-512-12b' - # model_name = 'h2oai/h2ogpt-oig-oasst1-512-6_9b' - # model_name = 'h2oai/h2ogpt-oasst1-512-20b' - inference_server = '' - model, tokenizer, device = get_model(load_8bit=True, base_model=model_name, - inference_server=inference_server, gpu_id=0) - - max_max_tokens = tokenizer.model_max_length - only_new_text = True - gen_kwargs = dict(do_sample=do_sample, - num_beams=num_beams, - max_new_tokens=max_new_tokens, - min_new_tokens=min_new_tokens, - early_stopping=early_stopping, - max_time=max_time, - repetition_penalty=repetition_penalty, - num_return_sequences=num_return_sequences, - return_full_text=not only_new_text, - handle_long_generation=None) - if do_sample: - gen_kwargs.update(dict(temperature=temperature, - top_k=top_k, - top_p=top_p)) - assert len(set(gen_hyper).difference(gen_kwargs.keys())) == 0 - else: - assert len(set(gen_hyper0).difference(gen_kwargs.keys())) == 0 - - if stream_output: - skip_prompt = only_new_text - from gen import H2OTextIteratorStreamer - decoder_kwargs = {} - streamer = H2OTextIteratorStreamer(tokenizer, skip_prompt=skip_prompt, block=False, **decoder_kwargs) - gen_kwargs.update(dict(streamer=streamer)) - else: - streamer = None - - from h2oai_pipeline import H2OTextGenerationPipeline - pipe = H2OTextGenerationPipeline(model=model, use_prompter=True, - prompter=prompter, - context=context, - iinput=iinput, - prompt_type=prompt_type, - prompt_dict=prompt_dict, - sanitize_bot_response=sanitize_bot_response, - chat=False, stream_output=stream_output, - tokenizer=tokenizer, - # leave some room for 1 paragraph, even if min_new_tokens=0 - max_input_tokens=max_max_tokens - max(min_new_tokens, 256), - base_model=model_name, - **gen_kwargs) - # pipe.task = "text-generation" - # below makes it listen only to our prompt removal, - # not built in prompt removal that is less general and not specific for our model - pipe.task = "text2text-generation" - - llm = H2OHuggingFacePipeline(pipeline=pipe) - return llm, model_name, streamer, prompt_type, async_output, only_new_text - - -def get_device_dtype(): - # torch.device("cuda") leads to cuda:x cuda:y mismatches for multi-GPU consistently - import torch - n_gpus = torch.cuda.device_count() if torch.cuda.is_available else 0 - device = 'cpu' if n_gpus == 0 else 'cuda' - # from utils import NullContext - # context_class = NullContext if n_gpus > 1 or n_gpus == 0 else context_class - context_class = torch.device - torch_dtype = torch.float16 if device == 'cuda' else torch.float32 - return device, torch_dtype, context_class - - -def get_wiki_data(title, first_paragraph_only, text_limit=None, take_head=True): - """ - Get wikipedia data from online - :param title: - :param first_paragraph_only: - :param text_limit: - :param take_head: - :return: - """ - filename = 'wiki_%s_%s_%s_%s.data' % (first_paragraph_only, title, text_limit, take_head) - url = f"https://en.wikipedia.org/w/api.php?format=json&action=query&prop=extracts&explaintext=1&titles={title}" - if first_paragraph_only: - url += "&exintro=1" - import json - if not os.path.isfile(filename): - data = requests.get(url).json() - json.dump(data, open(filename, 'wt')) - else: - data = json.load(open(filename, "rt")) - page_content = list(data["query"]["pages"].values())[0]["extract"] - if take_head is not None and text_limit is not None: - page_content = page_content[:text_limit] if take_head else page_content[-text_limit:] - title_url = str(title).replace(' ', '_') - return Document( - page_content=str(page_content), - metadata={"source": f"https://en.wikipedia.org/wiki/{title_url}"}, - ) - - -def get_wiki_sources(first_para=True, text_limit=None): - """ - Get specific named sources from wikipedia - :param first_para: - :param text_limit: - :return: - """ - default_wiki_sources = ['Unix', 'Microsoft_Windows', 'Linux'] - wiki_sources = list(os.getenv('WIKI_SOURCES', default_wiki_sources)) - return [get_wiki_data(x, first_para, text_limit=text_limit) for x in wiki_sources] - - -def get_github_docs(repo_owner, repo_name): - """ - Access github from specific repo - :param repo_owner: - :param repo_name: - :return: - """ - with tempfile.TemporaryDirectory() as d: - subprocess.check_call( - f"git clone --depth 1 https://github.com/{repo_owner}/{repo_name}.git .", - cwd=d, - shell=True, - ) - git_sha = ( - subprocess.check_output("git rev-parse HEAD", shell=True, cwd=d) - .decode("utf-8") - .strip() - ) - repo_path = pathlib.Path(d) - markdown_files = list(repo_path.glob("*/*.md")) + list( - repo_path.glob("*/*.mdx") - ) - for markdown_file in markdown_files: - with open(markdown_file, "r") as f: - relative_path = markdown_file.relative_to(repo_path) - github_url = f"https://github.com/{repo_owner}/{repo_name}/blob/{git_sha}/{relative_path}" - yield Document(page_content=str(f.read()), metadata={"source": github_url}) - - -def get_dai_pickle(dest="."): - from huggingface_hub import hf_hub_download - # True for case when locally already logged in with correct token, so don't have to set key - token = os.getenv('HUGGING_FACE_HUB_TOKEN', True) - path_to_zip_file = hf_hub_download('h2oai/dai_docs', 'dai_docs.pickle', token=token, repo_type='dataset') - shutil.copy(path_to_zip_file, dest) - - -def get_dai_docs(from_hf=False, get_pickle=True): - """ - Consume DAI documentation, or consume from public pickle - :param from_hf: get DAI docs from HF, then generate pickle for later use by LangChain - :param get_pickle: Avoid raw DAI docs, just get pickle directly from HF - :return: - """ - import pickle - - if get_pickle: - get_dai_pickle() - - dai_store = 'dai_docs.pickle' - dst = "working_dir_docs" - if not os.path.isfile(dai_store): - from create_data import setup_dai_docs - dst = setup_dai_docs(dst=dst, from_hf=from_hf) - - import glob - files = list(glob.glob(os.path.join(dst, '*rst'), recursive=True)) - - basedir = os.path.abspath(os.getcwd()) - from create_data import rst_to_outputs - new_outputs = rst_to_outputs(files) - os.chdir(basedir) - - pickle.dump(new_outputs, open(dai_store, 'wb')) - else: - new_outputs = pickle.load(open(dai_store, 'rb')) - - sources = [] - for line, file in new_outputs: - # gradio requires any linked file to be with app.py - sym_src = os.path.abspath(os.path.join(dst, file)) - sym_dst = os.path.abspath(os.path.join(os.getcwd(), file)) - if os.path.lexists(sym_dst): - os.remove(sym_dst) - os.symlink(sym_src, sym_dst) - itm = Document(page_content=str(line), metadata={"source": file}) - # NOTE: yield has issues when going into db, loses metadata - # yield itm - sources.append(itm) - return sources - - -def get_supported_types(): - non_image_types0 = ["pdf", "txt", "csv", "toml", "py", "rst", "xml", "rtf", - "md", - "html", "mhtml", "htm", - "enex", "eml", "epub", "odt", "pptx", "ppt", - "zip", - "gz", - "gzip", - "urls", - ] - # "msg", GPL3 - - video_types0 = ['WEBM', - 'MPG', 'MP2', 'MPEG', 'MPE', '.PV', - 'OGG', - 'MP4', 'M4P', 'M4V', - 'AVI', 'WMV', - 'MOV', 'QT', - 'FLV', 'SWF', - 'AVCHD'] - video_types0 = [x.lower() for x in video_types0] - if have_pillow: - from PIL import Image - exts = Image.registered_extensions() - image_types0 = {ex for ex, f in exts.items() if f in Image.OPEN if ex not in video_types0 + non_image_types0} - image_types0 = sorted(image_types0) - image_types0 = [x[1:] if x.startswith('.') else x for x in image_types0] - else: - image_types0 = [] - return non_image_types0, image_types0, video_types0 - - -non_image_types, image_types, video_types = get_supported_types() -set_image_types = set(image_types) - -if have_libreoffice or True: - # or True so it tries to load, e.g. on MAC/Windows, even if don't have libreoffice since works without that - non_image_types.extend(["docx", "doc", "xls", "xlsx"]) -if have_jq: - non_image_types.extend(["json", "jsonl"]) - -file_types = non_image_types + image_types - - -def try_as_html(file): - # try treating as html as occurs when scraping websites - from bs4 import BeautifulSoup - with open(file, "rt") as f: - try: - is_html = bool(BeautifulSoup(f.read(), "html.parser").find()) - except: # FIXME - is_html = False - if is_html: - file_url = 'file://' + file - doc1 = UnstructuredURLLoader(urls=[file_url]).load() - doc1 = [x for x in doc1 if x.page_content] - else: - doc1 = [] - return doc1 - - -def json_metadata_func(record: dict, metadata: dict) -> dict: - # Define the metadata extraction function. - - if isinstance(record, dict): - metadata["sender_name"] = record.get("sender_name") - metadata["timestamp_ms"] = record.get("timestamp_ms") - - if "source" in metadata: - metadata["source_json"] = metadata['source'] - if "seq_num" in metadata: - metadata["seq_num_json"] = metadata['seq_num'] - - return metadata - - -def file_to_doc(file, - filei=0, - base_path=None, verbose=False, fail_any_exception=False, - chunk=True, chunk_size=512, n_jobs=-1, - is_url=False, is_txt=False, - - # urls - use_unstructured=True, - use_playwright=False, - use_selenium=False, - - # pdfs - use_pymupdf='auto', - use_unstructured_pdf='auto', - use_pypdf='auto', - enable_pdf_ocr='auto', - try_pdf_as_html='auto', - enable_pdf_doctr='auto', - - # images - enable_ocr=False, - enable_doctr=False, - enable_pix2struct=False, - enable_captions=True, - captions_model=None, - model_loaders=None, - - # json - jq_schema='.[]', - - headsize=50, # see also H2OSerpAPIWrapper - db_type=None, - selected_file_types=None): - assert isinstance(model_loaders, dict) - if selected_file_types is not None: - set_image_types1 = set_image_types.intersection(set(selected_file_types)) - else: - set_image_types1 = set_image_types - - assert db_type is not None - chunk_sources = functools.partial(_chunk_sources, chunk=chunk, chunk_size=chunk_size, db_type=db_type) - add_meta = functools.partial(_add_meta, headsize=headsize, filei=filei) - # FIXME: if zip, file index order will not be correct if other files involved - path_to_docs_func = functools.partial(path_to_docs, - verbose=verbose, - fail_any_exception=fail_any_exception, - n_jobs=n_jobs, - chunk=chunk, chunk_size=chunk_size, - # url=file if is_url else None, - # text=file if is_txt else None, - - # urls - use_unstructured=use_unstructured, - use_playwright=use_playwright, - use_selenium=use_selenium, - - # pdfs - use_pymupdf=use_pymupdf, - use_unstructured_pdf=use_unstructured_pdf, - use_pypdf=use_pypdf, - enable_pdf_ocr=enable_pdf_ocr, - enable_pdf_doctr=enable_pdf_doctr, - try_pdf_as_html=try_pdf_as_html, - - # images - enable_ocr=enable_ocr, - enable_doctr=enable_doctr, - enable_pix2struct=enable_pix2struct, - enable_captions=enable_captions, - captions_model=captions_model, - - caption_loader=model_loaders['caption'], - doctr_loader=model_loaders['doctr'], - pix2struct_loader=model_loaders['pix2struct'], - - # json - jq_schema=jq_schema, - - db_type=db_type, - ) - - if file is None: - if fail_any_exception: - raise RuntimeError("Unexpected None file") - else: - return [] - doc1 = [] # in case no support, or disabled support - if base_path is None and not is_txt and not is_url: - # then assume want to persist but don't care which path used - # can't be in base_path - dir_name = os.path.dirname(file) - base_name = os.path.basename(file) - # if from gradio, will have its own temp uuid too, but that's ok - base_name = sanitize_filename(base_name) + "_" + str(uuid.uuid4())[:10] - base_path = os.path.join(dir_name, base_name) - if is_url: - file = file.strip() # in case accidental spaces in front or at end - file_lower = file.lower() - case1 = file_lower.startswith('arxiv:') and len(file_lower.split('arxiv:')) == 2 - case2 = file_lower.startswith('https://arxiv.org/abs') and len(file_lower.split('https://arxiv.org/abs')) == 2 - case3 = file_lower.startswith('http://arxiv.org/abs') and len(file_lower.split('http://arxiv.org/abs')) == 2 - case4 = file_lower.startswith('arxiv.org/abs/') and len(file_lower.split('arxiv.org/abs/')) == 2 - if case1 or case2 or case3 or case4: - if case1: - query = file.lower().split('arxiv:')[1].strip() - elif case2: - query = file.lower().split('https://arxiv.org/abs/')[1].strip() - elif case2: - query = file.lower().split('http://arxiv.org/abs/')[1].strip() - elif case3: - query = file.lower().split('arxiv.org/abs/')[1].strip() - else: - raise RuntimeError("Unexpected arxiv error for %s" % file) - if have_arxiv: - trials = 3 - docs1 = [] - for trial in range(trials): - try: - docs1 = ArxivLoader(query=query, load_max_docs=20, load_all_available_meta=True).load() - break - except urllib.error.URLError: - pass - if not docs1: - print("Failed to get arxiv %s" % query, flush=True) - # ensure string, sometimes None - [[x.metadata.update({k: str(v)}) for k, v in x.metadata.items()] for x in docs1] - query_url = f"https://arxiv.org/abs/{query}" - [x.metadata.update( - dict(source=x.metadata.get('entry_id', query_url), query=query_url, - input_type='arxiv', head=x.metadata.get('Title', ''), date=str(datetime.now))) for x in - docs1] - else: - docs1 = [] - else: - if not (file.startswith("http://") or file.startswith("file://") or file.startswith("https://")): - file = 'http://' + file - docs1 = [] - do_unstructured = only_unstructured_urls or use_unstructured - if only_selenium or only_playwright: - do_unstructured = False - do_playwright = have_playwright and (use_playwright or only_playwright) - if only_unstructured_urls or only_selenium: - do_playwright = False - do_selenium = have_selenium and (use_selenium or only_selenium) - if only_unstructured_urls or only_playwright: - do_selenium = False - if do_unstructured or use_unstructured: - docs1a = UnstructuredURLLoader(urls=[file]).load() - docs1a = [x for x in docs1a if x.page_content] - add_parser(docs1a, 'UnstructuredURLLoader') - docs1.extend(docs1a) - if len(docs1) == 0 and have_playwright or do_playwright: - # then something went wrong, try another loader: - from langchain.document_loaders import PlaywrightURLLoader - docs1a = asyncio.run(PlaywrightURLLoader(urls=[file]).aload()) - # docs1 = PlaywrightURLLoader(urls=[file]).load() - docs1a = [x for x in docs1a if x.page_content] - add_parser(docs1a, 'PlaywrightURLLoader') - docs1.extend(docs1a) - if len(docs1) == 0 and have_selenium or do_selenium: - # then something went wrong, try another loader: - # but requires Chrome binary, else get: selenium.common.exceptions.WebDriverException: - # Message: unknown error: cannot find Chrome binary - from langchain.document_loaders import SeleniumURLLoader - from selenium.common.exceptions import WebDriverException - try: - docs1a = SeleniumURLLoader(urls=[file]).load() - docs1a = [x for x in docs1a if x.page_content] - add_parser(docs1a, 'SeleniumURLLoader') - docs1.extend(docs1a) - except WebDriverException as e: - print("No web driver: %s" % str(e), flush=True) - [x.metadata.update(dict(input_type='url', date=str(datetime.now))) for x in docs1] - add_meta(docs1, file, parser="is_url") - docs1 = clean_doc(docs1) - doc1 = chunk_sources(docs1) - elif is_txt: - base_path = "user_paste" - base_path = makedirs(base_path, exist_ok=True, tmp_ok=True, use_base=True) - source_file = os.path.join(base_path, "_%s" % str(uuid.uuid4())[:10]) - with open(source_file, "wt") as f: - f.write(file) - metadata = dict(source=source_file, date=str(datetime.now()), input_type='pasted txt') - doc1 = Document(page_content=str(file), metadata=metadata) - add_meta(doc1, file, parser="f.write") - # Bit odd to change if was original text - # doc1 = clean_doc(doc1) - elif file.lower().endswith('.html') or file.lower().endswith('.mhtml') or file.lower().endswith('.htm'): - docs1 = UnstructuredHTMLLoader(file_path=file).load() - add_meta(docs1, file, parser='UnstructuredHTMLLoader') - docs1 = clean_doc(docs1) - doc1 = chunk_sources(docs1, language=Language.HTML) - elif (file.lower().endswith('.docx') or file.lower().endswith('.doc')) and (have_libreoffice or True): - docs1 = UnstructuredWordDocumentLoader(file_path=file).load() - add_meta(docs1, file, parser='UnstructuredWordDocumentLoader') - doc1 = chunk_sources(docs1) - elif (file.lower().endswith('.xlsx') or file.lower().endswith('.xls')) and (have_libreoffice or True): - docs1 = UnstructuredExcelLoader(file_path=file).load() - add_meta(docs1, file, parser='UnstructuredExcelLoader') - doc1 = chunk_sources(docs1) - elif file.lower().endswith('.odt'): - docs1 = UnstructuredODTLoader(file_path=file).load() - add_meta(docs1, file, parser='UnstructuredODTLoader') - doc1 = chunk_sources(docs1) - elif file.lower().endswith('pptx') or file.lower().endswith('ppt'): - docs1 = UnstructuredPowerPointLoader(file_path=file).load() - add_meta(docs1, file, parser='UnstructuredPowerPointLoader') - docs1 = clean_doc(docs1) - doc1 = chunk_sources(docs1) - elif file.lower().endswith('.txt'): - # use UnstructuredFileLoader ? - docs1 = TextLoader(file, encoding="utf8", autodetect_encoding=True).load() - # makes just one, but big one - doc1 = chunk_sources(docs1) - # Bit odd to change if was original text - # doc1 = clean_doc(doc1) - add_meta(doc1, file, parser='TextLoader') - elif file.lower().endswith('.rtf'): - docs1 = UnstructuredRTFLoader(file).load() - add_meta(docs1, file, parser='UnstructuredRTFLoader') - doc1 = chunk_sources(docs1) - elif file.lower().endswith('.md'): - docs1 = UnstructuredMarkdownLoader(file).load() - add_meta(docs1, file, parser='UnstructuredMarkdownLoader') - docs1 = clean_doc(docs1) - doc1 = chunk_sources(docs1, language=Language.MARKDOWN) - elif file.lower().endswith('.enex'): - docs1 = EverNoteLoader(file).load() - add_meta(doc1, file, parser='EverNoteLoader') - doc1 = chunk_sources(docs1) - elif file.lower().endswith('.epub'): - docs1 = UnstructuredEPubLoader(file).load() - add_meta(docs1, file, parser='UnstructuredEPubLoader') - doc1 = chunk_sources(docs1) - elif any(file.lower().endswith(x) for x in set_image_types1): - docs1 = [] - if verbose: - print("BEGIN: Tesseract", flush=True) - if have_tesseract and enable_ocr: - # OCR, somewhat works, but not great - docs1a = UnstructuredImageLoader(file, strategy='ocr_only').load() - # docs1a = UnstructuredImageLoader(file, strategy='hi_res').load() - docs1a = [x for x in docs1a if x.page_content] - add_meta(docs1a, file, parser='UnstructuredImageLoader') - docs1.extend(docs1a) - if verbose: - print("END: Tesseract", flush=True) - if have_doctr and enable_doctr: - if verbose: - print("BEGIN: DocTR", flush=True) - if model_loaders['doctr'] is not None and not isinstance(model_loaders['doctr'], (str, bool)): - if verbose: - print("Reuse DocTR", flush=True) - model_loaders['doctr'].load_model() - else: - if verbose: - print("Fresh DocTR", flush=True) - from image_doctr import H2OOCRLoader - model_loaders['doctr'] = H2OOCRLoader() - model_loaders['doctr'].set_document_paths([file]) - docs1c = model_loaders['doctr'].load() - docs1c = [x for x in docs1c if x.page_content] - add_meta(docs1c, file, parser='H2OOCRLoader: %s' % 'DocTR') - # caption didn't set source, so fix-up meta - for doci in docs1c: - doci.metadata['source'] = doci.metadata.get('document_path', file) - doci.metadata['hashid'] = hash_file(doci.metadata['source']) - docs1.extend(docs1c) - if verbose: - print("END: DocTR", flush=True) - if enable_captions: - # BLIP - if verbose: - print("BEGIN: BLIP", flush=True) - if model_loaders['caption'] is not None and not isinstance(model_loaders['caption'], (str, bool)): - # assumes didn't fork into this process with joblib, else can deadlock - if verbose: - print("Reuse BLIP", flush=True) - model_loaders['caption'].load_model() - else: - if verbose: - print("Fresh BLIP", flush=True) - from image_captions import H2OImageCaptionLoader - model_loaders['caption'] = H2OImageCaptionLoader(caption_gpu=model_loaders['caption'] == 'gpu', - blip_model=captions_model, - blip_processor=captions_model) - model_loaders['caption'].set_image_paths([file]) - docs1c = model_loaders['caption'].load() - docs1c = [x for x in docs1c if x.page_content] - add_meta(docs1c, file, parser='H2OImageCaptionLoader: %s' % captions_model) - # caption didn't set source, so fix-up meta - for doci in docs1c: - doci.metadata['source'] = doci.metadata.get('image_path', file) - doci.metadata['hashid'] = hash_file(doci.metadata['source']) - docs1.extend(docs1c) - - if verbose: - print("END: BLIP", flush=True) - if enable_pix2struct: - # BLIP - if verbose: - print("BEGIN: Pix2Struct", flush=True) - if model_loaders['pix2struct'] is not None and not isinstance(model_loaders['pix2struct'], (str, bool)): - if verbose: - print("Reuse pix2struct", flush=True) - model_loaders['pix2struct'].load_model() - else: - if verbose: - print("Fresh pix2struct", flush=True) - from image_pix2struct import H2OPix2StructLoader - model_loaders['pix2struct'] = H2OPix2StructLoader() - model_loaders['pix2struct'].set_image_paths([file]) - docs1c = model_loaders['pix2struct'].load() - docs1c = [x for x in docs1c if x.page_content] - add_meta(docs1c, file, parser='H2OPix2StructLoader: %s' % model_loaders['pix2struct']) - # caption didn't set source, so fix-up meta - for doci in docs1c: - doci.metadata['source'] = doci.metadata.get('image_path', file) - doci.metadata['hashid'] = hash_file(doci.metadata['source']) - docs1.extend(docs1c) - if verbose: - print("END: Pix2Struct", flush=True) - doc1 = chunk_sources(docs1) - elif file.lower().endswith('.msg'): - raise RuntimeError("Not supported, GPL3 license") - # docs1 = OutlookMessageLoader(file).load() - # docs1[0].metadata['source'] = file - elif file.lower().endswith('.eml'): - try: - docs1 = UnstructuredEmailLoader(file).load() - add_meta(docs1, file, parser='UnstructuredEmailLoader') - doc1 = chunk_sources(docs1) - except ValueError as e: - if 'text/html content not found in email' in str(e): - pass - else: - raise - doc1 = [x for x in doc1 if x.page_content] - if len(doc1) == 0: - # e.g. plain/text dict key exists, but not - # doc1 = TextLoader(file, encoding="utf8").load() - docs1 = UnstructuredEmailLoader(file, content_source="text/plain").load() - docs1 = [x for x in docs1 if x.page_content] - add_meta(docs1, file, parser='UnstructuredEmailLoader text/plain') - doc1 = chunk_sources(docs1) - # elif file.lower().endswith('.gcsdir'): - # doc1 = GCSDirectoryLoader(project_name, bucket, prefix).load() - # elif file.lower().endswith('.gcsfile'): - # doc1 = GCSFileLoader(project_name, bucket, blob).load() - elif file.lower().endswith('.rst'): - with open(file, "r") as f: - doc1 = Document(page_content=str(f.read()), metadata={"source": file}) - add_meta(doc1, file, parser='f.read()') - doc1 = chunk_sources(doc1, language=Language.RST) - elif file.lower().endswith('.json'): - # 10k rows, 100 columns-like parts 4 bytes each - JSON_SIZE_LIMIT = int(os.getenv('JSON_SIZE_LIMIT', str(10 * 10 * 1024 * 10 * 4))) - if os.path.getsize(file) > JSON_SIZE_LIMIT: - raise ValueError( - "JSON file sizes > %s not supported for naive parsing and embedding, requires Agents enabled" % JSON_SIZE_LIMIT) - loader = JSONLoader( - file_path=file, - # jq_schema='.messages[].content', - jq_schema=jq_schema, - text_content=False, - metadata_func=json_metadata_func) - doc1 = loader.load() - add_meta(doc1, file, parser='JSONLoader: %s' % jq_schema) - fix_json_meta(doc1) - elif file.lower().endswith('.jsonl'): - loader = JSONLoader( - file_path=file, - # jq_schema='.messages[].content', - jq_schema=jq_schema, - json_lines=True, - text_content=False, - metadata_func=json_metadata_func) - doc1 = loader.load() - add_meta(doc1, file, parser='JSONLoader: %s' % jq_schema) - fix_json_meta(doc1) - elif file.lower().endswith('.pdf'): - # migration - if isinstance(use_pymupdf, bool): - if use_pymupdf == False: - use_pymupdf = 'off' - if use_pymupdf == True: - use_pymupdf = 'on' - if isinstance(use_unstructured_pdf, bool): - if use_unstructured_pdf == False: - use_unstructured_pdf = 'off' - if use_unstructured_pdf == True: - use_unstructured_pdf = 'on' - if isinstance(use_pypdf, bool): - if use_pypdf == False: - use_pypdf = 'off' - if use_pypdf == True: - use_pypdf = 'on' - if isinstance(enable_pdf_ocr, bool): - if enable_pdf_ocr == False: - enable_pdf_ocr = 'off' - if enable_pdf_ocr == True: - enable_pdf_ocr = 'on' - if isinstance(try_pdf_as_html, bool): - if try_pdf_as_html == False: - try_pdf_as_html = 'off' - if try_pdf_as_html == True: - try_pdf_as_html = 'on' - - doc1 = [] - tried_others = False - handled = False - did_pymupdf = False - did_unstructured = False - e = None - if have_pymupdf and (len(doc1) == 0 and use_pymupdf == 'auto' or use_pymupdf == 'on'): - # GPL, only use if installed - from langchain.document_loaders import PyMuPDFLoader - # load() still chunks by pages, but every page has title at start to help - try: - doc1a = PyMuPDFLoader(file).load() - did_pymupdf = True - except BaseException as e0: - doc1a = [] - print("PyMuPDFLoader: %s" % str(e0), flush=True) - e = e0 - # remove empty documents - handled |= len(doc1a) > 0 - doc1a = [x for x in doc1a if x.page_content] - doc1a = clean_doc(doc1a) - add_parser(doc1a, 'PyMuPDFLoader') - doc1.extend(doc1a) - if len(doc1) == 0 and use_unstructured_pdf == 'auto' or use_unstructured_pdf == 'on': - tried_others = True - try: - doc1a = UnstructuredPDFLoader(file).load() - did_unstructured = True - except BaseException as e0: - doc1a = [] - print("UnstructuredPDFLoader: %s" % str(e0), flush=True) - e = e0 - handled |= len(doc1a) > 0 - # remove empty documents - doc1a = [x for x in doc1a if x.page_content] - add_parser(doc1a, 'UnstructuredPDFLoader') - # seems to not need cleaning in most cases - doc1.extend(doc1a) - if len(doc1) == 0 and use_pypdf == 'auto' or use_pypdf == 'on': - tried_others = True - # open-source fallback - # load() still chunks by pages, but every page has title at start to help - try: - doc1a = PyPDFLoader(file).load() - except BaseException as e0: - doc1a = [] - print("PyPDFLoader: %s" % str(e0), flush=True) - e = e0 - handled |= len(doc1a) > 0 - # remove empty documents - doc1a = [x for x in doc1a if x.page_content] - doc1a = clean_doc(doc1a) - add_parser(doc1a, 'PyPDFLoader') - doc1.extend(doc1a) - if not did_pymupdf and ((have_pymupdf and len(doc1) == 0) and tried_others): - # try again in case only others used, but only if didn't already try (2nd part of and) - # GPL, only use if installed - from langchain.document_loaders import PyMuPDFLoader - # load() still chunks by pages, but every page has title at start to help - try: - doc1a = PyMuPDFLoader(file).load() - except BaseException as e0: - doc1a = [] - print("PyMuPDFLoader: %s" % str(e0), flush=True) - e = e0 - handled |= len(doc1a) > 0 - # remove empty documents - doc1a = [x for x in doc1a if x.page_content] - doc1a = clean_doc(doc1a) - add_parser(doc1a, 'PyMuPDFLoader2') - doc1.extend(doc1a) - did_pdf_ocr = False - if len(doc1) == 0 and (enable_pdf_ocr == 'auto' and enable_pdf_doctr != 'on') or enable_pdf_ocr == 'on': - did_pdf_ocr = True - # no did_unstructured condition here because here we do OCR, and before we did not - # try OCR in end since slowest, but works on pure image pages well - doc1a = UnstructuredPDFLoader(file, strategy='ocr_only').load() - handled |= len(doc1a) > 0 - # remove empty documents - doc1a = [x for x in doc1a if x.page_content] - add_parser(doc1a, 'UnstructuredPDFLoader ocr_only') - # seems to not need cleaning in most cases - doc1.extend(doc1a) - # Some PDFs return nothing or junk from PDFMinerLoader - if len(doc1) == 0 and enable_pdf_doctr == 'auto' or enable_pdf_doctr == 'on': - if verbose: - print("BEGIN: DocTR", flush=True) - if model_loaders['doctr'] is not None and not isinstance(model_loaders['doctr'], (str, bool)): - model_loaders['doctr'].load_model() - else: - from image_doctr import H2OOCRLoader - model_loaders['doctr'] = H2OOCRLoader() - model_loaders['doctr'].set_document_paths([file]) - doc1a = model_loaders['doctr'].load() - doc1a = [x for x in doc1a if x.page_content] - add_meta(doc1a, file, parser='H2OOCRLoader: %s' % 'DocTR') - handled |= len(doc1a) > 0 - # caption didn't set source, so fix-up meta - for doci in doc1a: - doci.metadata['source'] = doci.metadata.get('document_path', file) - doci.metadata['hashid'] = hash_file(doci.metadata['source']) - doc1.extend(doc1a) - if verbose: - print("END: DocTR", flush=True) - if try_pdf_as_html in ['auto', 'on']: - doc1a = try_as_html(file) - add_parser(doc1a, 'try_as_html') - doc1.extend(doc1a) - - if len(doc1) == 0: - # if literally nothing, show failed to parse so user knows, since unlikely nothing in PDF at all. - if handled: - raise ValueError("%s had no valid text, but meta data was parsed" % file) - else: - raise ValueError("%s had no valid text and no meta data was parsed: %s" % (file, str(e))) - add_meta(doc1, file, parser='pdf') - doc1 = chunk_sources(doc1) - elif file.lower().endswith('.csv'): - CSV_SIZE_LIMIT = int(os.getenv('CSV_SIZE_LIMIT', str(10 * 1024 * 10 * 4))) - if os.path.getsize(file) > CSV_SIZE_LIMIT: - raise ValueError( - "CSV file sizes > %s not supported for naive parsing and embedding, requires Agents enabled" % CSV_SIZE_LIMIT) - doc1 = CSVLoader(file).load() - add_meta(doc1, file, parser='CSVLoader') - if isinstance(doc1, list): - # each row is a Document, identify - [x.metadata.update(dict(chunk_id=chunk_id)) for chunk_id, x in enumerate(doc1)] - if db_type in ['chroma', 'chroma_old']: - # then separate summarize list - sdoc1 = clone_documents(doc1) - [x.metadata.update(dict(chunk_id=-1)) for chunk_id, x in enumerate(sdoc1)] - doc1 = sdoc1 + doc1 - elif file.lower().endswith('.py'): - doc1 = PythonLoader(file).load() - add_meta(doc1, file, parser='PythonLoader') - doc1 = chunk_sources(doc1, language=Language.PYTHON) - elif file.lower().endswith('.toml'): - doc1 = TomlLoader(file).load() - add_meta(doc1, file, parser='TomlLoader') - doc1 = chunk_sources(doc1) - elif file.lower().endswith('.xml'): - from langchain.document_loaders import UnstructuredXMLLoader - loader = UnstructuredXMLLoader(file_path=file) - doc1 = loader.load() - add_meta(doc1, file, parser='UnstructuredXMLLoader') - elif file.lower().endswith('.urls'): - with open(file, "r") as f: - urls = f.readlines() - # recurse - doc1 = path_to_docs_func(None, url=urls) - elif file.lower().endswith('.zip'): - with zipfile.ZipFile(file, 'r') as zip_ref: - # don't put into temporary path, since want to keep references to docs inside zip - # so just extract in path where - zip_ref.extractall(base_path) - # recurse - doc1 = path_to_docs_func(base_path) - elif file.lower().endswith('.gz') or file.lower().endswith('.gzip'): - if file.lower().endswith('.gz'): - de_file = file.lower().replace('.gz', '') - else: - de_file = file.lower().replace('.gzip', '') - with gzip.open(file, 'rb') as f_in: - with open(de_file, 'wb') as f_out: - shutil.copyfileobj(f_in, f_out) - # recurse - doc1 = file_to_doc(de_file, - filei=filei, # single file, same file index as outside caller - base_path=base_path, verbose=verbose, fail_any_exception=fail_any_exception, - chunk=chunk, chunk_size=chunk_size, n_jobs=n_jobs, - is_url=is_url, is_txt=is_txt, - - # urls - use_unstructured=use_unstructured, - use_playwright=use_playwright, - use_selenium=use_selenium, - - # pdfs - use_pymupdf=use_pymupdf, - use_unstructured_pdf=use_unstructured_pdf, - use_pypdf=use_pypdf, - enable_pdf_ocr=enable_pdf_ocr, - enable_pdf_doctr=enable_pdf_doctr, - try_pdf_as_html=try_pdf_as_html, - - # images - enable_ocr=enable_ocr, - enable_doctr=enable_doctr, - enable_pix2struct=enable_pix2struct, - enable_captions=enable_captions, - captions_model=captions_model, - model_loaders=model_loaders, - - # json - jq_schema=jq_schema, - - headsize=headsize, - db_type=db_type, - selected_file_types=selected_file_types) - else: - raise RuntimeError("No file handler for %s" % os.path.basename(file)) - - # allow doc1 to be list or not. - if not isinstance(doc1, list): - # If not list, did not chunk yet, so chunk now - docs = chunk_sources([doc1]) - elif isinstance(doc1, list) and len(doc1) == 1: - # if list of length one, don't trust and chunk it, chunk_id's will still be correct if repeat - docs = chunk_sources(doc1) - else: - docs = doc1 - - assert isinstance(docs, list) - return docs - - -def path_to_doc1(file, - filei=0, - verbose=False, fail_any_exception=False, return_file=True, - chunk=True, chunk_size=512, - n_jobs=-1, - is_url=False, is_txt=False, - - # urls - use_unstructured=True, - use_playwright=False, - use_selenium=False, - - # pdfs - use_pymupdf='auto', - use_unstructured_pdf='auto', - use_pypdf='auto', - enable_pdf_ocr='auto', - enable_pdf_doctr='auto', - try_pdf_as_html='auto', - - # images - enable_ocr=False, - enable_doctr=False, - enable_pix2struct=False, - enable_captions=True, - captions_model=None, - model_loaders=None, - - # json - jq_schema='.[]', - - db_type=None, - selected_file_types=None): - assert db_type is not None - if verbose: - if is_url: - print("Ingesting URL: %s" % file, flush=True) - elif is_txt: - print("Ingesting Text: %s" % file, flush=True) - else: - print("Ingesting file: %s" % file, flush=True) - res = None - try: - # don't pass base_path=path, would infinitely recurse - res = file_to_doc(file, - filei=filei, - base_path=None, verbose=verbose, fail_any_exception=fail_any_exception, - chunk=chunk, chunk_size=chunk_size, - n_jobs=n_jobs, - is_url=is_url, is_txt=is_txt, - - # urls - use_unstructured=use_unstructured, - use_playwright=use_playwright, - use_selenium=use_selenium, - - # pdfs - use_pymupdf=use_pymupdf, - use_unstructured_pdf=use_unstructured_pdf, - use_pypdf=use_pypdf, - enable_pdf_ocr=enable_pdf_ocr, - enable_pdf_doctr=enable_pdf_doctr, - try_pdf_as_html=try_pdf_as_html, - - # images - enable_ocr=enable_ocr, - enable_doctr=enable_doctr, - enable_pix2struct=enable_pix2struct, - enable_captions=enable_captions, - captions_model=captions_model, - model_loaders=model_loaders, - - # json - jq_schema=jq_schema, - - db_type=db_type, - selected_file_types=selected_file_types) - except BaseException as e: - print("Failed to ingest %s due to %s" % (file, traceback.format_exc())) - if fail_any_exception: - raise - else: - exception_doc = Document( - page_content='', - metadata={"source": file, "exception": '%s Exception: %s' % (file, str(e)), - "traceback": traceback.format_exc()}) - res = [exception_doc] - if verbose: - if is_url: - print("DONE Ingesting URL: %s" % file, flush=True) - elif is_txt: - print("DONE Ingesting Text: %s" % file, flush=True) - else: - print("DONE Ingesting file: %s" % file, flush=True) - if return_file: - base_tmp = "temp_path_to_doc1" - if not os.path.isdir(base_tmp): - base_tmp = makedirs(base_tmp, exist_ok=True, tmp_ok=True, use_base=True) - filename = os.path.join(base_tmp, str(uuid.uuid4()) + ".tmp.pickle") - with open(filename, 'wb') as f: - pickle.dump(res, f) - return filename - return res - - -def path_to_docs(path_or_paths, verbose=False, fail_any_exception=False, n_jobs=-1, - chunk=True, chunk_size=512, - url=None, text=None, - - # urls - use_unstructured=True, - use_playwright=False, - use_selenium=False, - - # pdfs - use_pymupdf='auto', - use_unstructured_pdf='auto', - use_pypdf='auto', - enable_pdf_ocr='auto', - enable_pdf_doctr='auto', - try_pdf_as_html='auto', - - # images - enable_ocr=False, - enable_doctr=False, - enable_pix2struct=False, - enable_captions=True, - captions_model=None, - - caption_loader=None, - doctr_loader=None, - pix2struct_loader=None, - - # json - jq_schema='.[]', - - existing_files=[], - existing_hash_ids={}, - db_type=None, - selected_file_types=None, - ): - if verbose: - print("BEGIN Consuming path_or_paths=%s url=%s text=%s" % (path_or_paths, url, text), flush=True) - if selected_file_types is not None: - non_image_types1 = [x for x in non_image_types if x in selected_file_types] - image_types1 = [x for x in image_types if x in selected_file_types] - else: - non_image_types1 = non_image_types.copy() - image_types1 = image_types.copy() - - assert db_type is not None - # path_or_paths could be str, list, tuple, generator - globs_image_types = [] - globs_non_image_types = [] - if not path_or_paths and not url and not text: - return [] - elif url: - url = get_list_or_str(url) - globs_non_image_types = url if isinstance(url, (list, tuple, types.GeneratorType)) else [url] - elif text: - globs_non_image_types = text if isinstance(text, (list, tuple, types.GeneratorType)) else [text] - elif isinstance(path_or_paths, str) and os.path.isdir(path_or_paths): - # single path, only consume allowed files - path = path_or_paths - # Below globs should match patterns in file_to_doc() - [globs_image_types.extend(glob.glob(os.path.join(path, "./**/*.%s" % ftype), recursive=True)) - for ftype in image_types1] - globs_image_types = [os.path.normpath(x) for x in globs_image_types] - [globs_non_image_types.extend(glob.glob(os.path.join(path, "./**/*.%s" % ftype), recursive=True)) - for ftype in non_image_types1] - globs_non_image_types = [os.path.normpath(x) for x in globs_non_image_types] - else: - if isinstance(path_or_paths, str): - if os.path.isfile(path_or_paths) or os.path.isdir(path_or_paths): - path_or_paths = [path_or_paths] - else: - # path was deleted etc. - return [] - # list/tuple of files (consume what can, and exception those that selected but cannot consume so user knows) - assert isinstance(path_or_paths, (list, tuple, types.GeneratorType)), \ - "Wrong type for path_or_paths: %s %s" % (path_or_paths, type(path_or_paths)) - # reform out of allowed types - globs_image_types.extend( - flatten_list([[os.path.normpath(x) for x in path_or_paths if x.endswith(y)] for y in image_types1])) - # could do below: - # globs_non_image_types = flatten_list([[x for x in path_or_paths if x.endswith(y)] for y in non_image_types1]) - # But instead, allow fail so can collect unsupported too - set_globs_image_types = set(globs_image_types) - globs_non_image_types.extend([os.path.normpath(x) for x in path_or_paths if x not in set_globs_image_types]) - - # filter out any files to skip (e.g. if already processed them) - # this is easy, but too aggressive in case a file changed, so parent probably passed existing_files=[] - assert not existing_files, "DEV: assume not using this approach" - if existing_files: - set_skip_files = set(existing_files) - globs_image_types = [x for x in globs_image_types if x not in set_skip_files] - globs_non_image_types = [x for x in globs_non_image_types if x not in set_skip_files] - if existing_hash_ids: - # assume consistent with add_meta() use of hash_file(file) - # also assume consistent with get_existing_hash_ids for dict creation - # assume hashable values - existing_hash_ids_set = set(existing_hash_ids.items()) - hash_ids_all_image = set({x: hash_file(x) for x in globs_image_types}.items()) - hash_ids_all_non_image = set({x: hash_file(x) for x in globs_non_image_types}.items()) - # don't use symmetric diff. If file is gone, ignore and don't remove or something - # just consider existing files (key) having new hash or not (value) - new_files_image = set(dict(hash_ids_all_image - existing_hash_ids_set).keys()) - new_files_non_image = set(dict(hash_ids_all_non_image - existing_hash_ids_set).keys()) - globs_image_types = [x for x in globs_image_types if x in new_files_image] - globs_non_image_types = [x for x in globs_non_image_types if x in new_files_non_image] - - # could use generator, but messes up metadata handling in recursive case - if caption_loader and not isinstance(caption_loader, (bool, str)) and caption_loader.device != 'cpu' or \ - get_device() == 'cuda': - # to avoid deadlocks, presume was preloaded and so can't fork due to cuda context - # get_device() == 'cuda' because presume faster to process image from (temporarily) preloaded model - n_jobs_image = 1 - else: - n_jobs_image = n_jobs - if enable_doctr or enable_pdf_doctr in [True, 'auto', 'on']: - if doctr_loader and not isinstance(doctr_loader, (bool, str)) and doctr_loader.device != 'cpu': - # can't fork cuda context - n_jobs = 1 - - return_file = True # local choice - is_url = url is not None - is_txt = text is not None - model_loaders = dict(caption=caption_loader, - doctr=doctr_loader, - pix2struct=pix2struct_loader) - model_loaders0 = model_loaders.copy() - kwargs = dict(verbose=verbose, fail_any_exception=fail_any_exception, - return_file=return_file, - chunk=chunk, chunk_size=chunk_size, - n_jobs=n_jobs, - is_url=is_url, - is_txt=is_txt, - - # urls - use_unstructured=use_unstructured, - use_playwright=use_playwright, - use_selenium=use_selenium, - - # pdfs - use_pymupdf=use_pymupdf, - use_unstructured_pdf=use_unstructured_pdf, - use_pypdf=use_pypdf, - enable_pdf_ocr=enable_pdf_ocr, - enable_pdf_doctr=enable_pdf_doctr, - try_pdf_as_html=try_pdf_as_html, - - # images - enable_ocr=enable_ocr, - enable_doctr=enable_doctr, - enable_pix2struct=enable_pix2struct, - enable_captions=enable_captions, - captions_model=captions_model, - model_loaders=model_loaders, - - # json - jq_schema=jq_schema, - - db_type=db_type, - selected_file_types=selected_file_types, - ) - if n_jobs != 1 and len(globs_non_image_types) > 1: - # avoid nesting, e.g. upload 1 zip and then inside many files - # harder to handle if upload many zips with many files, inner parallel one will be disabled by joblib - documents = ProgressParallel(n_jobs=n_jobs, verbose=10 if verbose else 0, backend='multiprocessing')( - delayed(path_to_doc1)(file, filei=filei, **kwargs) for filei, file in enumerate(globs_non_image_types) - ) - else: - documents = [path_to_doc1(file, filei=filei, **kwargs) for filei, file in - enumerate(tqdm(globs_non_image_types))] - - # do images separately since can't fork after cuda in parent, so can't be parallel - if n_jobs_image != 1 and len(globs_image_types) > 1: - # avoid nesting, e.g. upload 1 zip and then inside many files - # harder to handle if upload many zips with many files, inner parallel one will be disabled by joblib - image_documents = ProgressParallel(n_jobs=n_jobs, verbose=10 if verbose else 0, backend='multiprocessing')( - delayed(path_to_doc1)(file, filei=filei, **kwargs) for filei, file in enumerate(globs_image_types) - ) - else: - image_documents = [path_to_doc1(file, filei=filei, **kwargs) for filei, file in - enumerate(tqdm(globs_image_types))] - - # unload loaders (image loaders, includes enable_pdf_doctr that uses same loader) - for name, loader in model_loaders.items(): - loader0 = model_loaders0[name] - real_model_initial = loader0 is not None and not isinstance(loader0, (str, bool)) - real_model_final = model_loaders[name] is not None and not isinstance(model_loaders[name], (str, bool)) - if not real_model_initial and real_model_final: - # clear off GPU newly added model - model_loaders[name].unload_model() - - # add image docs in - documents += image_documents - - if return_file: - # then documents really are files - files = documents.copy() - documents = [] - for fil in files: - with open(fil, 'rb') as f: - documents.extend(pickle.load(f)) - # remove temp pickle - remove(fil) - else: - documents = reduce(concat, documents) - - if verbose: - print("END consuming path_or_paths=%s url=%s text=%s" % (path_or_paths, url, text), flush=True) - return documents - - -def prep_langchain(persist_directory, - load_db_if_exists, - db_type, use_openai_embedding, - langchain_mode, langchain_mode_paths, langchain_mode_types, - hf_embedding_model, - migrate_embedding_model, - auto_migrate_db, - n_jobs=-1, kwargs_make_db={}, - verbose=False): - """ - do prep first time, involving downloads - # FIXME: Add github caching then add here - :return: - """ - if os.getenv("HARD_ASSERTS"): - assert langchain_mode not in ['MyData'], "Should not prep scratch/personal data" - - if langchain_mode in langchain_modes_intrinsic: - return None - - db_dir_exists = os.path.isdir(persist_directory) - user_path = langchain_mode_paths.get(langchain_mode) - - if db_dir_exists and user_path is None: - if verbose: - print("Prep: persist_directory=%s exists, using" % persist_directory, flush=True) - db, use_openai_embedding, hf_embedding_model = \ - get_existing_db(None, persist_directory, load_db_if_exists, - db_type, use_openai_embedding, - langchain_mode, langchain_mode_paths, langchain_mode_types, - hf_embedding_model, migrate_embedding_model, auto_migrate_db, - n_jobs=n_jobs) - else: - if db_dir_exists and user_path is not None: - if verbose: - print("Prep: persist_directory=%s exists, user_path=%s passed, adding any changed or new documents" % ( - persist_directory, user_path), flush=True) - elif not db_dir_exists: - if verbose: - print("Prep: persist_directory=%s does not exist, regenerating" % persist_directory, flush=True) - db = None - if langchain_mode in ['DriverlessAI docs']: - # FIXME: Could also just use dai_docs.pickle directly and upload that - get_dai_docs(from_hf=True) - - if langchain_mode in ['wiki']: - get_wiki_sources(first_para=kwargs_make_db['first_para'], text_limit=kwargs_make_db['text_limit']) - - langchain_kwargs = kwargs_make_db.copy() - langchain_kwargs.update(locals()) - db, num_new_sources, new_sources_metadata = make_db(**langchain_kwargs) - - return db - - -import posthog - -posthog.disabled = True - - -class FakeConsumer(object): - def __init__(self, *args, **kwargs): - pass - - def run(self): - pass - - def pause(self): - pass - - def upload(self): - pass - - def next(self): - pass - - def request(self, batch): - pass - - -posthog.Consumer = FakeConsumer - - -def check_update_chroma_embedding(db, - db_type, - use_openai_embedding, - hf_embedding_model, migrate_embedding_model, auto_migrate_db, - langchain_mode, langchain_mode_paths, langchain_mode_types, - n_jobs=-1): - changed_db = False - embed_tuple = load_embed(db=db) - if embed_tuple not in [(True, use_openai_embedding, hf_embedding_model), - (False, use_openai_embedding, hf_embedding_model)]: - print("Detected new embedding %s vs. %s %s, updating db: %s" % ( - use_openai_embedding, hf_embedding_model, embed_tuple, langchain_mode), flush=True) - # handle embedding changes - db_get = get_documents(db) - sources = [Document(page_content=result[0], metadata=result[1] or {}) - for result in zip(db_get['documents'], db_get['metadatas'])] - # delete index, has to be redone - persist_directory = db._persist_directory - shutil.move(persist_directory, persist_directory + "_" + str(uuid.uuid4()) + ".bak") - assert db_type in ['chroma', 'chroma_old'] - load_db_if_exists = False - db = get_db(sources, use_openai_embedding=use_openai_embedding, db_type=db_type, - persist_directory=persist_directory, load_db_if_exists=load_db_if_exists, - langchain_mode=langchain_mode, - langchain_mode_paths=langchain_mode_paths, - langchain_mode_types=langchain_mode_types, - collection_name=None, - hf_embedding_model=hf_embedding_model, - migrate_embedding_model=migrate_embedding_model, - auto_migrate_db=auto_migrate_db, - n_jobs=n_jobs, - ) - changed_db = True - print("Done updating db for new embedding: %s" % langchain_mode, flush=True) - - return db, changed_db - - -def migrate_meta_func(db, langchain_mode): - changed_db = False - db_get = get_documents(db) - # just check one doc - if len(db_get['metadatas']) > 0 and 'chunk_id' not in db_get['metadatas'][0]: - print("Detected old metadata, adding additional information", flush=True) - t0 = time.time() - # handle meta changes - [x.update(dict(chunk_id=x.get('chunk_id', 0))) for x in db_get['metadatas']] - client_collection = db._client.get_collection(name=db._collection.name, - embedding_function=db._collection._embedding_function) - client_collection.update(ids=db_get['ids'], metadatas=db_get['metadatas']) - # check - db_get = get_documents(db) - assert 'chunk_id' in db_get['metadatas'][0], "Failed to add meta" - changed_db = True - print("Done updating db for new meta: %s in %s seconds" % (langchain_mode, time.time() - t0), flush=True) - - return db, changed_db - - -def get_existing_db(db, persist_directory, - load_db_if_exists, db_type, use_openai_embedding, - langchain_mode, langchain_mode_paths, langchain_mode_types, - hf_embedding_model, - migrate_embedding_model, - auto_migrate_db=False, - verbose=False, check_embedding=True, migrate_meta=True, - n_jobs=-1): - if load_db_if_exists and db_type in ['chroma', 'chroma_old'] and os.path.isdir(persist_directory): - if os.path.isfile(os.path.join(persist_directory, 'chroma.sqlite3')): - must_migrate = False - elif os.path.isdir(os.path.join(persist_directory, 'index')): - must_migrate = True - else: - return db, use_openai_embedding, hf_embedding_model - chroma_settings = dict(is_persistent=True) - use_chromamigdb = False - if must_migrate: - if auto_migrate_db: - print("Detected chromadb<0.4 database, require migration, doing now....", flush=True) - from chroma_migrate.import_duckdb import migrate_from_duckdb - import chromadb - api = chromadb.PersistentClient(path=persist_directory) - did_migration = migrate_from_duckdb(api, persist_directory) - assert did_migration, "Failed to migrate chroma collection at %s, see https://docs.trychroma.com/migration for CLI tool" % persist_directory - elif have_chromamigdb: - print( - "Detected chroma<0.4 database but --auto_migrate_db=False, but detected chromamigdb package, so using old database that still requires duckdb", - flush=True) - chroma_settings = dict(chroma_db_impl="duckdb+parquet") - use_chromamigdb = True - else: - raise ValueError( - "Detected chromadb<0.4 database, require migration, but did not detect chromamigdb package or did not choose auto_migrate_db=False (see FAQ.md)") - - if db is None: - if verbose: - print("DO Loading db: %s" % langchain_mode, flush=True) - got_embedding, use_openai_embedding0, hf_embedding_model0 = load_embed(persist_directory=persist_directory) - if got_embedding: - use_openai_embedding, hf_embedding_model = use_openai_embedding0, hf_embedding_model0 - embedding = get_embedding(use_openai_embedding, hf_embedding_model=hf_embedding_model) - import logging - logging.getLogger("chromadb").setLevel(logging.ERROR) - if use_chromamigdb: - from chromamigdb.config import Settings - chroma_class = ChromaMig - else: - from chromadb.config import Settings - chroma_class = Chroma - client_settings = Settings(anonymized_telemetry=False, - **chroma_settings, - persist_directory=persist_directory) - db = chroma_class(persist_directory=persist_directory, embedding_function=embedding, - collection_name=langchain_mode.replace(' ', '_'), - client_settings=client_settings) - try: - db.similarity_search('') - except BaseException as e: - # migration when no embed_info - if 'Dimensionality of (768) does not match index dimensionality (384)' in str(e) or \ - 'Embedding dimension 768 does not match collection dimensionality 384' in str(e): - hf_embedding_model = "sentence-transformers/all-MiniLM-L6-v2" - embedding = get_embedding(use_openai_embedding, hf_embedding_model=hf_embedding_model) - db = chroma_class(persist_directory=persist_directory, embedding_function=embedding, - collection_name=langchain_mode.replace(' ', '_'), - client_settings=client_settings) - # should work now, let fail if not - db.similarity_search('') - save_embed(db, use_openai_embedding, hf_embedding_model) - else: - raise - - if verbose: - print("DONE Loading db: %s" % langchain_mode, flush=True) - else: - if not migrate_embedding_model: - # OVERRIDE embedding choices if could load embedding info when not migrating - got_embedding, use_openai_embedding, hf_embedding_model = load_embed(db=db) - if verbose: - print("USING already-loaded db: %s" % langchain_mode, flush=True) - if check_embedding: - db_trial, changed_db = check_update_chroma_embedding(db, - db_type, - use_openai_embedding, - hf_embedding_model, - migrate_embedding_model, - auto_migrate_db, - langchain_mode, - langchain_mode_paths, - langchain_mode_types, - n_jobs=n_jobs) - if changed_db: - db = db_trial - # only call persist if really changed db, else takes too long for large db - if db is not None: - db.persist() - clear_embedding(db) - save_embed(db, use_openai_embedding, hf_embedding_model) - if migrate_meta and db is not None: - db_trial, changed_db = migrate_meta_func(db, langchain_mode) - if changed_db: - db = db_trial - return db, use_openai_embedding, hf_embedding_model - return db, use_openai_embedding, hf_embedding_model - - -def clear_embedding(db): - if db is None: - return - # don't keep on GPU, wastes memory, push back onto CPU and only put back on GPU once again embed - try: - if hasattr(db._embedding_function, 'client') and hasattr(db._embedding_function.client, 'cpu'): - # only push back to CPU if each db/user has own embedding model, else if shared share on GPU - if hasattr(db._embedding_function.client, 'preload') and not db._embedding_function.client.preload: - db._embedding_function.client.cpu() - clear_torch_cache() - except RuntimeError as e: - print("clear_embedding error: %s" % ''.join(traceback.format_tb(e.__traceback__)), flush=True) - - -def make_db(**langchain_kwargs): - func_names = list(inspect.signature(_make_db).parameters) - missing_kwargs = [x for x in func_names if x not in langchain_kwargs] - defaults_db = {k: v.default for k, v in dict(inspect.signature(run_qa_db).parameters).items()} - for k in missing_kwargs: - if k in defaults_db: - langchain_kwargs[k] = defaults_db[k] - # final check for missing - missing_kwargs = [x for x in func_names if x not in langchain_kwargs] - assert not missing_kwargs, "Missing kwargs for make_db: %s" % missing_kwargs - # only keep actual used - langchain_kwargs = {k: v for k, v in langchain_kwargs.items() if k in func_names} - return _make_db(**langchain_kwargs) - - -embed_lock_name = 'embed.lock' - - -def get_embed_lock_file(db, persist_directory=None): - if hasattr(db, '_persist_directory') or persist_directory: - if persist_directory is None: - persist_directory = db._persist_directory - check_persist_directory(persist_directory) - base_path = os.path.join('locks', persist_directory) - base_path = makedirs(base_path, exist_ok=True, tmp_ok=True, use_base=True) - lock_file = os.path.join(base_path, embed_lock_name) - makedirs(os.path.dirname(lock_file)) - return lock_file - return None - - -def save_embed(db, use_openai_embedding, hf_embedding_model): - if hasattr(db, '_persist_directory'): - persist_directory = db._persist_directory - lock_file = get_embed_lock_file(db) - with filelock.FileLock(lock_file): - embed_info_file = os.path.join(persist_directory, 'embed_info') - with open(embed_info_file, 'wb') as f: - if isinstance(hf_embedding_model, str): - hf_embedding_model_save = hf_embedding_model - elif hasattr(hf_embedding_model, 'model_name'): - hf_embedding_model_save = hf_embedding_model.model_name - elif isinstance(hf_embedding_model, dict) and 'name' in hf_embedding_model: - hf_embedding_model_save = hf_embedding_model['name'] - elif isinstance(hf_embedding_model, dict) and 'name' in hf_embedding_model: - if os.getenv('HARD_ASSERTS'): - # unexpected in testing or normally - raise RuntimeError("HERE") - hf_embedding_model_save = 'hkunlp/instructor-large' - pickle.dump((use_openai_embedding, hf_embedding_model_save), f) - return use_openai_embedding, hf_embedding_model - - -def load_embed(db=None, persist_directory=None): - if hasattr(db, 'embeddings') and hasattr(db.embeddings, 'model_name'): - hf_embedding_model = db.embeddings.model_name if 'openai' not in db.embeddings.model_name.lower() else None - use_openai_embedding = hf_embedding_model is None - save_embed(db, use_openai_embedding, hf_embedding_model) - return True, use_openai_embedding, hf_embedding_model - if persist_directory is None: - persist_directory = db._persist_directory - embed_info_file = os.path.join(persist_directory, 'embed_info') - if os.path.isfile(embed_info_file): - lock_file = get_embed_lock_file(db, persist_directory=persist_directory) - with filelock.FileLock(lock_file): - with open(embed_info_file, 'rb') as f: - try: - use_openai_embedding, hf_embedding_model = pickle.load(f) - if not isinstance(hf_embedding_model, str): - # work-around bug introduced here: https://github.com/h2oai/h2ogpt/commit/54c4414f1ce3b5b7c938def651c0f6af081c66de - hf_embedding_model = 'hkunlp/instructor-large' - # fix file - save_embed(db, use_openai_embedding, hf_embedding_model) - got_embedding = True - except EOFError: - use_openai_embedding, hf_embedding_model = False, 'hkunlp/instructor-large' - got_embedding = False - if os.getenv('HARD_ASSERTS'): - # unexpected in testing or normally - raise - else: - # migration, assume defaults - use_openai_embedding, hf_embedding_model = False, "sentence-transformers/all-MiniLM-L6-v2" - got_embedding = False - assert isinstance(hf_embedding_model, str) - return got_embedding, use_openai_embedding, hf_embedding_model - - -def get_persist_directory(langchain_mode, langchain_type=None, db1s=None, dbs=None): - if langchain_mode in [LangChainMode.DISABLED.value, LangChainMode.LLM.value]: - # not None so join works but will fail to find db - return '', langchain_type - - userid = get_userid_direct(db1s) - username = get_username_direct(db1s) - - # sanity for bad code - assert userid != 'None' - assert username != 'None' - - dirid = username or userid - if langchain_type == LangChainTypes.SHARED.value and not dirid: - dirid = './' # just to avoid error - if langchain_type == LangChainTypes.PERSONAL.value and not dirid: - # e.g. from client when doing transient calls with MyData - if db1s is None: - # just trick to get filled locally - db1s = {LangChainMode.MY_DATA.value: [None, None, None]} - set_userid_direct(db1s, str(uuid.uuid4()), str(uuid.uuid4())) - userid = get_userid_direct(db1s) - username = get_username_direct(db1s) - dirid = username or userid - langchain_type = LangChainTypes.PERSONAL.value - - # deal with existing locations - user_base_dir = os.getenv('USERS_BASE_DIR', 'users') - persist_directory = os.path.join(user_base_dir, dirid, 'db_dir_%s' % langchain_mode) - if userid and \ - (os.path.isdir(persist_directory) or - db1s is not None and langchain_mode in db1s or - langchain_type == LangChainTypes.PERSONAL.value): - langchain_type = LangChainTypes.PERSONAL.value - persist_directory = makedirs(persist_directory, use_base=True) - check_persist_directory(persist_directory) - return persist_directory, langchain_type - - persist_directory = 'db_dir_%s' % langchain_mode - if (os.path.isdir(persist_directory) or - dbs is not None and langchain_mode in dbs or - langchain_type == LangChainTypes.SHARED.value): - # ensure consistent - langchain_type = LangChainTypes.SHARED.value - persist_directory = makedirs(persist_directory, use_base=True) - check_persist_directory(persist_directory) - return persist_directory, langchain_type - - # dummy return for prep_langchain() or full personal space - base_others = 'db_nonusers' - persist_directory = os.path.join(base_others, 'db_dir_%s' % str(uuid.uuid4())) - persist_directory = makedirs(persist_directory, use_base=True) - langchain_type = LangChainTypes.PERSONAL.value - - check_persist_directory(persist_directory) - return persist_directory, langchain_type - - -def check_persist_directory(persist_directory): - # deal with some cases when see intrinsic names being used as shared - for langchain_mode in langchain_modes_intrinsic: - if persist_directory == 'db_dir_%s' % langchain_mode: - raise RuntimeError("Illegal access to %s" % persist_directory) - - -def _make_db(use_openai_embedding=False, - hf_embedding_model=None, - migrate_embedding_model=False, - auto_migrate_db=False, - first_para=False, text_limit=None, - chunk=True, chunk_size=512, - - # urls - use_unstructured=True, - use_playwright=False, - use_selenium=False, - - # pdfs - use_pymupdf='auto', - use_unstructured_pdf='auto', - use_pypdf='auto', - enable_pdf_ocr='auto', - enable_pdf_doctr='auto', - try_pdf_as_html='auto', - - # images - enable_ocr=False, - enable_doctr=False, - enable_pix2struct=False, - enable_captions=True, - captions_model=None, - caption_loader=None, - doctr_loader=None, - pix2struct_loader=None, - - # json - jq_schema='.[]', - - langchain_mode=None, - langchain_mode_paths=None, - langchain_mode_types=None, - db_type='faiss', - load_db_if_exists=True, - db=None, - n_jobs=-1, - verbose=False): - assert hf_embedding_model is not None - user_path = langchain_mode_paths.get(langchain_mode) - langchain_type = langchain_mode_types.get(langchain_mode, LangChainTypes.EITHER.value) - persist_directory, langchain_type = get_persist_directory(langchain_mode, langchain_type=langchain_type) - langchain_mode_types[langchain_mode] = langchain_type - # see if can get persistent chroma db - db_trial, use_openai_embedding, hf_embedding_model = \ - get_existing_db(db, persist_directory, load_db_if_exists, db_type, - use_openai_embedding, - langchain_mode, langchain_mode_paths, langchain_mode_types, - hf_embedding_model, migrate_embedding_model, auto_migrate_db, verbose=verbose, - n_jobs=n_jobs) - if db_trial is not None: - db = db_trial - - sources = [] - if not db: - chunk_sources = functools.partial(_chunk_sources, chunk=chunk, chunk_size=chunk_size, db_type=db_type) - if langchain_mode in ['wiki_full']: - from read_wiki_full import get_all_documents - small_test = None - print("Generating new wiki", flush=True) - sources1 = get_all_documents(small_test=small_test, n_jobs=os.cpu_count() // 2) - print("Got new wiki", flush=True) - sources1 = chunk_sources(sources1, chunk=chunk) - print("Chunked new wiki", flush=True) - sources.extend(sources1) - elif langchain_mode in ['wiki']: - sources1 = get_wiki_sources(first_para=first_para, text_limit=text_limit) - sources1 = chunk_sources(sources1, chunk=chunk) - sources.extend(sources1) - elif langchain_mode in ['github h2oGPT']: - # sources = get_github_docs("dagster-io", "dagster") - sources1 = get_github_docs("h2oai", "h2ogpt") - # FIXME: always chunk for now - sources1 = chunk_sources(sources1) - sources.extend(sources1) - elif langchain_mode in ['DriverlessAI docs']: - sources1 = get_dai_docs(from_hf=True) - # FIXME: DAI docs are already chunked well, should only chunk more if over limit - sources1 = chunk_sources(sources1, chunk=False) - sources.extend(sources1) - if user_path: - # UserData or custom, which has to be from user's disk - if db is not None: - # NOTE: Ignore file names for now, only go by hash ids - # existing_files = get_existing_files(db) - existing_files = [] - existing_hash_ids = get_existing_hash_ids(db) - else: - # pretend no existing files so won't filter - existing_files = [] - existing_hash_ids = [] - # chunk internally for speed over multiple docs - # FIXME: If first had old Hash=None and switch embeddings, - # then re-embed, and then hit here and reload so have hash, and then re-embed. - sources1 = path_to_docs(user_path, n_jobs=n_jobs, chunk=chunk, chunk_size=chunk_size, - # urls - use_unstructured=use_unstructured, - use_playwright=use_playwright, - use_selenium=use_selenium, - - # pdfs - use_pymupdf=use_pymupdf, - use_unstructured_pdf=use_unstructured_pdf, - use_pypdf=use_pypdf, - enable_pdf_ocr=enable_pdf_ocr, - enable_pdf_doctr=enable_pdf_doctr, - try_pdf_as_html=try_pdf_as_html, - - # images - enable_ocr=enable_ocr, - enable_doctr=enable_doctr, - enable_pix2struct=enable_pix2struct, - enable_captions=enable_captions, - captions_model=captions_model, - caption_loader=caption_loader, - doctr_loader=doctr_loader, - pix2struct_loader=pix2struct_loader, - - # json - jq_schema=jq_schema, - - existing_files=existing_files, existing_hash_ids=existing_hash_ids, - db_type=db_type) - new_metadata_sources = set([x.metadata['source'] for x in sources1]) - if new_metadata_sources: - if os.getenv('NO_NEW_FILES') is not None: - raise RuntimeError("Expected no new files! %s" % new_metadata_sources) - print("Loaded %s new files as sources to add to %s" % (len(new_metadata_sources), langchain_mode), - flush=True) - if verbose: - print("Files added: %s" % '\n'.join(new_metadata_sources), flush=True) - sources.extend(sources1) - if len(sources) > 0 and os.getenv('NO_NEW_FILES') is not None: - raise RuntimeError("Expected no new files! %s" % langchain_mode) - if len(sources) == 0 and os.getenv('SHOULD_NEW_FILES') is not None: - raise RuntimeError("Expected new files! %s" % langchain_mode) - print("Loaded %s sources for potentially adding to %s" % (len(sources), langchain_mode), flush=True) - - # see if got sources - if not sources: - if verbose: - if db is not None: - print("langchain_mode %s has no new sources, nothing to add to db" % langchain_mode, flush=True) - else: - print("langchain_mode %s has no sources, not making new db" % langchain_mode, flush=True) - return db, 0, [] - if verbose: - if db is not None: - print("Generating db", flush=True) - else: - print("Adding to db", flush=True) - if not db: - if sources: - db = get_db(sources, use_openai_embedding=use_openai_embedding, db_type=db_type, - persist_directory=persist_directory, - langchain_mode=langchain_mode, - langchain_mode_paths=langchain_mode_paths, - langchain_mode_types=langchain_mode_types, - hf_embedding_model=hf_embedding_model, - migrate_embedding_model=migrate_embedding_model, - auto_migrate_db=auto_migrate_db, - n_jobs=n_jobs) - if verbose: - print("Generated db", flush=True) - elif langchain_mode not in langchain_modes_intrinsic: - print("Did not generate db for %s since no sources" % langchain_mode, flush=True) - new_sources_metadata = [x.metadata for x in sources] - elif user_path is not None: - print("Existing db, potentially adding %s sources from user_path=%s" % (len(sources), user_path), flush=True) - db, num_new_sources, new_sources_metadata = add_to_db(db, sources, db_type=db_type, - use_openai_embedding=use_openai_embedding, - hf_embedding_model=hf_embedding_model) - print("Existing db, added %s new sources from user_path=%s" % (num_new_sources, user_path), flush=True) - else: - new_sources_metadata = [x.metadata for x in sources] - - return db, len(new_sources_metadata), new_sources_metadata - - -def get_metadatas(db): - metadatas = [] - from langchain.vectorstores import FAISS - if isinstance(db, FAISS): - metadatas = [v.metadata for k, v in db.docstore._dict.items()] - elif isinstance(db, Chroma) or isinstance(db, ChromaMig) or ChromaMig.__name__ in str(db): - metadatas = get_documents(db)['metadatas'] - elif db is not None: - # FIXME: Hack due to https://github.com/weaviate/weaviate/issues/1947 - # seems no way to get all metadata, so need to avoid this approach for weaviate - metadatas = [x.metadata for x in db.similarity_search("", k=10000)] - return metadatas - - -def get_db_lock_file(db, lock_type='getdb'): - if hasattr(db, '_persist_directory'): - persist_directory = db._persist_directory - check_persist_directory(persist_directory) - base_path = os.path.join('locks', persist_directory) - base_path = makedirs(base_path, exist_ok=True, tmp_ok=True, use_base=True) - lock_file = os.path.join(base_path, "%s.lock" % lock_type) - makedirs(os.path.dirname(lock_file)) # ensure made - return lock_file - return None - - -def get_documents(db): - if hasattr(db, '_persist_directory'): - lock_file = get_db_lock_file(db) - with filelock.FileLock(lock_file): - # get segfaults and other errors when multiple threads access this - return _get_documents(db) - else: - return _get_documents(db) - - -def _get_documents(db): - from langchain.vectorstores import FAISS - if isinstance(db, FAISS): - documents = [v for k, v in db.docstore._dict.items()] - documents = dict(documents=documents) - elif isinstance(db, Chroma) or isinstance(db, ChromaMig) or ChromaMig.__name__ in str(db): - documents = db.get() - else: - # FIXME: Hack due to https://github.com/weaviate/weaviate/issues/1947 - # seems no way to get all metadata, so need to avoid this approach for weaviate - documents = [x for x in db.similarity_search("", k=10000)] - documents = dict(documents=documents) - return documents - - -def get_docs_and_meta(db, top_k_docs, filter_kwargs={}, text_context_list=None): - if hasattr(db, '_persist_directory'): - lock_file = get_db_lock_file(db) - with filelock.FileLock(lock_file): - return _get_docs_and_meta(db, top_k_docs, filter_kwargs=filter_kwargs, text_context_list=text_context_list) - else: - return _get_docs_and_meta(db, top_k_docs, filter_kwargs=filter_kwargs, text_context_list=text_context_list) - - -def _get_docs_and_meta(db, top_k_docs, filter_kwargs={}, text_context_list=None): - db_documents = [] - db_metadatas = [] - - if text_context_list: - db_documents += [x.page_content if hasattr(x, 'page_content') else x for x in text_context_list] - db_metadatas += [x.metadata if hasattr(x, 'metadata') else {} for x in text_context_list] - - from langchain.vectorstores import FAISS - if isinstance(db, Chroma) or isinstance(db, ChromaMig) or ChromaMig.__name__ in str(db): - db_get = db._collection.get(where=filter_kwargs.get('filter')) - db_metadatas += db_get['metadatas'] - db_documents += db_get['documents'] - elif isinstance(db, FAISS): - import itertools - db_metadatas += get_metadatas(db) - # FIXME: FAISS has no filter - if top_k_docs == -1: - db_documents += list(db.docstore._dict.values()) - else: - # slice dict first - db_documents += list(dict(itertools.islice(db.docstore._dict.items(), top_k_docs)).values()) - elif db is not None: - db_metadatas += get_metadatas(db) - db_documents += get_documents(db)['documents'] - - return db_documents, db_metadatas - - -def get_existing_files(db): - metadatas = get_metadatas(db) - metadata_sources = set([x['source'] for x in metadatas]) - return metadata_sources - - -def get_existing_hash_ids(db): - metadatas = get_metadatas(db) - # assume consistency, that any prior hashed source was single hashed file at the time among all source chunks - metadata_hash_ids = {os.path.normpath(x['source']): x.get('hashid') for x in metadatas} - return metadata_hash_ids - - -def run_qa_db(**kwargs): - func_names = list(inspect.signature(_run_qa_db).parameters) - # hard-coded defaults - kwargs['answer_with_sources'] = kwargs.get('answer_with_sources', True) - kwargs['show_rank'] = kwargs.get('show_rank', False) - kwargs['show_accordions'] = kwargs.get('show_accordions', True) - kwargs['show_link_in_sources'] = kwargs.get('show_link_in_sources', True) - kwargs['top_k_docs_max_show'] = kwargs.get('top_k_docs_max_show', 10) - kwargs['llamacpp_dict'] = {} # shouldn't be required unless from test using _run_qa_db - missing_kwargs = [x for x in func_names if x not in kwargs] - assert not missing_kwargs, "Missing kwargs for run_qa_db: %s" % missing_kwargs - # only keep actual used - kwargs = {k: v for k, v in kwargs.items() if k in func_names} - try: - return _run_qa_db(**kwargs) - finally: - clear_torch_cache() - - -def _run_qa_db(query=None, - iinput=None, - context=None, - use_openai_model=False, use_openai_embedding=False, - first_para=False, text_limit=None, top_k_docs=4, chunk=True, chunk_size=512, - - # urls - use_unstructured=True, - use_playwright=False, - use_selenium=False, - - # pdfs - use_pymupdf='auto', - use_unstructured_pdf='auto', - use_pypdf='auto', - enable_pdf_ocr='auto', - enable_pdf_doctr='auto', - try_pdf_as_html='auto', - - # images - enable_ocr=False, - enable_doctr=False, - enable_pix2struct=False, - enable_captions=True, - captions_model=None, - caption_loader=None, - doctr_loader=None, - pix2struct_loader=None, - - # json - jq_schema='.[]', - - langchain_mode_paths={}, - langchain_mode_types={}, - detect_user_path_changes_every_query=False, - db_type=None, - model_name=None, model=None, tokenizer=None, inference_server=None, - langchain_only_model=False, - hf_embedding_model=None, - migrate_embedding_model=False, - auto_migrate_db=False, - stream_output=False, - async_output=True, - num_async=3, - prompter=None, - prompt_type=None, - prompt_dict=None, - answer_with_sources=True, - append_sources_to_answer=True, - cut_distance=1.64, - add_chat_history_to_context=True, - add_search_to_context=False, - keep_sources_in_context=False, - memory_restriction_level=0, - system_prompt='', - sanitize_bot_response=False, - show_rank=False, - show_accordions=True, - show_link_in_sources=True, - top_k_docs_max_show=10, - use_llm_if_no_docs=True, - load_db_if_exists=False, - db=None, - do_sample=False, - temperature=0.1, - top_k=40, - top_p=0.7, - num_beams=1, - max_new_tokens=512, - min_new_tokens=1, - early_stopping=False, - max_time=180, - repetition_penalty=1.0, - num_return_sequences=1, - langchain_mode=None, - langchain_action=None, - langchain_agents=None, - document_subset=DocumentSubset.Relevant.name, - document_choice=[DocumentChoice.ALL.value], - pre_prompt_query=None, - prompt_query=None, - pre_prompt_summary=None, - prompt_summary=None, - text_context_list=None, - chat_conversation=None, - visible_models=None, - h2ogpt_key=None, - docs_ordering_type='reverse_ucurve_sort', - min_max_new_tokens=256, - - n_jobs=-1, - llamacpp_dict=None, - verbose=False, - cli=False, - lora_weights='', - auto_reduce_chunks=True, - max_chunks=100, - total_tokens_for_docs=None, - headsize=50, - ): - """ - - :param query: - :param use_openai_model: - :param use_openai_embedding: - :param first_para: - :param text_limit: - :param top_k_docs: - :param chunk: - :param chunk_size: - :param langchain_mode_paths: dict of langchain_mode -> user path to glob recursively from - :param db_type: 'faiss' for in-memory - 'chroma' (for chroma >= 0.4) - 'chroma_old' (for chroma < 0.4) - 'weaviate' for persisted on disk - :param model_name: model name, used to switch behaviors - :param model: pre-initialized model, else will make new one - :param tokenizer: pre-initialized tokenizer, else will make new one. Required not None if model is not None - :param answer_with_sources - :return: - """ - t_run = time.time() - if stream_output: - # threads and asyncio don't mix - async_output = False - if langchain_action in [LangChainAction.QUERY.value]: - # only summarization supported - async_output = False - - # in case None, e.g. lazy client, then set based upon actual model - pre_prompt_query, prompt_query, pre_prompt_summary, prompt_summary = \ - get_langchain_prompts(pre_prompt_query, prompt_query, - pre_prompt_summary, prompt_summary, - model_name, inference_server, - llamacpp_dict.get('model_path_llama')) - - assert db_type is not None - assert hf_embedding_model is not None - assert langchain_mode_paths is not None - assert langchain_mode_types is not None - if model is not None: - assert model_name is not None # require so can make decisions - assert query is not None - assert prompter is not None or prompt_type is not None or model is None # if model is None, then will generate - if prompter is not None: - prompt_type = prompter.prompt_type - prompt_dict = prompter.prompt_dict - if model is not None: - assert prompt_type is not None - if prompt_type == PromptType.custom.name: - assert prompt_dict is not None # should at least be {} or '' - else: - prompt_dict = '' - - if LangChainAgent.SEARCH.value in langchain_agents and 'llama' in model_name.lower(): - system_prompt = """You are a zero shot react agent. -Consider to prompt of Question that was original query from the user. -Respond to prompt of Thought with a thought that may lead to a reasonable new action choice. -Respond to prompt of Action with an action to take out of the tools given, giving exactly single word for the tool name. -Respond to prompt of Action Input with an input to give the tool. -Consider to prompt of Observation that was response from the tool. -Repeat this Thought, Action, Action Input, Observation, Thought sequence several times with new and different thoughts and actions each time, do not repeat. -Once satisfied that the thoughts, responses are sufficient to answer the question, then respond to prompt of Thought with: I now know the final answer -Respond to prompt of Final Answer with your final high-quality bullet list answer to the original query. -""" - prompter.system_prompt = system_prompt - - assert len(set(gen_hyper).difference(inspect.signature(get_llm).parameters)) == 0 - # pass in context to LLM directly, since already has prompt_type structure - # can't pass through langchain in get_chain() to LLM: https://github.com/hwchase17/langchain/issues/6638 - llm, model_name, streamer, prompt_type_out, async_output, only_new_text = \ - get_llm(use_openai_model=use_openai_model, model_name=model_name, - model=model, - tokenizer=tokenizer, - inference_server=inference_server, - langchain_only_model=langchain_only_model, - stream_output=stream_output, - async_output=async_output, - num_async=num_async, - do_sample=do_sample, - temperature=temperature, - top_k=top_k, - top_p=top_p, - num_beams=num_beams, - max_new_tokens=max_new_tokens, - min_new_tokens=min_new_tokens, - early_stopping=early_stopping, - max_time=max_time, - repetition_penalty=repetition_penalty, - num_return_sequences=num_return_sequences, - prompt_type=prompt_type, - prompt_dict=prompt_dict, - prompter=prompter, - context=context, - iinput=iinput, - sanitize_bot_response=sanitize_bot_response, - system_prompt=system_prompt, - visible_models=visible_models, - h2ogpt_key=h2ogpt_key, - min_max_new_tokens=min_max_new_tokens, - n_jobs=n_jobs, - llamacpp_dict=llamacpp_dict, - cli=cli, - verbose=verbose, - ) - # in case change, override original prompter - if hasattr(llm, 'prompter'): - prompter = llm.prompter - if hasattr(llm, 'pipeline') and hasattr(llm.pipeline, 'prompter'): - prompter = llm.pipeline.prompter - - if prompter is None: - if prompt_type is None: - prompt_type = prompt_type_out - # get prompter - chat = True # FIXME? - prompter = Prompter(prompt_type, prompt_dict, debug=False, chat=chat, stream_output=stream_output, - system_prompt=system_prompt) - - use_docs_planned = False - scores = [] - chain = None - - # basic version of prompt without docs etc. - data_point = dict(context=context, instruction=query, input=iinput) - prompt_basic = prompter.generate_prompt(data_point) - - if isinstance(document_choice, str): - # support string as well - document_choice = [document_choice] - - func_names = list(inspect.signature(get_chain).parameters) - sim_kwargs = {k: v for k, v in locals().items() if k in func_names} - missing_kwargs = [x for x in func_names if x not in sim_kwargs] - assert not missing_kwargs, "Missing: %s" % missing_kwargs - docs, chain, scores, \ - use_docs_planned, num_docs_before_cut, \ - use_llm_if_no_docs, llm_mode, top_k_docs_max_show = \ - get_chain(**sim_kwargs) - if document_subset in non_query_commands: - formatted_doc_chunks = '\n\n'.join([get_url(x) + '\n\n' + x.page_content for x in docs]) - if not formatted_doc_chunks and not use_llm_if_no_docs: - yield dict(prompt=prompt_basic, response="No sources", sources='', num_prompt_tokens=0) - return - # if no souces, outside gpt_langchain, LLM will be used with '' input - scores = [1] * len(docs) - get_answer_args = tuple([query, docs, formatted_doc_chunks, scores, show_rank, - answer_with_sources, - append_sources_to_answer]) - get_answer_kwargs = dict(show_accordions=show_accordions, - show_link_in_sources=show_link_in_sources, - top_k_docs_max_show=top_k_docs_max_show, - docs_ordering_type=docs_ordering_type, - num_docs_before_cut=num_docs_before_cut, - verbose=verbose) - ret, extra = get_sources_answer(*get_answer_args, **get_answer_kwargs) - yield dict(prompt=prompt_basic, response=formatted_doc_chunks, sources=extra, num_prompt_tokens=0) - return - if not use_llm_if_no_docs: - if not docs and langchain_action in [LangChainAction.SUMMARIZE_MAP.value, - LangChainAction.SUMMARIZE_ALL.value, - LangChainAction.SUMMARIZE_REFINE.value]: - ret = 'No relevant documents to summarize.' if num_docs_before_cut else 'No documents to summarize.' - extra = '' - yield dict(prompt=prompt_basic, response=ret, sources=extra, num_prompt_tokens=0) - return - if not docs and not llm_mode: - ret = 'No relevant documents to query (for chatting with LLM, pick Resources->Collections->LLM).' if num_docs_before_cut else 'No documents to query (for chatting with LLM, pick Resources->Collections->LLM).' - extra = '' - yield dict(prompt=prompt_basic, response=ret, sources=extra, num_prompt_tokens=0) - return - - if chain is None and not langchain_only_model: - # here if no docs at all and not HF type - # can only return if HF type - return - - # context stuff similar to used in evaluate() - import torch - device, torch_dtype, context_class = get_device_dtype() - conditional_type = hasattr(llm, 'pipeline') and hasattr(llm.pipeline, 'model') and hasattr(llm.pipeline.model, - 'conditional_type') and llm.pipeline.model.conditional_type - with torch.no_grad(): - have_lora_weights = lora_weights not in [no_lora_str, '', None] - context_class_cast = NullContext if device == 'cpu' or have_lora_weights else torch.autocast - if conditional_type: - # issues when casting to float16, can mess up t5 model, e.g. only when not streaming, or other odd behaviors - context_class_cast = NullContext - with context_class_cast(device): - if stream_output and streamer: - answer = None - import queue - bucket = queue.Queue() - thread = EThread(target=chain, streamer=streamer, bucket=bucket) - thread.start() - outputs = "" - try: - for new_text in streamer: - # print("new_text: %s" % new_text, flush=True) - if bucket.qsize() > 0 or thread.exc: - thread.join() - outputs += new_text - if prompter: # and False: # FIXME: pipeline can already use prompter - if conditional_type: - if prompter.botstr: - prompt = prompter.botstr - output_with_prompt = prompt + outputs - only_new_text = False - else: - prompt = None - output_with_prompt = outputs - only_new_text = True - else: - prompt = None # FIXME - output_with_prompt = outputs - # don't specify only_new_text here, use get_llm() value - output1 = prompter.get_response(output_with_prompt, prompt=prompt, - only_new_text=only_new_text, - sanitize_bot_response=sanitize_bot_response) - yield dict(prompt=prompt, response=output1, sources='', num_prompt_tokens=0) - else: - yield dict(prompt=prompt, response=outputs, sources='', num_prompt_tokens=0) - except BaseException: - # if any exception, raise that exception if was from thread, first - if thread.exc: - raise thread.exc - raise - finally: - # in case no exception and didn't join with thread yet, then join - if not thread.exc: - answer = thread.join() - if isinstance(answer, dict): - if 'output_text' in answer: - answer = answer['output_text'] - elif 'output' in answer: - answer = answer['output'] - # in case raise StopIteration or broke queue loop in streamer, but still have exception - if thread.exc: - raise thread.exc - else: - if async_output: - import asyncio - answer = asyncio.run(chain()) - else: - answer = chain() - if isinstance(answer, dict): - if 'output_text' in answer: - answer = answer['output_text'] - elif 'output' in answer: - answer = answer['output'] - - get_answer_args = tuple([query, docs, answer, scores, show_rank, - answer_with_sources, - append_sources_to_answer]) - get_answer_kwargs = dict(show_accordions=show_accordions, - show_link_in_sources=show_link_in_sources, - top_k_docs_max_show=top_k_docs_max_show, - docs_ordering_type=docs_ordering_type, - num_docs_before_cut=num_docs_before_cut, - verbose=verbose, - t_run=t_run, - count_input_tokens=llm.count_input_tokens - if hasattr(llm, 'count_input_tokens') else None, - count_output_tokens=llm.count_output_tokens - if hasattr(llm, 'count_output_tokens') else None) - - t_run = time.time() - t_run - - # for final yield, get real prompt used - if hasattr(llm, 'prompter') and llm.prompter.prompt is not None: - prompt = llm.prompter.prompt - else: - prompt = prompt_basic - num_prompt_tokens = get_token_count(prompt, tokenizer) - - if not use_docs_planned: - ret = answer - extra = '' - yield dict(prompt=prompt, response=ret, sources=extra, num_prompt_tokens=num_prompt_tokens) - elif answer is not None: - ret, extra = get_sources_answer(*get_answer_args, **get_answer_kwargs) - yield dict(prompt=prompt, response=ret, sources=extra, num_prompt_tokens=num_prompt_tokens) - return - - -def get_docs_with_score(query, k_db, filter_kwargs, db, db_type, text_context_list=None, verbose=False): - docs_with_score = [] - got_db_docs = False - - if text_context_list: - docs_with_score += [(x, x.metadata.get('score', 1.0)) for x in text_context_list] - - # deal with bug in chroma where if (say) 234 doc chunks and ask for 233+ then fails due to reduction misbehavior - if hasattr(db, '_embedding_function') and isinstance(db._embedding_function, FakeEmbeddings): - top_k_docs = -1 - # don't add text_context_list twice - db_documents, db_metadatas = get_docs_and_meta(db, top_k_docs, filter_kwargs=filter_kwargs, - text_context_list=None) - # sort by order given to parser (file_id) and any chunk_id if chunked - doc_file_ids = [x.get('file_id', 0) for x in db_metadatas] - doc_chunk_ids = [x.get('chunk_id', 0) for x in db_metadatas] - docs_with_score_fake = [(Document(page_content=result[0], metadata=result[1] or {}), 1.0) - for result in zip(db_documents, db_metadatas)] - docs_with_score_fake = [x for fx, cx, x in - sorted(zip(doc_file_ids, doc_chunk_ids, docs_with_score_fake), - key=lambda x: (x[0], x[1])) - ] - got_db_docs |= len(docs_with_score_fake) > 0 - docs_with_score += docs_with_score_fake - elif db is not None and db_type in ['chroma', 'chroma_old']: - while True: - try: - docs_with_score_chroma = db.similarity_search_with_score(query, k=k_db, **filter_kwargs) - break - except (RuntimeError, AttributeError) as e: - # AttributeError is for people with wrong version of langchain - if verbose: - print("chroma bug: %s" % str(e), flush=True) - if k_db == 1: - raise - if k_db > 500: - k_db -= 200 - elif k_db > 100: - k_db -= 50 - elif k_db > 10: - k_db -= 5 - else: - k_db -= 1 - k_db = max(1, k_db) - got_db_docs |= len(docs_with_score_chroma) > 0 - docs_with_score += docs_with_score_chroma - elif db is not None: - docs_with_score_other = db.similarity_search_with_score(query, k=k_db, **filter_kwargs) - got_db_docs |= len(docs_with_score_other) > 0 - docs_with_score += docs_with_score_other - - # set in metadata original order of docs - [x[0].metadata.update(orig_index=ii) for ii, x in enumerate(docs_with_score)] - - return docs_with_score, got_db_docs - - -def get_chain(query=None, - iinput=None, - context=None, # FIXME: https://github.com/hwchase17/langchain/issues/6638 - use_openai_model=False, use_openai_embedding=False, - first_para=False, text_limit=None, top_k_docs=4, chunk=True, chunk_size=512, - - # urls - use_unstructured=True, - use_playwright=False, - use_selenium=False, - - # pdfs - use_pymupdf='auto', - use_unstructured_pdf='auto', - use_pypdf='auto', - enable_pdf_ocr='auto', - enable_pdf_doctr='auto', - try_pdf_as_html='auto', - - # images - enable_ocr=False, - enable_doctr=False, - enable_pix2struct=False, - enable_captions=True, - captions_model=None, - caption_loader=None, - doctr_loader=None, - pix2struct_loader=None, - - # json - jq_schema='.[]', - - langchain_mode_paths=None, - langchain_mode_types=None, - detect_user_path_changes_every_query=False, - db_type='faiss', - model_name=None, - inference_server='', - max_new_tokens=None, - langchain_only_model=False, - hf_embedding_model=None, - migrate_embedding_model=False, - auto_migrate_db=False, - prompter=None, - prompt_type=None, - prompt_dict=None, - system_prompt=None, - cut_distance=1.1, - add_chat_history_to_context=True, # FIXME: https://github.com/hwchase17/langchain/issues/6638 - add_search_to_context=False, - keep_sources_in_context=False, - memory_restriction_level=0, - top_k_docs_max_show=10, - - load_db_if_exists=False, - db=None, - langchain_mode=None, - langchain_action=None, - langchain_agents=None, - document_subset=DocumentSubset.Relevant.name, - document_choice=[DocumentChoice.ALL.value], - pre_prompt_query=None, - prompt_query=None, - pre_prompt_summary=None, - prompt_summary=None, - text_context_list=None, - chat_conversation=None, - - n_jobs=-1, - # beyond run_db_query: - llm=None, - tokenizer=None, - verbose=False, - docs_ordering_type='reverse_ucurve_sort', - min_max_new_tokens=256, - stream_output=True, - async_output=True, - - # local - auto_reduce_chunks=True, - max_chunks=100, - total_tokens_for_docs=None, - use_llm_if_no_docs=None, - headsize=50, - ): - if inference_server is None: - inference_server = '' - assert hf_embedding_model is not None - assert langchain_agents is not None # should be at least [] - if text_context_list is None: - text_context_list = [] - - # default value: - llm_mode = langchain_mode in ['Disabled', 'LLM'] and len(text_context_list) == 0 - query_action = langchain_action == LangChainAction.QUERY.value - summarize_action = langchain_action in [LangChainAction.SUMMARIZE_MAP.value, - LangChainAction.SUMMARIZE_ALL.value, - LangChainAction.SUMMARIZE_REFINE.value] - - if len(text_context_list) > 0: - # turn into documents to make easy to manage and add meta - # try to account for summarization vs. query - chunk_id = 0 if query_action else -1 - text_context_list = [ - Document(page_content=x, metadata=dict(source='text_context_list', score=1.0, chunk_id=chunk_id)) for x - in text_context_list] - - if add_search_to_context: - params = { - "engine": "duckduckgo", - "gl": "us", - "hl": "en", - } - search = H2OSerpAPIWrapper(params=params) - # if doing search, allow more docs - docs_search, top_k_docs = search.get_search_documents(query, - query_action=query_action, - chunk=chunk, chunk_size=chunk_size, - db_type=db_type, - headsize=headsize, - top_k_docs=top_k_docs) - text_context_list = docs_search + text_context_list - add_search_to_context &= len(docs_search) > 0 - top_k_docs_max_show = max(top_k_docs_max_show, len(docs_search)) - - if len(text_context_list) > 0: - llm_mode = False - use_llm_if_no_docs = True - - from src.output_parser import H2OMRKLOutputParser - from langchain.agents import AgentType, load_tools, initialize_agent, create_vectorstore_agent, \ - create_pandas_dataframe_agent, create_json_agent, create_csv_agent - from langchain.agents.agent_toolkits import VectorStoreInfo, VectorStoreToolkit, create_python_agent, JsonToolkit - if LangChainAgent.SEARCH.value in langchain_agents: - output_parser = H2OMRKLOutputParser() - tools = load_tools(["serpapi"], llm=llm, serpapi_api_key=os.environ.get('SERPAPI_API_KEY')) - if inference_server.startswith('openai'): - agent_type = AgentType.OPENAI_FUNCTIONS - agent_executor_kwargs = {"handle_parsing_errors": True, 'output_parser': output_parser} - else: - agent_type = AgentType.ZERO_SHOT_REACT_DESCRIPTION - agent_executor_kwargs = {'output_parser': output_parser} - chain = initialize_agent(tools, llm, agent=agent_type, - agent_executor_kwargs=agent_executor_kwargs, - agent_kwargs=dict(output_parser=output_parser, - format_instructions=output_parser.get_format_instructions()), - output_parser=output_parser, - max_iterations=10, - verbose=True) - chain_kwargs = dict(input=query) - target = wrapped_partial(chain, chain_kwargs) - - docs = [] - scores = [] - use_docs_planned = False - num_docs_before_cut = 0 - use_llm_if_no_docs = True - return docs, target, scores, use_docs_planned, num_docs_before_cut, use_llm_if_no_docs, llm_mode, top_k_docs_max_show - - if LangChainAgent.COLLECTION.value in langchain_agents: - output_parser = H2OMRKLOutputParser() - vectorstore_info = VectorStoreInfo( - name=langchain_mode, - description="DataBase of text from PDFs, Image Captions, or web URL content", - vectorstore=db, - ) - toolkit = VectorStoreToolkit(vectorstore_info=vectorstore_info) - chain = create_vectorstore_agent(llm=llm, toolkit=toolkit, - agent_executor_kwargs=dict(output_parser=output_parser), - verbose=True) - - chain_kwargs = dict(input=query) - target = wrapped_partial(chain, chain_kwargs) - - docs = [] - scores = [] - use_docs_planned = False - num_docs_before_cut = 0 - use_llm_if_no_docs = True - return docs, target, scores, use_docs_planned, num_docs_before_cut, use_llm_if_no_docs, llm_mode, top_k_docs_max_show - - if LangChainAgent.PYTHON.value in langchain_agents and inference_server.startswith('openai'): - chain = create_python_agent( - llm=llm, - tool=PythonREPLTool(), - verbose=True, - agent_type=AgentType.OPENAI_FUNCTIONS, - agent_executor_kwargs={"handle_parsing_errors": True}, - ) - - chain_kwargs = dict(input=query) - target = wrapped_partial(chain, chain_kwargs) - - docs = [] - scores = [] - use_docs_planned = False - num_docs_before_cut = 0 - use_llm_if_no_docs = True - return docs, target, scores, use_docs_planned, num_docs_before_cut, use_llm_if_no_docs, llm_mode, top_k_docs_max_show - - if LangChainAgent.PANDAS.value in langchain_agents and inference_server.startswith('openai_chat'): - # FIXME: DATA - df = pd.DataFrame(None) - chain = create_pandas_dataframe_agent( - llm, - df, - verbose=True, - agent_type=AgentType.OPENAI_FUNCTIONS, - ) - - chain_kwargs = dict(input=query) - target = wrapped_partial(chain, chain_kwargs) - - docs = [] - scores = [] - use_docs_planned = False - num_docs_before_cut = 0 - use_llm_if_no_docs = True - return docs, target, scores, use_docs_planned, num_docs_before_cut, use_llm_if_no_docs, llm_mode, top_k_docs_max_show - - if isinstance(document_choice, str): - document_choice = [document_choice] - if document_choice and document_choice[0] == DocumentChoice.ALL.value: - document_choice_agent = document_choice[1:] - else: - document_choice_agent = document_choice - document_choice_agent = [x for x in document_choice_agent if x.endswith('.json')] - if LangChainAgent.JSON.value in \ - langchain_agents and \ - inference_server.startswith('openai_chat') and \ - len(document_choice_agent) == 1 and \ - document_choice_agent[0].endswith('.json'): - # with open('src/openai.yaml') as f: - # data = yaml.load(f, Loader=yaml.FullLoader) - with open(document_choice[0], 'rt') as f: - data = json.loads(f.read()) - json_spec = JsonSpec(dict_=data, max_value_length=4000) - json_toolkit = JsonToolkit(spec=json_spec) - - chain = create_json_agent( - llm=llm, toolkit=json_toolkit, verbose=True - ) - - chain_kwargs = dict(input=query) - target = wrapped_partial(chain, chain_kwargs) - - docs = [] - scores = [] - use_docs_planned = False - num_docs_before_cut = 0 - use_llm_if_no_docs = True - return docs, target, scores, use_docs_planned, num_docs_before_cut, use_llm_if_no_docs, llm_mode, top_k_docs_max_show - - if isinstance(document_choice, str): - document_choice = [document_choice] - if document_choice and document_choice[0] == DocumentChoice.ALL.value: - document_choice_agent = document_choice[1:] - else: - document_choice_agent = document_choice - document_choice_agent = [x for x in document_choice_agent if x.endswith('.csv')] - if LangChainAgent.CSV.value in langchain_agents and len(document_choice_agent) == 1 and document_choice_agent[ - 0].endswith( - '.csv'): - data_file = document_choice[0] - if inference_server.startswith('openai_chat'): - chain = create_csv_agent( - llm, - data_file, - verbose=True, - agent_type=AgentType.OPENAI_FUNCTIONS, - ) - else: - chain = create_csv_agent( - llm, - data_file, - verbose=True, - agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION, - ) - chain_kwargs = dict(input=query) - target = wrapped_partial(chain, chain_kwargs) - - docs = [] - scores = [] - use_docs_planned = False - num_docs_before_cut = 0 - use_llm_if_no_docs = True - return docs, target, scores, use_docs_planned, num_docs_before_cut, use_llm_if_no_docs, llm_mode, top_k_docs_max_show - - # determine whether use of context out of docs is planned - if not use_openai_model and prompt_type not in ['plain'] or langchain_only_model: - if llm_mode: - use_docs_planned = False - else: - use_docs_planned = True - else: - use_docs_planned = True - - # https://github.com/hwchase17/langchain/issues/1946 - # FIXME: Seems to way to get size of chroma db to limit top_k_docs to avoid - # Chroma collection MyData contains fewer than 4 elements. - # type logger error - if top_k_docs == -1: - k_db = 1000 if db_type in ['chroma', 'chroma_old'] else 100 - else: - # top_k_docs=100 works ok too - k_db = 1000 if db_type in ['chroma', 'chroma_old'] else top_k_docs - - # FIXME: For All just go over all dbs instead of a separate db for All - if not detect_user_path_changes_every_query and db is not None: - # avoid looking at user_path during similarity search db handling, - # if already have db and not updating from user_path every query - # but if db is None, no db yet loaded (e.g. from prep), so allow user_path to be whatever it was - if langchain_mode_paths is None: - langchain_mode_paths = {} - langchain_mode_paths = langchain_mode_paths.copy() - langchain_mode_paths[langchain_mode] = None - # once use_openai_embedding, hf_embedding_model passed in, possibly changed, - # but that's ok as not used below or in calling functions - db, num_new_sources, new_sources_metadata = make_db(use_openai_embedding=use_openai_embedding, - hf_embedding_model=hf_embedding_model, - migrate_embedding_model=migrate_embedding_model, - auto_migrate_db=auto_migrate_db, - first_para=first_para, text_limit=text_limit, - chunk=chunk, chunk_size=chunk_size, - - # urls - use_unstructured=use_unstructured, - use_playwright=use_playwright, - use_selenium=use_selenium, - - # pdfs - use_pymupdf=use_pymupdf, - use_unstructured_pdf=use_unstructured_pdf, - use_pypdf=use_pypdf, - enable_pdf_ocr=enable_pdf_ocr, - enable_pdf_doctr=enable_pdf_doctr, - try_pdf_as_html=try_pdf_as_html, - - # images - enable_ocr=enable_ocr, - enable_doctr=enable_doctr, - enable_pix2struct=enable_pix2struct, - enable_captions=enable_captions, - captions_model=captions_model, - caption_loader=caption_loader, - doctr_loader=doctr_loader, - pix2struct_loader=pix2struct_loader, - - # json - jq_schema=jq_schema, - - langchain_mode=langchain_mode, - langchain_mode_paths=langchain_mode_paths, - langchain_mode_types=langchain_mode_types, - db_type=db_type, - load_db_if_exists=load_db_if_exists, - db=db, - n_jobs=n_jobs, - verbose=verbose) - num_docs_before_cut = 0 - use_template = not use_openai_model and prompt_type not in ['plain'] or langchain_only_model - got_db_docs = False # not yet at least - template, template_if_no_docs, auto_reduce_chunks, query = \ - get_template(query, iinput, - pre_prompt_query, prompt_query, - pre_prompt_summary, prompt_summary, - langchain_action, - llm_mode, - use_docs_planned, - auto_reduce_chunks, - got_db_docs, - add_search_to_context) - - max_input_tokens = get_max_input_tokens(llm=llm, tokenizer=tokenizer, inference_server=inference_server, - model_name=model_name, max_new_tokens=max_new_tokens) - - if (db or text_context_list) and use_docs_planned: - if hasattr(db, '_persist_directory'): - lock_file = get_db_lock_file(db, lock_type='sim') - else: - base_path = 'locks' - base_path = makedirs(base_path, exist_ok=True, tmp_ok=True, use_base=True) - name_path = "sim.lock" - lock_file = os.path.join(base_path, name_path) - - if not (isinstance(db, Chroma) or isinstance(db, ChromaMig) or ChromaMig.__name__ in str(db)): - # only chroma supports filtering - filter_kwargs = {} - filter_kwargs_backup = {} - else: - import logging - logging.getLogger("chromadb").setLevel(logging.ERROR) - assert document_choice is not None, "Document choice was None" - if isinstance(db, Chroma): - filter_kwargs_backup = {} # shouldn't ever need backup - # chroma >= 0.4 - if len(document_choice) == 0 or len(document_choice) >= 1 and document_choice[ - 0] == DocumentChoice.ALL.value: - filter_kwargs = {"filter": {"chunk_id": {"$gte": 0}}} if query_action else \ - {"filter": {"chunk_id": {"$eq": -1}}} - else: - if document_choice[0] == DocumentChoice.ALL.value: - document_choice = document_choice[1:] - if len(document_choice) == 0: - filter_kwargs = {} - elif len(document_choice) > 1: - or_filter = [ - {"$and": [dict(source={"$eq": x}), dict(chunk_id={"$gte": 0})]} if query_action else { - "$and": [dict(source={"$eq": x}), dict(chunk_id={"$eq": -1})]} - for x in document_choice] - filter_kwargs = dict(filter={"$or": or_filter}) - else: - # still chromadb UX bug, have to do different thing for 1 vs. 2+ docs when doing filter - one_filter = \ - [{"source": {"$eq": x}, "chunk_id": {"$gte": 0}} if query_action else { - "source": {"$eq": x}, - "chunk_id": { - "$eq": -1}} - for x in document_choice][0] - - filter_kwargs = dict(filter={"$and": [dict(source=one_filter['source']), - dict(chunk_id=one_filter['chunk_id'])]}) - else: - # migration for chroma < 0.4 - if len(document_choice) == 0 or len(document_choice) >= 1 and document_choice[ - 0] == DocumentChoice.ALL.value: - filter_kwargs = {"filter": {"chunk_id": {"$gte": 0}}} if query_action else \ - {"filter": {"chunk_id": {"$eq": -1}}} - filter_kwargs_backup = {"filter": {"chunk_id": {"$gte": 0}}} - elif len(document_choice) >= 2: - if document_choice[0] == DocumentChoice.ALL.value: - document_choice = document_choice[1:] - or_filter = [ - {"source": {"$eq": x}, "chunk_id": {"$gte": 0}} if query_action else {"source": {"$eq": x}, - "chunk_id": { - "$eq": -1}} - for x in document_choice] - filter_kwargs = dict(filter={"$or": or_filter}) - or_filter_backup = [ - {"source": {"$eq": x}} if query_action else {"source": {"$eq": x}} - for x in document_choice] - filter_kwargs_backup = dict(filter={"$or": or_filter_backup}) - elif len(document_choice) == 1: - # degenerate UX bug in chroma - one_filter = \ - [{"source": {"$eq": x}, "chunk_id": {"$gte": 0}} if query_action else {"source": {"$eq": x}, - "chunk_id": { - "$eq": -1}} - for x in document_choice][0] - filter_kwargs = dict(filter=one_filter) - one_filter_backup = \ - [{"source": {"$eq": x}} if query_action else {"source": {"$eq": x}} - for x in document_choice][0] - filter_kwargs_backup = dict(filter=one_filter_backup) - else: - # shouldn't reach - filter_kwargs = {} - filter_kwargs_backup = {} - - if llm_mode: - docs = [] - scores = [] - elif document_subset == DocumentSubset.TopKSources.name or query in [None, '', '\n']: - db_documents, db_metadatas = get_docs_and_meta(db, top_k_docs, filter_kwargs=filter_kwargs, - text_context_list=text_context_list) - if len(db_documents) == 0 and filter_kwargs_backup: - db_documents, db_metadatas = get_docs_and_meta(db, top_k_docs, filter_kwargs=filter_kwargs_backup, - text_context_list=text_context_list) - - if top_k_docs == -1: - top_k_docs = len(db_documents) - # similar to langchain's chroma's _results_to_docs_and_scores - docs_with_score = [(Document(page_content=result[0], metadata=result[1] or {}), 0) - for result in zip(db_documents, db_metadatas)] - # set in metadata original order of docs - [x[0].metadata.update(orig_index=ii) for ii, x in enumerate(docs_with_score)] - - # order documents - doc_hashes = [x.get('doc_hash', 'None') for x in db_metadatas] - if query_action: - doc_chunk_ids = [x.get('chunk_id', 0) for x in db_metadatas] - docs_with_score2 = [x for hx, cx, x in - sorted(zip(doc_hashes, doc_chunk_ids, docs_with_score), key=lambda x: (x[0], x[1])) - if cx >= 0] - else: - assert summarize_action - doc_chunk_ids = [x.get('chunk_id', -1) for x in db_metadatas] - docs_with_score2 = [x for hx, cx, x in - sorted(zip(doc_hashes, doc_chunk_ids, docs_with_score), key=lambda x: (x[0], x[1])) - if cx == -1 - ] - if len(docs_with_score2) == 0 and len(docs_with_score) > 0: - # old database without chunk_id, migration added 0 but didn't make -1 as that would be expensive - # just do again and relax filter, let summarize operate on actual chunks if nothing else - docs_with_score2 = [x for hx, cx, x in - sorted(zip(doc_hashes, doc_chunk_ids, docs_with_score), - key=lambda x: (x[0], x[1])) - ] - docs_with_score = docs_with_score2 - - docs_with_score = docs_with_score[:top_k_docs] - docs = [x[0] for x in docs_with_score] - scores = [x[1] for x in docs_with_score] - num_docs_before_cut = len(docs) - else: - with filelock.FileLock(lock_file): - docs_with_score, got_db_docs = get_docs_with_score(query, k_db, filter_kwargs, db, db_type, - text_context_list=text_context_list, - verbose=verbose) - if len(docs_with_score) == 0 and filter_kwargs_backup: - docs_with_score, got_db_docs = get_docs_with_score(query, k_db, filter_kwargs_backup, db, - db_type, - text_context_list=text_context_list, - verbose=verbose) - - tokenizer = get_tokenizer(db=db, llm=llm, tokenizer=tokenizer, inference_server=inference_server, - use_openai_model=use_openai_model, - db_type=db_type) - # NOTE: if map_reduce, then no need to auto reduce chunks - if query_action and (top_k_docs == -1 or auto_reduce_chunks): - top_k_docs_tokenize = 100 - docs_with_score = docs_with_score[:top_k_docs_tokenize] - - prompt_no_docs = template.format(context='', question=query) - - model_max_length = tokenizer.model_max_length - chat = True # FIXME? - - # first docs_with_score are most important with highest score - full_prompt, \ - instruction, iinput, context, \ - num_prompt_tokens, max_new_tokens, \ - num_prompt_tokens0, num_prompt_tokens_actual, \ - chat_index, top_k_docs_trial, one_doc_size = \ - get_limited_prompt(prompt_no_docs, - iinput, - tokenizer, - prompter=prompter, - inference_server=inference_server, - prompt_type=prompt_type, - prompt_dict=prompt_dict, - chat=chat, - max_new_tokens=max_new_tokens, - system_prompt=system_prompt, - context=context, - chat_conversation=chat_conversation, - text_context_list=[x[0].page_content for x in docs_with_score], - keep_sources_in_context=keep_sources_in_context, - model_max_length=model_max_length, - memory_restriction_level=memory_restriction_level, - langchain_mode=langchain_mode, - add_chat_history_to_context=add_chat_history_to_context, - min_max_new_tokens=min_max_new_tokens, - ) - # avoid craziness - if 0 < top_k_docs_trial < max_chunks: - # avoid craziness - if top_k_docs == -1: - top_k_docs = top_k_docs_trial - else: - top_k_docs = min(top_k_docs, top_k_docs_trial) - elif top_k_docs_trial >= max_chunks: - top_k_docs = max_chunks - if top_k_docs > 0: - docs_with_score = docs_with_score[:top_k_docs] - elif one_doc_size is not None: - docs_with_score = [docs_with_score[0][:one_doc_size]] - else: - docs_with_score = [] - else: - if total_tokens_for_docs is not None: - # used to limit tokens for summarization, e.g. public instance - top_k_docs, one_doc_size, num_doc_tokens = \ - get_docs_tokens(tokenizer, - text_context_list=[x[0].page_content for x in docs_with_score], - max_input_tokens=total_tokens_for_docs) - - docs_with_score = docs_with_score[:top_k_docs] - - # put most relevant chunks closest to question, - # esp. if truncation occurs will be "oldest" or "farthest from response" text that is truncated - # BUT: for small models, e.g. 6_9 pythia, if sees some stuff related to h2oGPT first, it can connect that and not listen to rest - if docs_ordering_type in ['best_first']: - pass - elif docs_ordering_type in ['best_near_prompt', 'reverse_sort']: - docs_with_score.reverse() - elif docs_ordering_type in ['', None, 'reverse_ucurve_sort']: - docs_with_score = reverse_ucurve_list(docs_with_score) - else: - raise ValueError("No such docs_ordering_type=%s" % docs_ordering_type) - - # cut off so no high distance docs/sources considered - num_docs_before_cut = len(docs_with_score) - docs = [x[0] for x in docs_with_score if x[1] < cut_distance] - scores = [x[1] for x in docs_with_score if x[1] < cut_distance] - if len(scores) > 0 and verbose: - print("Distance: min: %s max: %s mean: %s median: %s" % - (scores[0], scores[-1], np.mean(scores), np.median(scores)), flush=True) - else: - docs = [] - scores = [] - - if not docs and use_docs_planned and not langchain_only_model: - # if HF type and have no docs, can bail out - return docs, None, [], False, num_docs_before_cut, use_llm_if_no_docs, llm_mode, top_k_docs_max_show - - if document_subset in non_query_commands: - # no LLM use - return docs, None, [], False, num_docs_before_cut, use_llm_if_no_docs, llm_mode, top_k_docs_max_show - - # FIXME: WIP - common_words_file = "data/NGSL_1.2_stats.csv.zip" - if False and os.path.isfile(common_words_file) and langchain_action == LangChainAction.QUERY.value: - df = pd.read_csv("data/NGSL_1.2_stats.csv.zip") - import string - reduced_query = query.translate(str.maketrans(string.punctuation, ' ' * len(string.punctuation))).strip() - reduced_query_words = reduced_query.split(' ') - set_common = set(df['Lemma'].values.tolist()) - num_common = len([x.lower() in set_common for x in reduced_query_words]) - frac_common = num_common / len(reduced_query) if reduced_query else 0 - # FIXME: report to user bad query that uses too many common words - if verbose: - print("frac_common: %s" % frac_common, flush=True) - - if len(docs) == 0: - # avoid context == in prompt then - use_docs_planned = False - template = template_if_no_docs - - got_db_docs = got_db_docs and len(text_context_list) < len(docs) - # update template in case situation changed or did get docs - # then no new documents from database or not used, redo template - # got template earlier as estimate of template token size, here is final used version - template, template_if_no_docs, auto_reduce_chunks, query = \ - get_template(query, iinput, - pre_prompt_query, prompt_query, - pre_prompt_summary, prompt_summary, - langchain_action, - llm_mode, - use_docs_planned, - auto_reduce_chunks, - got_db_docs, - add_search_to_context) - - if langchain_action == LangChainAction.QUERY.value: - if use_template: - # instruct-like, rather than few-shot prompt_type='plain' as default - # but then sources confuse the model with how inserted among rest of text, so avoid - prompt = PromptTemplate( - # input_variables=["summaries", "question"], - input_variables=["context", "question"], - template=template, - ) - chain = load_qa_chain(llm, prompt=prompt, verbose=verbose) - else: - # only if use_openai_model = True, unused normally except in testing - chain = load_qa_with_sources_chain(llm) - if not use_docs_planned: - chain_kwargs = dict(input_documents=[], question=query) - else: - chain_kwargs = dict(input_documents=docs, question=query) - target = wrapped_partial(chain, chain_kwargs) - elif langchain_action in [LangChainAction.SUMMARIZE_MAP.value, - LangChainAction.SUMMARIZE_REFINE, - LangChainAction.SUMMARIZE_ALL.value]: - if async_output: - return_intermediate_steps = False - else: - return_intermediate_steps = True - from langchain.chains.summarize import load_summarize_chain - if langchain_action == LangChainAction.SUMMARIZE_MAP.value: - prompt = PromptTemplate(input_variables=["text"], template=template) - chain = load_summarize_chain(llm, chain_type="map_reduce", - map_prompt=prompt, combine_prompt=prompt, - return_intermediate_steps=return_intermediate_steps, - token_max=max_input_tokens, verbose=verbose) - if async_output: - chain_func = chain.arun - else: - chain_func = chain - target = wrapped_partial(chain_func, {"input_documents": docs}) # , return_only_outputs=True) - elif langchain_action == LangChainAction.SUMMARIZE_ALL.value: - assert use_template - prompt = PromptTemplate(input_variables=["text"], template=template) - chain = load_summarize_chain(llm, chain_type="stuff", prompt=prompt, - return_intermediate_steps=return_intermediate_steps, verbose=verbose) - if async_output: - chain_func = chain.arun - else: - chain_func = chain - target = wrapped_partial(chain_func) - elif langchain_action == LangChainAction.SUMMARIZE_REFINE.value: - chain = load_summarize_chain(llm, chain_type="refine", - return_intermediate_steps=return_intermediate_steps, verbose=verbose) - if async_output: - chain_func = chain.arun - else: - chain_func = chain - target = wrapped_partial(chain_func) - else: - raise RuntimeError("No such langchain_action=%s" % langchain_action) - else: - raise RuntimeError("No such langchain_action=%s" % langchain_action) - - return docs, target, scores, use_docs_planned, num_docs_before_cut, use_llm_if_no_docs, llm_mode, top_k_docs_max_show - - -def get_max_model_length(llm=None, tokenizer=None, inference_server=None, model_name=None): - if hasattr(tokenizer, 'model_max_length'): - return tokenizer.model_max_length - elif inference_server in ['openai', 'openai_azure']: - return llm.modelname_to_contextsize(model_name) - elif inference_server in ['openai_chat', 'openai_azure_chat']: - return model_token_mapping[model_name] - elif isinstance(tokenizer, FakeTokenizer): - # GGML - return tokenizer.model_max_length - else: - return 2048 - - -def get_max_input_tokens(llm=None, tokenizer=None, inference_server=None, model_name=None, max_new_tokens=None): - model_max_length = get_max_model_length(llm=llm, tokenizer=tokenizer, inference_server=inference_server, - model_name=model_name) - - if any([inference_server.startswith(x) for x in - ['openai', 'openai_azure', 'openai_chat', 'openai_azure_chat', 'vllm']]): - # openai can't handle tokens + max_new_tokens > max_tokens even if never generate those tokens - # and vllm uses OpenAI API with same limits - max_input_tokens = model_max_length - max_new_tokens - elif isinstance(tokenizer, FakeTokenizer): - # don't trust that fake tokenizer (e.g. GGML) will make lots of tokens normally, allow more input - max_input_tokens = model_max_length - min(256, max_new_tokens) - else: - if 'falcon' in model_name or inference_server.startswith('http'): - # allow for more input for falcon, assume won't make as long outputs as default max_new_tokens - # Also allow if TGI or Gradio, because we tell it input may be same as output, even if model can't actually handle - max_input_tokens = model_max_length - min(256, max_new_tokens) - else: - # trust that maybe model will make so many tokens, so limit input - max_input_tokens = model_max_length - max_new_tokens - - return max_input_tokens - - -def get_tokenizer(db=None, llm=None, tokenizer=None, inference_server=None, use_openai_model=False, - db_type='chroma'): - if hasattr(llm, 'pipeline') and hasattr(llm.pipeline, 'tokenizer'): - # more accurate - return llm.pipeline.tokenizer - elif hasattr(llm, 'tokenizer'): - # e.g. TGI client mode etc. - return llm.tokenizer - elif inference_server in ['openai', 'openai_chat', 'openai_azure', - 'openai_azure_chat']: - return tokenizer - elif isinstance(tokenizer, FakeTokenizer): - return tokenizer - elif use_openai_model: - return FakeTokenizer() - elif (hasattr(db, '_embedding_function') and - hasattr(db._embedding_function, 'client') and - hasattr(db._embedding_function.client, 'tokenize')): - # in case model is not our pipeline with HF tokenizer - return db._embedding_function.client.tokenize - else: - # backup method - if os.getenv('HARD_ASSERTS'): - assert db_type in ['faiss', 'weaviate'] - # use tiktoken for faiss since embedding called differently - return FakeTokenizer() - - -def get_template(query, iinput, - pre_prompt_query, prompt_query, - pre_prompt_summary, prompt_summary, - langchain_action, - llm_mode, - use_docs_planned, - auto_reduce_chunks, - got_db_docs, - add_search_to_context): - if got_db_docs and add_search_to_context: - # modify prompts, assumes patterns like in predefined prompts. If user customizes, then they'd need to account for that. - prompt_query = prompt_query.replace('information in the document sources', - 'information in the document and web search sources (and their source dates and website source)') - prompt_summary = prompt_summary.replace('information in the document sources', - 'information in the document and web search sources (and their source dates and website source)') - elif got_db_docs and not add_search_to_context: - pass - elif not got_db_docs and add_search_to_context: - # modify prompts, assumes patterns like in predefined prompts. If user customizes, then they'd need to account for that. - prompt_query = prompt_query.replace('information in the document sources', - 'information in the web search sources (and their source dates and website source)') - prompt_summary = prompt_summary.replace('information in the document sources', - 'information in the web search sources (and their source dates and website source)') - - if langchain_action == LangChainAction.QUERY.value: - if iinput: - query = "%s\n%s" % (query, iinput) - if llm_mode or not use_docs_planned: - template_if_no_docs = template = """{context}{question}""" - else: - template = """%s -\"\"\" -{context} -\"\"\" -%s{question}""" % (pre_prompt_query, prompt_query) - template_if_no_docs = """{context}{question}""" - elif langchain_action in [LangChainAction.SUMMARIZE_ALL.value, LangChainAction.SUMMARIZE_MAP.value]: - none = ['', '\n', None] - - # modify prompt_summary if user passes query or iinput - if query not in none and iinput not in none: - prompt_summary = "Focusing on %s, %s, %s" % (query, iinput, prompt_summary) - elif query not in none: - prompt_summary = "Focusing on %s, %s" % (query, prompt_summary) - # don't auto reduce - auto_reduce_chunks = False - if langchain_action == LangChainAction.SUMMARIZE_MAP.value: - fstring = '{text}' - else: - fstring = '{input_documents}' - template = """%s: -\"\"\" -%s -\"\"\"\n%s""" % (pre_prompt_summary, fstring, prompt_summary) - template_if_no_docs = "Exactly only say: There are no documents to summarize." - elif langchain_action in [LangChainAction.SUMMARIZE_REFINE]: - template = '' # unused - template_if_no_docs = '' # unused - else: - raise RuntimeError("No such langchain_action=%s" % langchain_action) - - return template, template_if_no_docs, auto_reduce_chunks, query - - -def get_sources_answer(query, docs, answer, scores, show_rank, - answer_with_sources, append_sources_to_answer, - show_accordions=True, - show_link_in_sources=True, - top_k_docs_max_show=10, - docs_ordering_type='reverse_ucurve_sort', - num_docs_before_cut=0, - verbose=False, - t_run=None, - count_input_tokens=None, count_output_tokens=None): - if verbose: - print("query: %s" % query, flush=True) - print("answer: %s" % answer, flush=True) - - if len(docs) == 0: - extra = '' - ret = answer + extra - return ret, extra - - if answer_with_sources == -1: - extra = [dict(score=score, content=get_doc(x), source=get_source(x), orig_index=x.metadata.get('orig_index', 0)) - for score, x in zip(scores, docs)][ - :top_k_docs_max_show] - if append_sources_to_answer: - extra_str = [str(x) for x in extra] - ret = answer + '\n\n' + '\n'.join(extra_str) - else: - ret = answer - return ret, extra - - # link - answer_sources = [(max(0.0, 1.5 - score) / 1.5, - get_url(doc, font_size=font_size), - get_accordion(doc, font_size=font_size, head_acc=head_acc)) for score, doc in - zip(scores, docs)] - if not show_accordions: - answer_sources_dict = defaultdict(list) - [answer_sources_dict[url].append(score) for score, url in answer_sources] - answers_dict = {} - for url, scores_url in answer_sources_dict.items(): - answers_dict[url] = np.max(scores_url) - answer_sources = [(score, url) for url, score in answers_dict.items()] - answer_sources.sort(key=lambda x: x[0], reverse=True) - if show_rank: - # answer_sources = ['%d | %s' % (1 + rank, url) for rank, (score, url) in enumerate(answer_sources)] - # sorted_sources_urls = "Sources [Rank | Link]:
" + "
".join(answer_sources) - answer_sources = ['%s' % url for rank, (score, url) in enumerate(answer_sources)] - answer_sources = answer_sources[:top_k_docs_max_show] - sorted_sources_urls = "Ranked Sources:
" + "
".join(answer_sources) - else: - if show_accordions: - if show_link_in_sources: - answer_sources = ['
  • %.2g | %s
  • %s
    ' % (font_size, score, url, accordion) - for score, url, accordion in answer_sources] - else: - answer_sources = ['
  • %.2g
  • %s
    ' % (font_size, score, accordion) - for score, url, accordion in answer_sources] - else: - if show_link_in_sources: - answer_sources = ['
  • %.2g | %s
  • ' % (font_size, score, url) - for score, url in answer_sources] - else: - answer_sources = ['
  • %.2g
  • ' % (font_size, score) - for score, url in answer_sources] - answer_sources = answer_sources[:top_k_docs_max_show] - if show_accordions: - sorted_sources_urls = f"{source_prefix}