import streamlit as st import os import re import sys import time import base64 import random import logging logging.basicConfig(stream=sys.stdout, level=logging.INFO) logger = logging.getLogger(__name__) from dotenv import load_dotenv load_dotenv() for key in st.session_state.keys(): #del st.session_state[key] print(f'session state entry: {key} {st.session_state[key]}') __spaces__ = os.environ.get('__SPACES__') if __spaces__: from kron.persistence.dynamodb_request_log import get_request_log; st.session_state.request_log = get_request_log() #third party service access #hf inference api hf_api_key = os.environ['HF_TOKEN'] ch_api_key = os.environ['COHERE_TOKEN'] bs_api_key = os.environ['BASETEN_TOKEN'] #index_model = "Writer/camel-5b-hf" index_model = "Arylwen/instruct-palmyra-20b-gptq-8" INDEX_NAME = f"{index_model.replace('/', '-')}-default-no-coref" persist_path = f"storage/{INDEX_NAME}" MAX_LENGTH = 1024 MAX_NEW_TOKENS = 250 #import baseten #@st.cache_resource #def set_baseten_key(bs_api_key): # baseten.login(bs_api_key) #set_baseten_key(bs_api_key) def autoplay_video(video_path): with open(video_path, "rb") as f: video_content = f.read() video_str = f"data:video/mp4;base64,{base64.b64encode(video_content).decode()}" st.markdown(f""" """, unsafe_allow_html=True) # sidebar with st.sidebar: st.header('KG Questions') video, text = st.columns([2, 2]) with video: autoplay_video('docs/images/kg_construction.mp4') with text: st.write( f''' ###### The construction of a Knowledge Graph is mesmerizing. ###### Concepts in the middle are what most are doing. Are we considering anything different? Why? Why not? ###### Concepts on the edge are what few are doing. Are we considering that? Why? Why not? ''' ) st.caption('''###### corpus by [@ArxivHealthcareNLP@sigmoid.social](https://sigmoid.social/@ArxivHealthcareNLP)''') st.caption('''###### KG Questions by [arylwen](https://github.com/arylwen/mlk8s)''') from llama_index.core import StorageContext, ServiceContext, load_index_from_storage #from llama_index import ServiceContext # from llama_index import load_index_from_storage from llama_index.core.node_parser import SentenceSplitter #from llama_index.node_parser import SimpleNodeParser from llama_index.core.service_context_elements.llm_predictor import LLMPredictor from langchain import HuggingFaceHub from langchain.llms.cohere import Cohere #from langchain.llms import Baseten import tiktoken import openai #extensions to llama_index to support openai compatible endpoints, e.g. llama-api from kron.llm_predictor.KronOpenAILLM import KronOpenAI #baseten deployment expects a specific request format #from kron.llm_predictor.KronBasetenCamelLLM import KronBasetenCamelLLM from kron.llm_predictor.KronLLMPredictor import KronLLMPredictor #writer/camel uses endoftext from llama_index.core.utils import globals_helper enc = tiktoken.get_encoding("gpt2") tokenizer = lambda text: enc.encode(text, allowed_special={"<|endoftext|>"}) globals_helper._tokenizer = tokenizer def set_openai_local(): openai.api_key = os.environ['LOCAL_OPENAI_API_KEY'] openai.api_base = os.environ['LOCAL_OPENAI_API_BASE'] os.environ['OPENAI_API_KEY'] = os.environ['LOCAL_OPENAI_API_KEY'] os.environ['OPENAI_API_BASE'] = os.environ['LOCAL_OPENAI_API_BASE'] def set_openai(): openai.api_key = os.environ['DAVINCI_OPENAI_API_KEY'] openai.api_base = os.environ['DAVINCI_OPENAI_API_BASE'] os.environ['OPENAI_API_KEY'] = os.environ['DAVINCI_OPENAI_API_KEY'] os.environ['OPENAI_API_BASE'] = os.environ['DAVINCI_OPENAI_API_BASE'] from kron.llm_predictor.KronHFHubLLM import KronHuggingFaceHub def get_hf_predictor(query_model): # no embeddings for now set_openai_local() #llm=HuggingFaceHub(repo_id=query_model, task="text-generation", llm=KronHuggingFaceHub(repo_id=query_model, task="text-generation", # model_kwargs={"temperature": 0.01, "max_new_tokens": MAX_NEW_TOKENS, 'frequency_penalty':1.17}, model_kwargs={"temperature": 0.01, "max_new_tokens": MAX_NEW_TOKENS }, huggingfacehub_api_token=hf_api_key) llm_predictor = LLMPredictor(llm) return llm_predictor def get_cohere_predictor(query_model): # no embeddings for now set_openai_local() llm=Cohere(model='command', temperature = 0.01, # model_kwargs={"temperature": 0.01, "max_length": MAX_LENGTH}, cohere_api_key=ch_api_key) llm_predictor = LLMPredictor(llm) return llm_predictor #def get_baseten_predictor(query_model): # # no embeddings for now # set_openai_local() # llm=KronBasetenCamelLLM(model='3yd1ke3', temperature = 0.01, # model_kwargs={"temperature": 0.01, "max_length": MAX_LENGTH, 'repetition_penalty':1.07}, # model_kwargs={"temperature": 0.01, "max_length": MAX_LENGTH, 'frequency_penalty':1}, # cohere_api_key=ch_api_key) # llm_predictor = LLMPredictor(llm) # return llm_predictor def get_kron_openai_predictor(query_model): # define LLM llm=KronOpenAI(temperature=0.01, model=query_model) llm.max_tokens = MAX_LENGTH llm_predictor = KronLLMPredictor(llm) return llm_predictor def get_servce_context(llm_predictor): # define TextSplitter text_splitter = SentenceSplitter(chunk_size=192, chunk_overlap=48, paragraph_separator='\n') #define NodeParser #node_parser = SimpleNodeParser(text_splitter=text_splitter) node_parser = text_splitter #define ServiceContext service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor, node_parser=node_parser) return service_context # hack - on subsequent calls we can pass anything as index @st.cache_data def get_networkx_graph_nodes(_index, persist_path): g = _index.get_networkx_graph(100000) sorted_nodes = sorted(g.degree, key = lambda x: x[1], reverse=True) return sorted_nodes @st.cache_data def get_networkx_low_connected_components(_index, persist_path): g = _index.get_networkx_graph(100000) import networkx as nx sorted_c = [c for c in sorted(nx.connected_components(g), key=len, reverse=False)] #print(sorted_c[:100]) low_terms = [] for c in sorted_c: for cc in c: low_terms.extend([cc]) #print(low_terms) return low_terms def get_index(service_context, persist_path): print(f'Loading index from {persist_path}') # rebuild storage context storage_context = StorageContext.from_defaults(persist_dir=persist_path) # load index index = load_index_from_storage(storage_context=storage_context, service_context=service_context, max_triplets_per_chunk=2, show_progress = False) get_networkx_graph_nodes(index, persist_path) get_networkx_low_connected_components(index, persist_path) return index def get_query_engine(index): #writer/camel does not understand the refine prompt RESPONSE_MODE = 'accumulate' query_engine = index.as_query_engine(response_mode = RESPONSE_MODE) return query_engine def load_query_engine(llm_predictor, persist_path): service_context = get_servce_context(llm_predictor) index = get_index(service_context, persist_path) print(f'No query engine for {persist_path}; creating') query_engine = get_query_engine(index) return query_engine @st.cache_resource def build_kron_query_engine(query_model, persist_path): llm_predictor = get_kron_openai_predictor(query_model) query_engine = load_query_engine(llm_predictor, persist_path) return query_engine @st.cache_resource def build_hf_query_engine(query_model, persist_path): llm_predictor = get_hf_predictor(query_model) query_engine = load_query_engine(llm_predictor, persist_path) return query_engine @st.cache_resource def build_cohere_query_engine(query_model, persist_path): llm_predictor = get_cohere_predictor(query_model) query_engine = load_query_engine(llm_predictor, persist_path) return query_engine #@st.cache_resource #def build_baseten_query_engine(query_model, persist_path): # llm_predictor = get_baseten_predictor(query_model) # query_engine = load_query_engine(llm_predictor, persist_path) # return query_engine def format_response(answer): # Replace any eventual -- dashes = r'(\-{2,50})' answer.response = re.sub(dashes, '', answer.response) return answer.response or "None" def clear_question(query_model): if not ('prev_model' in st.session_state) or (('prev_model' in st.session_state) and (st.session_state.prev_model != query_model)) : if 'prev_model' in st.session_state: print(f'clearing question {st.session_state.prev_model} {query_model}') else: print(f'clearing question None {query_model}') if('question_input' in st.session_state): st.session_state.question = st.session_state.question_input st.session_state.question_input = '' st.session_state.question_answered = False st.session_state.answer = '' st.session_state.answer_rating = 3 st.session_state.elapsed = 0 st.session_state.prev_model = query_model query, measurable, explainable, ethical = st.tabs(["Query", "Measurable", "Explainable", "Ethical"]) initial_query = '' if 'question' not in st.session_state: st.session_state.question = '' if __spaces__ : with query: answer_model = st.radio( "Choose the model used for inference:", ('hf/tiiuae/falcon-7b-instruct', 'cohere/command', 'openai/gpt-3.5-turbo-instruct') #TODO start hf inference container on demand ) else : with query: answer_model = st.radio( "Choose the model used for inference:", ('Writer/camel-5b-hf', 'mosaicml/mpt-7b-instruct', 'hf/tiiuae/falcon-7b-instruct', 'cohere/command', 'baseten/Camel-5b', 'openai/gpt-3.5-turbo-instruct') ) if answer_model == 'openai/gpt-3.5-turbo-instruct': print(answer_model) query_model = 'gpt-3.5-turbo-instruct' clear_question(query_model) set_openai() query_engine = build_kron_query_engine(query_model, persist_path) graph_nodes = get_networkx_graph_nodes( "", persist_path) most_connected = random.sample(graph_nodes[:100], 5) low_connected = get_networkx_low_connected_components( "", persist_path) least_connected = random.sample(low_connected, 5) elif answer_model == 'hf/tiiuae/falcon-7b-instruct': print(answer_model) query_model = 'tiiuae/falcon-7b-instruct' clear_question(query_model) query_engine = build_hf_query_engine(query_model, persist_path) graph_nodes = get_networkx_graph_nodes( "", persist_path) most_connected = random.sample(graph_nodes[:100], 5) low_connected = get_networkx_low_connected_components( "", persist_path) least_connected = random.sample(low_connected, 5) elif answer_model == 'cohere/command': print(answer_model) query_model = 'cohere/command' clear_question(query_model) query_engine = build_cohere_query_engine(query_model, persist_path) graph_nodes = get_networkx_graph_nodes( "", persist_path) most_connected = random.sample(graph_nodes[:100], 5) low_connected = get_networkx_low_connected_components( "", persist_path) least_connected = random.sample(low_connected, 5) elif answer_model == 'baseten/Camel-5b': print(answer_model) query_model = 'baseten/Camel-5b' clear_question(query_model) query_engine = build_baseten_query_engine(query_model, persist_path) graph_nodes = get_networkx_graph_nodes( "", persist_path) most_connected = random.sample(graph_nodes[:100], 5) low_connected = get_networkx_low_connected_components( "", persist_path) least_connected = random.sample(low_connected, 5) elif answer_model == 'Writer/camel-5b-hf': query_model = 'Writer/camel-5b-hf' print(answer_model) clear_question(query_model) set_openai_local() query_engine = build_kron_query_engine(query_model, persist_path) graph_nodes = get_networkx_graph_nodes( "", persist_path) most_connected = random.sample(graph_nodes[:100], 5) low_connected = get_networkx_low_connected_components( "", persist_path) least_connected = random.sample(low_connected, 5) elif answer_model == 'mosaicml/mpt-7b-instruct': query_model = 'mosaicml/mpt-7b-instruct' clear_question(query_model) query_engine = build_hf_query_engine(query_model, persist_path) graph_nodes = get_networkx_graph_nodes( "", persist_path) most_connected = random.sample(graph_nodes[:100], 5) low_connected = get_networkx_low_connected_components( "", persist_path) least_connected = random.sample(low_connected, 5) else: print('This is a bug.') # to clear the input box def submit(): st.session_state.question = st.session_state.question_input st.session_state.question_input = '' st.session_state.question_answered = False with st.sidebar: import gensim m_connected = [] for item in most_connected: if not item[0].lower() in gensim.parsing.preprocessing.STOPWORDS: m_connected.extend([item[0].lower()]) option_1 = st.selectbox("What most are studying:", m_connected, disabled=True) option_2 = st.selectbox("What few are studying:", least_connected, disabled=True) with query: st.caption(f'''###### Intended for educational and research purpose. Please do not enter any private or confidential information. Model, question, answer and rating are logged to improve KG Questions.''') question = st.text_input("Enter a question, e.g. What benchmarks can we use for QA?", key='question_input', on_change=submit ) if(st.session_state.question): try : with query: col1, col2 = st.columns([2, 2]) if not st.session_state.question_answered: with st.spinner(f'Answering: {st.session_state.question} with {query_model}.'): start = time.time() answer = query_engine.query(st.session_state.question) st.session_state.answer = answer st.session_state.question_answered = True end = time.time() st.session_state.elapsed = (end-start) else: answer = st.session_state.answer answer_str = format_response(answer) with col1: if answer_str: elapsed = '{:.2f}'.format(st.session_state.elapsed) st.write(f'Answered: {st.session_state.question} with {query_model} in {elapsed}s. Please rate this answer.') with col2: from streamlit_star_rating import st_star_rating stars = st_star_rating("", maxValue=5, defaultValue=3, key="answer_rating") st.write(answer_str) with measurable: from measurable import display_wordcloud display_wordcloud(answer, answer_str) with explainable: from explainable import explain explain(answer) except Exception as e: answer_str = f'{type(e)}, {e}' st.session_state.answer_rating = -1 st.write(f'An error occured, please try again. \n{answer_str}') finally: if 'question' in st.session_state: req = st.session_state.question if(__spaces__): st.session_state.request_log.add_request_log_entry(query_model, req, answer_str, st.session_state.answer_rating) else: with measurable: st.write(f'###### Ask a question to see a comparison between the corpus, answer and reference documents.') with explainable: st.write(f'###### Ask a question to see the knowledge graph and a list of reference documents.') with ethical: from ethics import display_ethics display_ethics()