starcoder2-test / app.py
mister-g's picture
retriever commented
19f2a05
import streamlit as st
import torch
from langchain_text_splitters import Language, RecursiveCharacterTextSplitter
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
from langchain.memory import ConversationBufferMemory
from langchain.chains import ConversationalRetrievalChain
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
# gpt_model = 'gpt-4-1106-preview'
# embedding_model = 'text-embedding-3-small'
default_model_id = "bigcode/starcoder2-3b"
#default_model_id = "tiiuae/falcon-7b-instruct"
def init():
if "conversation" not in st.session_state:
st.session_state.conversation = None
if "chat_history" not in st.session_state:
st.session_state.chat_history = []
def init_llm_pipeline(model_id):
if "llm" not in st.session_state:
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.add_eos_token = True
tokenizer.pad_token_id = 0
tokenizer.padding_side = "left"
text_generation_pipeline = pipeline(
model=model,
tokenizer=tokenizer,
task="text-generation",
max_new_tokens=1024
)
st.session_state.llm = text_generation_pipeline
def get_retriever(files):
documents = [doc.getvalue().decode("utf-8") for doc in files]
python_splitter = RecursiveCharacterTextSplitter.from_language(
language=Language.PYTHON, chunk_size=2000, chunk_overlap=200
)
texts = python_splitter.create_documents(documents)
embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
db = FAISS.from_documents(texts, embeddings)
retriever = db.as_retriever(
search_type="mmr", # Also test "similarity"
search_kwargs={"k": 8},
)
return retriever
def get_conversation(retriever):
#memory = ConversationBufferMemory(memory_key='chat_history', return_messages=True)
conversation_chain = ConversationalRetrievalChain.from_llm(
llm=st.session_state.llm,
retriever=retriever
)
return conversation_chain
def getprompt(user_input):
prompt = f"You are a helpful assistant. Please answer the user question. USER: {user_input} ASSISTANT:"
return prompt
def handle_user_input(question):
st.session_state.chat_history += {"role":"user","content":question}
response = st.session_state.llm(getprompt(question))
st.session_state.chat_history += {"role":"assistant","content":response}
for i, message in enumerate(st.session_state.chat_history):
if i % 2 == 0:
with st.chat_message("user"):
st.write(message.content)
else:
with st.chat_message("assistant"):
st.write(message.content)
def main():
init()
st.set_page_config(page_title="Coding-Assistent", page_icon=":books:")
st.header(":books: Coding-Assistent ")
user_input = st.chat_input("Stellen Sie Ihre Frage hier")
if user_input:
with st.spinner("Führe Anfrage aus ..."):
handle_user_input(user_input)
with st.sidebar:
st.subheader("Model selector")
model_id = st.text_input("Modelname on HuggingFace", default_model_id)
st.subheader("Code Upload")
upload_docs=st.file_uploader("Dokumente hier hochladen", accept_multiple_files=True)
if st.button("Hochladen"):
with st.spinner("Analysiere Dokumente ..."):
init_llm_pipeline(model_id)
#retriever = get_retriever(upload_docs)
#st.session_state.conversation = get_conversation(retriever)
if __name__ == "__main__":
main()