import streamlit as st import os import re import sys import base64 import logging logging.basicConfig(stream=sys.stdout, level=logging.INFO) logger = logging.getLogger(__name__) from dotenv import load_dotenv load_dotenv() #os.environ['AWS_DEFAULT_REGION'] = 'us-west-2' for key in st.session_state.keys(): #del st.session_state[key] print(f'session state entry: {key} {st.session_state[key]}') __spaces__ = os.environ.get('__SPACES__') if __spaces__: from kron.persistence.dynamodb_request_log import get_request_log; st.session_state.request_log = get_request_log() #third party service access #hf inference api hf_api_key = os.environ['HF_TOKEN'] ch_api_key = os.environ['COHERE_TOKEN'] bs_api_key = os.environ['BASETEN_TOKEN'] index_model = "Writer/camel-5b-hf" INDEX_NAME = f"{index_model.replace('/', '-')}-default-no-coref" persist_path = f"storage/{INDEX_NAME}" MAX_LENGTH = 1024 import baseten @st.cache_resource def set_baseten_key(bs_api_key): baseten.login(bs_api_key) set_baseten_key(bs_api_key) def autoplay_video(video_path): with open(video_path, "rb") as f: video_content = f.read() video_str = f"data:video/mp4;base64,{base64.b64encode(video_content).decode()}" st.markdown(f""" """, unsafe_allow_html=True) # sidebar with st.sidebar: st.header('KG Questions') video, text = st.columns([2, 2]) with video: autoplay_video('docs/images/kg_construction.mp4') with text: st.write( f''' ###### The construction of a Knowledge Graph is mesmerizing. ###### Concepts in the middle are what most are doing. Are we considering anything different? Why? Why not? ###### Concepts on the edge are what few are doing. Are we considering that? Why? Why not? ''' ) st.write( f''' #### How can help with ? ''') from llama_index import StorageContext from llama_index import ServiceContext from llama_index import load_index_from_storage from llama_index.langchain_helpers.text_splitter import SentenceSplitter from llama_index.node_parser import SimpleNodeParser from llama_index import LLMPredictor from langchain import HuggingFaceHub from langchain.llms.cohere import Cohere from langchain.llms import Baseten import tiktoken import openai #extensions to llama_index to support openai compatible endpoints, e.g. llama-api from kron.llm_predictor.KronOpenAILLM import KronOpenAI #baseten deployment expects a specific request format from kron.llm_predictor.KronBasetenCamelLLM import KronBasetenCamelLLM from kron.llm_predictor.KronLLMPredictor import KronLLMPredictor #writer/camel uses endoftext from llama_index.utils import globals_helper enc = tiktoken.get_encoding("gpt2") tokenizer = lambda text: enc.encode(text, allowed_special={"<|endoftext|>"}) globals_helper._tokenizer = tokenizer def set_openai_local(): openai.api_key = os.environ['LOCAL_OPENAI_API_KEY'] openai.api_base = os.environ['LOCAL_OPENAI_API_BASE'] os.environ['OPENAI_API_KEY'] = os.environ['LOCAL_OPENAI_API_KEY'] os.environ['OPENAI_API_BASE'] = os.environ['LOCAL_OPENAI_API_BASE'] def set_openai(): openai.api_key = os.environ['DAVINCI_OPENAI_API_KEY'] openai.api_base = os.environ['DAVINCI_OPENAI_API_BASE'] os.environ['OPENAI_API_KEY'] = os.environ['DAVINCI_OPENAI_API_KEY'] os.environ['OPENAI_API_BASE'] = os.environ['DAVINCI_OPENAI_API_BASE'] def get_hf_predictor(query_model): # no embeddings for now set_openai_local() llm=HuggingFaceHub(repo_id=query_model, task="text-generation", model_kwargs={"temperature": 0.01, "max_length": MAX_LENGTH}, huggingfacehub_api_token=hf_api_key) llm_predictor = LLMPredictor(llm) return llm_predictor def get_cohere_predictor(query_model): # no embeddings for now set_openai_local() llm=Cohere(model='command', temperature = 0.01, # model_kwargs={"temperature": 0.01, "max_length": MAX_LENGTH}, cohere_api_key=ch_api_key) llm_predictor = LLMPredictor(llm) return llm_predictor def get_baseten_predictor(query_model): # no embeddings for now set_openai_local() llm=KronBasetenCamelLLM(model='3yd1ke3', temperature = 0.01, # model_kwargs={"temperature": 0.01, "max_length": MAX_LENGTH, 'repetition_penalty':1.07}, model_kwargs={"temperature": 0.01, "max_length": MAX_LENGTH, 'frequency_penalty':1}, cohere_api_key=ch_api_key) llm_predictor = LLMPredictor(llm) return llm_predictor def get_kron_openai_predictor(query_model): # define LLM llm=KronOpenAI(temperature=0.01, model=query_model) llm.max_tokens = MAX_LENGTH llm_predictor = KronLLMPredictor(llm) return llm_predictor def get_servce_context(llm_predictor): # define TextSplitter text_splitter = SentenceSplitter(chunk_size=192, chunk_overlap=48, paragraph_separator='\n') #define NodeParser node_parser = SimpleNodeParser(text_splitter=text_splitter) #define ServiceContext service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor, node_parser=node_parser) return service_context def get_index(service_context, persist_path): print(f'Loading index from {persist_path}') # rebuild storage context storage_context = StorageContext.from_defaults(persist_dir=persist_path) # load index index = load_index_from_storage(storage_context=storage_context, service_context=service_context, max_triplets_per_chunk=2, show_progress = False) return index def get_query_engine(index): #writer/camel does not understand the refine prompt RESPONSE_MODE = 'accumulate' query_engine = index.as_query_engine(response_mode = RESPONSE_MODE) return query_engine def load_query_engine(llm_predictor, persist_path): service_context = get_servce_context(llm_predictor) index = get_index(service_context, persist_path) print(f'No query engine for {persist_path}; creating') query_engine = get_query_engine(index) return query_engine @st.cache_resource def build_kron_query_engine(query_model, persist_path): llm_predictor = get_kron_openai_predictor(query_model) query_engine = load_query_engine(llm_predictor, persist_path) return query_engine @st.cache_resource def build_hf_query_engine(query_model, persist_path): llm_predictor = get_hf_predictor(query_model) query_engine = load_query_engine(llm_predictor, persist_path) return query_engine @st.cache_resource def build_cohere_query_engine(query_model, persist_path): llm_predictor = get_cohere_predictor(query_model) query_engine = load_query_engine(llm_predictor, persist_path) return query_engine @st.cache_resource def build_baseten_query_engine(query_model, persist_path): llm_predictor = get_baseten_predictor(query_model) query_engine = load_query_engine(llm_predictor, persist_path) return query_engine def format_response(answer): # Replace any eventual -- dashes = r'(\-{2,50})' answer.response = re.sub(dashes, '', answer.response) return answer.response or "None" def clear_question(query_model): if not ('prev_model' in st.session_state) or (('prev_model' in st.session_state) and (st.session_state.prev_model != query_model)) : if 'prev_model' in st.session_state: print(f'clearing question {st.session_state.prev_model} {query_model}') else: print(f'clearing question None {query_model}') if('question_input' in st.session_state): st.session_state.question = st.session_state.question_input st.session_state.question_input = '' st.session_state.question_answered = False st.session_state.answer = '' st.session_state.answer_rating = 3 st.session_state.prev_model = query_model initial_query = '' if 'question' not in st.session_state: st.session_state.question = '' if __spaces__ : answer_model = st.radio( "Choose the model used for inference:", ('baseten/Camel-5b', 'cohere/command','hf/tiiuae/falcon-7b-instruct', 'openai/text-davinci-003') #TODO start hf inference container on demand # ('cohere/command','hf/tiiuae/falcon-7b-instruct', 'openai/text-davinci-003') ) else : answer_model = st.radio( "Choose the model used for inference:", ('Local-Camel', 'HF-TKI', 'hf/tiiuae/falcon-7b-instruct', 'openai/text-davinci-003') ) if answer_model == 'openai/text-davinci-003': print(answer_model) query_model = 'text-davinci-003' clear_question(query_model) set_openai() query_engine = build_kron_query_engine(query_model, persist_path) elif answer_model == 'hf/tiiuae/falcon-7b-instruct': print(answer_model) query_model = 'tiiuae/falcon-7b-instruct' clear_question(query_model) query_engine = build_hf_query_engine(query_model, persist_path) elif answer_model == 'cohere/command': print(answer_model) query_model = 'cohere/command' clear_question(query_model) query_engine = build_cohere_query_engine(query_model, persist_path) elif answer_model == 'baseten/Camel-5b': print(answer_model) query_model = 'baseten/Camel-5b' clear_question(query_model) query_engine = build_baseten_query_engine(query_model, persist_path) elif answer_model == 'Local-Camel': query_model = 'Writer/camel-5b-hf' print(answer_model) clear_question(query_model) set_openai_local() query_engine = build_kron_query_engine(query_model, persist_path) elif answer_model == 'HF-TKI': query_model = 'allenai/tk-instruct-3b-def-pos-neg-expl' clear_question(query_model) query_engine = build_hf_query_engine(query_model, persist_path) else: print('This is a bug.') # to clear input box def submit(): st.session_state.question = st.session_state.question_input st.session_state.question_input = '' st.session_state.question_answered = False st.write(f'Model, question, answer and rating are logged to help with the improvement of this application.') question = st.text_input("Enter a question, e.g. What benchmarks can we use for QA?", key='question_input', on_change=submit ) if(st.session_state.question): col1, col2 = st.columns([2, 2]) with col1: st.write(f'Answering: {st.session_state.question} with {query_model}.') try : if not st.session_state.question_answered: answer = query_engine.query(st.session_state.question) st.session_state.answer = answer st.session_state.question_answered = True else: answer = st.session_state.answer answer_str = format_response(answer) st.write(answer_str) with col1: if answer_str: st.write(f' Please rate this answer.') with col2: from streamlit_star_rating import st_star_rating stars = st_star_rating("", maxValue=5, defaultValue=3, key="answer_rating") #print(f"------stars {stars}") except Exception as e: #print(f'{type(e)}, {e}') answer_str = f'{type(e)}, {e}' st.session_state.answer_rating = -1 st.write(f'An error occured, please try again. \n{answer_str}') finally: if 'question' in st.session_state: req = st.session_state.question if(__spaces__): st.session_state.request_log.add_request_log_entry(query_model, req, answer_str, st.session_state.answer_rating)