Spaces:
Runtime error
Runtime error
import streamlit as st | |
import torch | |
from langchain_text_splitters import Language, RecursiveCharacterTextSplitter | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
from langchain_community.vectorstores import FAISS | |
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline | |
from langchain.memory import ConversationBufferMemory | |
from langchain.chains import ConversationalRetrievalChain | |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
# gpt_model = 'gpt-4-1106-preview' | |
# embedding_model = 'text-embedding-3-small' | |
default_model_id = "bigcode/starcoder2-3b" | |
#default_model_id = "tiiuae/falcon-7b-instruct" | |
def init(): | |
if "conversation" not in st.session_state: | |
st.session_state.conversation = None | |
if "chat_history" not in st.session_state: | |
st.session_state.chat_history = [] | |
def init_llm_pipeline(model_id): | |
if "llm" not in st.session_state: | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
device_map="auto" | |
) | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
tokenizer.add_eos_token = True | |
tokenizer.pad_token_id = 0 | |
tokenizer.padding_side = "left" | |
text_generation_pipeline = pipeline( | |
model=model, | |
tokenizer=tokenizer, | |
task="text-generation", | |
max_new_tokens=1024 | |
) | |
st.session_state.llm = text_generation_pipeline | |
def get_retriever(files): | |
documents = [doc.getvalue().decode("utf-8") for doc in files] | |
python_splitter = RecursiveCharacterTextSplitter.from_language( | |
language=Language.PYTHON, chunk_size=2000, chunk_overlap=200 | |
) | |
texts = python_splitter.create_documents(documents) | |
embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2") | |
db = FAISS.from_documents(texts, embeddings) | |
retriever = db.as_retriever( | |
search_type="mmr", # Also test "similarity" | |
search_kwargs={"k": 8}, | |
) | |
return retriever | |
def get_conversation(retriever): | |
#memory = ConversationBufferMemory(memory_key='chat_history', return_messages=True) | |
conversation_chain = ConversationalRetrievalChain.from_llm( | |
llm=st.session_state.llm, | |
retriever=retriever | |
) | |
return conversation_chain | |
def getprompt(user_input): | |
prompt = f"You are a helpful assistant. Please answer the user question. USER: {user_input} ASSISTANT:" | |
return prompt | |
def handle_user_input(question): | |
st.session_state.chat_history += {"role":"user","content":question} | |
response = st.session_state.llm(getprompt(question)) | |
st.session_state.chat_history += {"role":"assistant","content":response} | |
for i, message in enumerate(st.session_state.chat_history): | |
if i % 2 == 0: | |
with st.chat_message("user"): | |
st.write(message.content) | |
else: | |
with st.chat_message("assistant"): | |
st.write(message.content) | |
def main(): | |
init() | |
st.set_page_config(page_title="Coding-Assistent", page_icon=":books:") | |
st.header(":books: Coding-Assistent ") | |
user_input = st.chat_input("Stellen Sie Ihre Frage hier") | |
if user_input: | |
with st.spinner("Führe Anfrage aus ..."): | |
handle_user_input(user_input) | |
with st.sidebar: | |
st.subheader("Model selector") | |
model_id = st.text_input("Modelname on HuggingFace", default_model_id) | |
st.subheader("Code Upload") | |
upload_docs=st.file_uploader("Dokumente hier hochladen", accept_multiple_files=True) | |
if st.button("Hochladen"): | |
with st.spinner("Analysiere Dokumente ..."): | |
init_llm_pipeline(model_id) | |
#retriever = get_retriever(upload_docs) | |
#st.session_state.conversation = get_conversation(retriever) | |
if __name__ == "__main__": | |
main() |