ag-mach commited on
Commit
d8369f5
1 Parent(s): abe1d26

initial commit

Browse files
Files changed (3) hide show
  1. .gitignore +7 -0
  2. app.py +116 -0
  3. requirements.txt +12 -0
.gitignore ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *$py.class
3
+
4
+ *.so
5
+ .env
6
+ .~env
7
+ venv/
app.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # UI comes here
2
+ import streamlit as st
3
+
4
+ from langchain_text_splitters import Language
5
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
6
+ from langchain.embeddings import HuggingFaceEmbeddings
7
+ from langchain.vectorstores import FAISS
8
+ from langchain.memory import ConversationBufferMemory
9
+ from langchain.chains import ConversationalRetrievalChain
10
+ from transformers import pipeline
11
+ from langchain import HuggingFacePipeline
12
+
13
+
14
+
15
+ gpt_model = 'gpt-4-1106-preview'
16
+ embedding_model = 'text-embedding-3-small'
17
+
18
+ def init():
19
+ if "conversation" not in st.session_state:
20
+ st.session_state.conversation = None
21
+ if "chat_history" not in st.session_state:
22
+ st.session_state.chat_history = None
23
+
24
+ def init_llm_pipeline(openai_key):
25
+ if "llm" not in st.session_state:
26
+ model_id = "bigcode/starcoder2-15b"
27
+ quantization_config = BitsAndBytesConfig(
28
+ load_in_4bit=True,
29
+ bnb_4bit_compute_dtype=torch.float16
30
+ )
31
+
32
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
33
+ model = AutoModelForCausalLM.from_pretrained(
34
+ model_id,
35
+ quantization_config=quantization_config,
36
+ device_map="auto",
37
+ )
38
+ tokenizer.add_eos_token = True
39
+ tokenizer.pad_token_id = 0
40
+ tokenizer.padding_side = "left"
41
+
42
+ text_generation_pipeline = pipeline(
43
+ model=model,
44
+ tokenizer=tokenizer,
45
+ task="text-generation",
46
+ temperature=0.7,
47
+ repetition_penalty=1.1,
48
+ return_full_text=True,
49
+ max_new_tokens=300,
50
+ )
51
+ st.session_state.llm = HuggingFacePipeline(pipeline=text_generation_pipeline)
52
+
53
+ def get_text(docs):
54
+ return docs.getvalue().decode("utf-8")
55
+
56
+ def get_vectorstore(documents):
57
+ python_splitter = RecursiveCharacterTextSplitter.from_language(
58
+ language=Language.PYTHON, chunk_size=2000, chunk_overlap=200
59
+ )
60
+ texts = python_splitter.split_documents(documents)
61
+
62
+ embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
63
+
64
+ db = FAISS.from_documents(texts, embeddings)
65
+ retriever = db.as_retriever(
66
+ search_type="mmr", # Also test "similarity"
67
+ search_kwargs={"k": 8},
68
+ )
69
+ return retriever
70
+
71
+ def get_conversation(retriever):
72
+ memory = ConversationBufferMemory(memory_key='chat_history', return_messages=True)
73
+ conversation_chain = ConversationalRetrievalChain.from_llm(
74
+ llm=st.session_state.llm,
75
+ retriever=retriever,
76
+ memory = memory
77
+ )
78
+ return conversation_chain
79
+
80
+ def handle_user_input(question):
81
+ response = st.session_state.conversation({'question':question})
82
+ st.session_state.chat_history = response['chat_history']
83
+ for i, message in enumerate(st.session_state.chat_history):
84
+ if i % 2 == 0:
85
+ with st.chat_message("user"):
86
+ st.write(message.content)
87
+ else:
88
+ with st.chat_message("assistant"):
89
+ st.write(message.content)
90
+
91
+ def main():
92
+ #load_dotenv()
93
+ init()
94
+
95
+ st.set_page_config(page_title="Coding-Assistent", page_icon=":books:")
96
+
97
+ st.header(":books: Coding-Assistent ")
98
+ user_input = st.chat_input("Stellen Sie Ihre Frage hier")
99
+ if user_input:
100
+ with st.spinner("Führe Anfrage aus ..."):
101
+ handle_user_input(user_input)
102
+
103
+
104
+ with st.sidebar:
105
+ st.subheader("Code Upload")
106
+ upload_docs=st.file_uploader("Dokumente hier hochladen", accept_multiple_files=True)
107
+ if st.button("Hochladen"):
108
+ with st.spinner("Analysiere Dokumente ..."):
109
+ init_llm_pipeline()
110
+ raw_text = get_text(upload_docs)
111
+ vectorstore = get_vectorstore(raw_text)
112
+ st.session_state.conversation = get_conversation(vectorstore)
113
+
114
+
115
+ if __name__ == "__main__":
116
+ main()
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ streamlit
2
+ langchain
3
+ langchain-community
4
+ python-dotenv
5
+ faiss-cpu
6
+ huggingface-hub
7
+ accelerate
8
+ bitsandbytes
9
+ torch
10
+ langchain-text-splitters
11
+ sentence_transformers
12
+ git+https://github.com/huggingface/transformers.git