File size: 3,898 Bytes
d8369f5
15027a6
262df6f
 
 
1caa52b
d8369f5
 
2de4fcc
d8369f5
 
 
15027a6
 
 
9cc5c95
d8369f5
 
 
 
 
 
 
bbdafb0
15027a6
 
d8369f5
 
15027a6
 
 
d8369f5
 
 
 
 
 
 
 
15027a6
d8369f5
9cc5c95
d8369f5
15027a6
 
d8369f5
 
 
15027a6
 
d8369f5
 
 
 
 
 
 
 
 
 
 
9cc5c95
 
d8369f5
 
9cc5c95
d8369f5
 
 
9cc5c95
 
 
 
d8369f5
9cc5c95
 
 
d8369f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bbdafb0
 
d8369f5
 
 
 
bbdafb0
15027a6
 
d8369f5
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import streamlit as st
import torch
from langchain_text_splitters import Language, RecursiveCharacterTextSplitter
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
from langchain.memory import ConversationBufferMemory
from langchain.chains import ConversationalRetrievalChain
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline



# gpt_model = 'gpt-4-1106-preview'
# embedding_model = 'text-embedding-3-small'
default_model_id = "bigcode/starcoder2-3b"
#default_model_id = "tiiuae/falcon-7b-instruct"

def init():
    if "conversation" not in st.session_state:
        st.session_state.conversation = None
    if "chat_history" not in st.session_state:
        st.session_state.chat_history = None        

def init_llm_pipeline(model_id):
    if "llm" not in st.session_state:     

        model = AutoModelForCausalLM.from_pretrained(
            model_id,
            device_map="auto"
        )      
        tokenizer = AutoTokenizer.from_pretrained(model_id)
        tokenizer.add_eos_token = True
        tokenizer.pad_token_id = 0
        tokenizer.padding_side = "left"

        text_generation_pipeline = pipeline(
        model=model,
        tokenizer=tokenizer,
        task="text-generation",
        max_new_tokens=1024
        )
        st.session_state.llm = text_generation_pipeline         

def get_retriever(files):
    documents = [doc.getvalue().decode("utf-8") for doc in files]
    python_splitter = RecursiveCharacterTextSplitter.from_language(
        language=Language.PYTHON, chunk_size=2000, chunk_overlap=200
    )

    texts = python_splitter.create_documents(documents)

    embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")

    db = FAISS.from_documents(texts, embeddings)
    retriever = db.as_retriever(
        search_type="mmr",  # Also test "similarity"
        search_kwargs={"k": 8},
    )
    return retriever
    
def get_conversation(retriever):
    #memory = ConversationBufferMemory(memory_key='chat_history', return_messages=True)

    conversation_chain = ConversationalRetrievalChain.from_llm(
        llm=st.session_state.llm,
        retriever=retriever   
    )
    return conversation_chain

def getprompt(user_input):
    prompt = f"You are a helpful assistant. Please answer the user question. USER: {user_input} ASSISTANT:"
    return prompt

def handle_user_input(question):
    st.session_state.chat_history += {"role":"user","content":question}
    response = st.session_state.llm(getprompt(question))
    st.session_state.chat_history += {"role":"assistant","content":response}
    for i, message in enumerate(st.session_state.chat_history):
        if i % 2 == 0:
            with st.chat_message("user"):
                st.write(message.content)
        else:
            with st.chat_message("assistant"):
                st.write(message.content)

def main():
    init()

    st.set_page_config(page_title="Coding-Assistent", page_icon=":books:")

    st.header(":books: Coding-Assistent ")
    user_input = st.chat_input("Stellen Sie Ihre Frage hier")
    if user_input:
        with st.spinner("Führe Anfrage aus ..."):        
            handle_user_input(user_input)


    with st.sidebar:
        st.subheader("Model selector")
        model_id = st.text_input("Modelname on HuggingFace", default_model_id) 
        st.subheader("Code Upload")
        upload_docs=st.file_uploader("Dokumente hier hochladen", accept_multiple_files=True)
        if st.button("Hochladen"):
            with st.spinner("Analysiere Dokumente ..."):
                init_llm_pipeline(model_id)
                retriever = get_retriever(upload_docs)
                st.session_state.conversation = get_conversation(retriever) 


if __name__ == "__main__":
    main()