mlk8s / app.py
Arylwen's picture
v0.0.6 sidebar
82caff6
raw
history blame
No virus
11.9 kB
import streamlit as st
import os
import re
import sys
import base64
import logging
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logger = logging.getLogger(__name__)
from dotenv import load_dotenv
load_dotenv()
#os.environ['AWS_DEFAULT_REGION'] = 'us-west-2'
for key in st.session_state.keys():
#del st.session_state[key]
print(f'session state entry: {key} {st.session_state[key]}')
__spaces__ = os.environ.get('__SPACES__')
if __spaces__:
from kron.persistence.dynamodb_request_log import get_request_log;
st.session_state.request_log = get_request_log()
#third party service access
#hf inference api
hf_api_key = os.environ['HF_TOKEN']
ch_api_key = os.environ['COHERE_TOKEN']
bs_api_key = os.environ['BASETEN_TOKEN']
index_model = "Writer/camel-5b-hf"
INDEX_NAME = f"{index_model.replace('/', '-')}-default-no-coref"
persist_path = f"storage/{INDEX_NAME}"
MAX_LENGTH = 1024
import baseten
@st.cache_resource
def set_baseten_key(bs_api_key):
baseten.login(bs_api_key)
set_baseten_key(bs_api_key)
def autoplay_video(video_path):
with open(video_path, "rb") as f:
video_content = f.read()
video_str = f"data:video/mp4;base64,{base64.b64encode(video_content).decode()}"
st.markdown(f"""
<video style="display: block; margin: auto; width: 140px;" controls loop autoplay width="140" height="180">
<source src="{video_str}" type="video/mp4">
</video>
""", unsafe_allow_html=True)
# sidebar
with st.sidebar:
st.header('KG Questions')
video, text = st.columns([2, 2])
with video:
autoplay_video('docs/images/kg_construction.mp4')
with text:
st.write(
f'''
###### The construction of a Knowledge Graph is mesmerizing.
###### Concepts in the middle are what most are doing. Are we considering anything different? Why? Why not?
###### Concepts on the edge are what few are doing. Are we considering that? Why? Why not?
'''
)
st.write(
f'''
#### How can <what most are doing> help with <what few are doing>?
''')
from llama_index import StorageContext
from llama_index import ServiceContext
from llama_index import load_index_from_storage
from llama_index.langchain_helpers.text_splitter import SentenceSplitter
from llama_index.node_parser import SimpleNodeParser
from llama_index import LLMPredictor
from langchain import HuggingFaceHub
from langchain.llms.cohere import Cohere
from langchain.llms import Baseten
import tiktoken
import openai
#extensions to llama_index to support openai compatible endpoints, e.g. llama-api
from kron.llm_predictor.KronOpenAILLM import KronOpenAI
#baseten deployment expects a specific request format
from kron.llm_predictor.KronBasetenCamelLLM import KronBasetenCamelLLM
from kron.llm_predictor.KronLLMPredictor import KronLLMPredictor
#writer/camel uses endoftext
from llama_index.utils import globals_helper
enc = tiktoken.get_encoding("gpt2")
tokenizer = lambda text: enc.encode(text, allowed_special={"<|endoftext|>"})
globals_helper._tokenizer = tokenizer
def set_openai_local():
openai.api_key = os.environ['LOCAL_OPENAI_API_KEY']
openai.api_base = os.environ['LOCAL_OPENAI_API_BASE']
os.environ['OPENAI_API_KEY'] = os.environ['LOCAL_OPENAI_API_KEY']
os.environ['OPENAI_API_BASE'] = os.environ['LOCAL_OPENAI_API_BASE']
def set_openai():
openai.api_key = os.environ['DAVINCI_OPENAI_API_KEY']
openai.api_base = os.environ['DAVINCI_OPENAI_API_BASE']
os.environ['OPENAI_API_KEY'] = os.environ['DAVINCI_OPENAI_API_KEY']
os.environ['OPENAI_API_BASE'] = os.environ['DAVINCI_OPENAI_API_BASE']
def get_hf_predictor(query_model):
# no embeddings for now
set_openai_local()
llm=HuggingFaceHub(repo_id=query_model, task="text-generation",
model_kwargs={"temperature": 0.01, "max_length": MAX_LENGTH},
huggingfacehub_api_token=hf_api_key)
llm_predictor = LLMPredictor(llm)
return llm_predictor
def get_cohere_predictor(query_model):
# no embeddings for now
set_openai_local()
llm=Cohere(model='command', temperature = 0.01,
# model_kwargs={"temperature": 0.01, "max_length": MAX_LENGTH},
cohere_api_key=ch_api_key)
llm_predictor = LLMPredictor(llm)
return llm_predictor
def get_baseten_predictor(query_model):
# no embeddings for now
set_openai_local()
llm=KronBasetenCamelLLM(model='3yd1ke3', temperature = 0.01,
# model_kwargs={"temperature": 0.01, "max_length": MAX_LENGTH, 'repetition_penalty':1.07},
model_kwargs={"temperature": 0.01, "max_length": MAX_LENGTH, 'frequency_penalty':1},
cohere_api_key=ch_api_key)
llm_predictor = LLMPredictor(llm)
return llm_predictor
def get_kron_openai_predictor(query_model):
# define LLM
llm=KronOpenAI(temperature=0.01, model=query_model)
llm.max_tokens = MAX_LENGTH
llm_predictor = KronLLMPredictor(llm)
return llm_predictor
def get_servce_context(llm_predictor):
# define TextSplitter
text_splitter = SentenceSplitter(chunk_size=192, chunk_overlap=48, paragraph_separator='\n')
#define NodeParser
node_parser = SimpleNodeParser(text_splitter=text_splitter)
#define ServiceContext
service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor, node_parser=node_parser)
return service_context
def get_index(service_context, persist_path):
print(f'Loading index from {persist_path}')
# rebuild storage context
storage_context = StorageContext.from_defaults(persist_dir=persist_path)
# load index
index = load_index_from_storage(storage_context=storage_context,
service_context=service_context,
max_triplets_per_chunk=2,
show_progress = False)
return index
def get_query_engine(index):
#writer/camel does not understand the refine prompt
RESPONSE_MODE = 'accumulate'
query_engine = index.as_query_engine(response_mode = RESPONSE_MODE)
return query_engine
def load_query_engine(llm_predictor, persist_path):
service_context = get_servce_context(llm_predictor)
index = get_index(service_context, persist_path)
print(f'No query engine for {persist_path}; creating')
query_engine = get_query_engine(index)
return query_engine
@st.cache_resource
def build_kron_query_engine(query_model, persist_path):
llm_predictor = get_kron_openai_predictor(query_model)
query_engine = load_query_engine(llm_predictor, persist_path)
return query_engine
@st.cache_resource
def build_hf_query_engine(query_model, persist_path):
llm_predictor = get_hf_predictor(query_model)
query_engine = load_query_engine(llm_predictor, persist_path)
return query_engine
@st.cache_resource
def build_cohere_query_engine(query_model, persist_path):
llm_predictor = get_cohere_predictor(query_model)
query_engine = load_query_engine(llm_predictor, persist_path)
return query_engine
@st.cache_resource
def build_baseten_query_engine(query_model, persist_path):
llm_predictor = get_baseten_predictor(query_model)
query_engine = load_query_engine(llm_predictor, persist_path)
return query_engine
def format_response(answer):
# Replace any eventual --
dashes = r'(\-{2,50})'
answer.response = re.sub(dashes, '', answer.response)
return answer.response or "None"
def clear_question(query_model):
if not ('prev_model' in st.session_state) or (('prev_model' in st.session_state) and (st.session_state.prev_model != query_model)) :
if 'prev_model' in st.session_state:
print(f'clearing question {st.session_state.prev_model} {query_model}')
else:
print(f'clearing question None {query_model}')
if('question_input' in st.session_state):
st.session_state.question = st.session_state.question_input
st.session_state.question_input = ''
st.session_state.question_answered = False
st.session_state.answer = ''
st.session_state.answer_rating = 3
st.session_state.prev_model = query_model
initial_query = ''
if 'question' not in st.session_state:
st.session_state.question = ''
if __spaces__ :
answer_model = st.radio(
"Choose the model used for inference:",
('baseten/Camel-5b', 'cohere/command','hf/tiiuae/falcon-7b-instruct', 'openai/text-davinci-003') #TODO start hf inference container on demand
# ('cohere/command','hf/tiiuae/falcon-7b-instruct', 'openai/text-davinci-003')
)
else :
answer_model = st.radio(
"Choose the model used for inference:",
('Local-Camel', 'HF-TKI', 'hf/tiiuae/falcon-7b-instruct', 'openai/text-davinci-003')
)
if answer_model == 'openai/text-davinci-003':
print(answer_model)
query_model = 'text-davinci-003'
clear_question(query_model)
set_openai()
query_engine = build_kron_query_engine(query_model, persist_path)
elif answer_model == 'hf/tiiuae/falcon-7b-instruct':
print(answer_model)
query_model = 'tiiuae/falcon-7b-instruct'
clear_question(query_model)
query_engine = build_hf_query_engine(query_model, persist_path)
elif answer_model == 'cohere/command':
print(answer_model)
query_model = 'cohere/command'
clear_question(query_model)
query_engine = build_cohere_query_engine(query_model, persist_path)
elif answer_model == 'baseten/Camel-5b':
print(answer_model)
query_model = 'baseten/Camel-5b'
clear_question(query_model)
query_engine = build_baseten_query_engine(query_model, persist_path)
elif answer_model == 'Local-Camel':
query_model = 'Writer/camel-5b-hf'
print(answer_model)
clear_question(query_model)
set_openai_local()
query_engine = build_kron_query_engine(query_model, persist_path)
elif answer_model == 'HF-TKI':
query_model = 'allenai/tk-instruct-3b-def-pos-neg-expl'
clear_question(query_model)
query_engine = build_hf_query_engine(query_model, persist_path)
else:
print('This is a bug.')
# to clear input box
def submit():
st.session_state.question = st.session_state.question_input
st.session_state.question_input = ''
st.session_state.question_answered = False
st.write(f'Model, question, answer and rating are logged to help with the improvement of this application.')
question = st.text_input("Enter a question, e.g. What benchmarks can we use for QA?", key='question_input', on_change=submit )
if(st.session_state.question):
col1, col2 = st.columns([2, 2])
with col1:
st.write(f'Answering: {st.session_state.question} with {query_model}.')
try :
if not st.session_state.question_answered:
answer = query_engine.query(st.session_state.question)
st.session_state.answer = answer
st.session_state.question_answered = True
else:
answer = st.session_state.answer
answer_str = format_response(answer)
st.write(answer_str)
with col1:
if answer_str:
st.write(f' Please rate this answer.')
with col2:
from streamlit_star_rating import st_star_rating
stars = st_star_rating("", maxValue=5, defaultValue=3, key="answer_rating")
#print(f"------stars {stars}")
except Exception as e:
#print(f'{type(e)}, {e}')
answer_str = f'{type(e)}, {e}'
st.session_state.answer_rating = -1
st.write(f'An error occured, please try again. \n{answer_str}')
finally:
if 'question' in st.session_state:
req = st.session_state.question
if(__spaces__):
st.session_state.request_log.add_request_log_entry(query_model, req, answer_str, st.session_state.answer_rating)