Spaces:
				
			
			
	
			
			
		Sleeping
		
	
	
	
			
			
	
	
	
	
		
		
		Sleeping
		
	Commit 
							
							·
						
						6c6516f
	
1
								Parent(s):
							
							99ff44d
								
improved dependancies
Browse files- app.py +27 -24
- mistral7b.py +1 -1
- requirements.txt +42 -57
    	
        app.py
    CHANGED
    
    | @@ -5,51 +5,54 @@ from mistral7b import mistral | |
| 5 | 
             
            import time
         | 
| 6 |  | 
| 7 |  | 
| 8 | 
            -
             | 
| 9 | 
             
            if "messages" not in st.session_state:
         | 
| 10 | 
             
                st.session_state.messages = []
         | 
| 11 |  | 
| 12 | 
            -
            if "tokens_used" | 
| 13 | 
             
                st.session_state.tokens_used = 0
         | 
| 14 |  | 
| 15 | 
            -
            if " | 
| 16 | 
             
                st.session_state.inference_time = [0.00]
         | 
| 17 |  | 
| 18 |  | 
| 19 | 
            -
            if "temp" not in st.session_state | 
| 20 | 
             
                st.session_state.temp = 0.8
         | 
| 21 |  | 
| 22 | 
            -
            if "model_settings" not in st.session_state | 
| 23 | 
             
                st.session_state.model_settings = {
         | 
| 24 | 
            -
                    "temp" | 
| 25 | 
            -
                    "max_tokens" | 
| 26 | 
             
                }
         | 
| 27 |  | 
| 28 | 
            -
            if "history" not in st.session_state | 
| 29 | 
             
                st.session_state.history = []
         | 
| 30 |  | 
| 31 | 
            -
            if "top_k" not in st.session_state | 
| 32 | 
             
                st.session_state.top_k = 5
         | 
| 33 |  | 
| 34 | 
             
            with st.sidebar:
         | 
| 35 | 
             
                st.markdown("# Model Analytics")
         | 
| 36 | 
             
                st.write("Tokens used :", st.session_state['tokens_used'])
         | 
| 37 |  | 
| 38 | 
            -
                st.write("Average Inference Time: ", round(sum( | 
| 39 | 
            -
             | 
|  | |
|  | |
| 40 |  | 
| 41 | 
             
                st.markdown("---")
         | 
| 42 | 
             
                st.markdown("# Retrieval Settings")
         | 
| 43 | 
            -
                st.slider(label="Documents to retrieve", | 
|  | |
| 44 | 
             
                st.markdown("---")
         | 
| 45 | 
             
                st.markdown("# Model Settings")
         | 
| 46 | 
            -
                selected_model = st.sidebar.radio( | 
| 47 | 
            -
             | 
|  | |
|  | |
| 48 | 
             
                st.write(" ")
         | 
| 49 | 
             
                st.info("**2023 ©️ Pragnesh Barik**")
         | 
| 50 |  | 
| 51 |  | 
| 52 | 
            -
             | 
| 53 | 
             
            st.image("ikigai.svg")
         | 
| 54 | 
             
            st.title("Ikigai Chat")
         | 
| 55 |  | 
| @@ -67,21 +70,21 @@ for message in st.session_state.messages: | |
| 67 | 
             
            if prompt := st.chat_input("Chat with Ikigai Docs?"):
         | 
| 68 | 
             
                st.chat_message("user").markdown(prompt)
         | 
| 69 | 
             
                st.session_state.messages.append({"role": "user", "content": prompt})
         | 
| 70 | 
            -
             | 
| 71 | 
             
                tick = time.time()
         | 
| 72 | 
            -
                response = mistral(prompt, st.session_state.history, | 
|  | |
| 73 | 
             
                tock = time.time()
         | 
| 74 |  | 
| 75 | 
            -
                
         | 
| 76 | 
             
                st.session_state.inference_time.append(tock - tick)
         | 
| 77 | 
             
                response = response.replace("</s>", "")
         | 
| 78 | 
             
                len_response = len(response.split())
         | 
| 79 |  | 
| 80 | 
            -
                st.session_state["tokens_used"] = | 
| 81 | 
            -
             | 
| 82 | 
            -
             | 
|  | |
| 83 | 
             
                    st.markdown(response)
         | 
| 84 | 
             
                st.session_state.history.append([prompt, response])
         | 
| 85 | 
            -
                st.session_state.messages.append( | 
| 86 | 
            -
             | 
| 87 | 
            -
             | 
|  | |
| 5 | 
             
            import time
         | 
| 6 |  | 
| 7 |  | 
|  | |
| 8 | 
             
            if "messages" not in st.session_state:
         | 
| 9 | 
             
                st.session_state.messages = []
         | 
| 10 |  | 
| 11 | 
            +
            if "tokens_used" not in st.session_state:
         | 
| 12 | 
             
                st.session_state.tokens_used = 0
         | 
| 13 |  | 
| 14 | 
            +
            if "inference_tipipme" not in st.session_state:
         | 
| 15 | 
             
                st.session_state.inference_time = [0.00]
         | 
| 16 |  | 
| 17 |  | 
| 18 | 
            +
            if "temp" not in st.session_state:
         | 
| 19 | 
             
                st.session_state.temp = 0.8
         | 
| 20 |  | 
| 21 | 
            +
            if "model_settings" not in st.session_state:
         | 
| 22 | 
             
                st.session_state.model_settings = {
         | 
| 23 | 
            +
                    "temp": 0.9,
         | 
| 24 | 
            +
                    "max_tokens": 512,
         | 
| 25 | 
             
                }
         | 
| 26 |  | 
| 27 | 
            +
            if "history" not in st.session_state:
         | 
| 28 | 
             
                st.session_state.history = []
         | 
| 29 |  | 
| 30 | 
            +
            if "top_k" not in st.session_state:
         | 
| 31 | 
             
                st.session_state.top_k = 5
         | 
| 32 |  | 
| 33 | 
             
            with st.sidebar:
         | 
| 34 | 
             
                st.markdown("# Model Analytics")
         | 
| 35 | 
             
                st.write("Tokens used :", st.session_state['tokens_used'])
         | 
| 36 |  | 
| 37 | 
            +
                st.write("Average Inference Time: ", round(sum(
         | 
| 38 | 
            +
                    st.session_state["inference_time"]) / len(st.session_state["inference_time"]), 3))
         | 
| 39 | 
            +
                st.write("Cost Incured :", round(
         | 
| 40 | 
            +
                    0.033 * st.session_state['tokens_used'] / 1000, 3), "INR")
         | 
| 41 |  | 
| 42 | 
             
                st.markdown("---")
         | 
| 43 | 
             
                st.markdown("# Retrieval Settings")
         | 
| 44 | 
            +
                st.slider(label="Documents to retrieve",
         | 
| 45 | 
            +
                          min_value=1, max_value=10, value=3)
         | 
| 46 | 
             
                st.markdown("---")
         | 
| 47 | 
             
                st.markdown("# Model Settings")
         | 
| 48 | 
            +
                selected_model = st.sidebar.radio(
         | 
| 49 | 
            +
                    'Select one:', ["Mistral 7B", "GPT 3.5 Turbo", "GPT 4",  "Llama 7B"])
         | 
| 50 | 
            +
                selected_temperature = st.slider(
         | 
| 51 | 
            +
                    label="Temperature", min_value=0.0, max_value=1.0, step=0.1, value=0.5)
         | 
| 52 | 
             
                st.write(" ")
         | 
| 53 | 
             
                st.info("**2023 ©️ Pragnesh Barik**")
         | 
| 54 |  | 
| 55 |  | 
|  | |
| 56 | 
             
            st.image("ikigai.svg")
         | 
| 57 | 
             
            st.title("Ikigai Chat")
         | 
| 58 |  | 
|  | |
| 70 | 
             
            if prompt := st.chat_input("Chat with Ikigai Docs?"):
         | 
| 71 | 
             
                st.chat_message("user").markdown(prompt)
         | 
| 72 | 
             
                st.session_state.messages.append({"role": "user", "content": prompt})
         | 
| 73 | 
            +
             | 
| 74 | 
             
                tick = time.time()
         | 
| 75 | 
            +
                response = mistral(prompt, st.session_state.history,
         | 
| 76 | 
            +
                                   temperature=st.session_state.model_settings["temp"], max_new_tokens=st.session_state.model_settings["max_tokens"])
         | 
| 77 | 
             
                tock = time.time()
         | 
| 78 |  | 
|  | |
| 79 | 
             
                st.session_state.inference_time.append(tock - tick)
         | 
| 80 | 
             
                response = response.replace("</s>", "")
         | 
| 81 | 
             
                len_response = len(response.split())
         | 
| 82 |  | 
| 83 | 
            +
                st.session_state["tokens_used"] = len_response + \
         | 
| 84 | 
            +
                    st.session_state["tokens_used"]
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                with st.chat_message("assistant"):
         | 
| 87 | 
             
                    st.markdown(response)
         | 
| 88 | 
             
                st.session_state.history.append([prompt, response])
         | 
| 89 | 
            +
                st.session_state.messages.append(
         | 
| 90 | 
            +
                    {"role": "assistant", "content": response})
         | 
|  | 
    	
        mistral7b.py
    CHANGED
    
    | @@ -42,6 +42,6 @@ def mistral( | |
| 42 |  | 
| 43 | 
             
                for response in stream:
         | 
| 44 | 
             
                    # print(response)
         | 
| 45 | 
            -
                    output += response.token | 
| 46 | 
             
                    # yield output
         | 
| 47 | 
             
                return output
         | 
|  | |
| 42 |  | 
| 43 | 
             
                for response in stream:
         | 
| 44 | 
             
                    # print(response)
         | 
| 45 | 
            +
                    output += response.token.text
         | 
| 46 | 
             
                    # yield output
         | 
| 47 | 
             
                return output
         | 
    	
        requirements.txt
    CHANGED
    
    | @@ -1,91 +1,76 @@ | |
| 1 | 
             
            altair==5.1.2
         | 
|  | |
| 2 | 
             
            attrs==23.1.0
         | 
| 3 | 
            -
             | 
| 4 | 
            -
            bitarray==2.8.1
         | 
| 5 | 
             
            blinker==1.6.3
         | 
| 6 | 
             
            cachetools==5.3.1
         | 
| 7 | 
            -
            huggingface-hub==0.16.4
         | 
| 8 | 
             
            certifi==2023.7.22
         | 
| 9 | 
            -
            charset-normalizer==3. | 
| 10 | 
             
            click==8.1.7
         | 
| 11 | 
            -
             | 
| 12 | 
            -
             | 
| 13 | 
            -
             | 
| 14 | 
            -
             | 
| 15 | 
            -
             | 
| 16 | 
            -
             | 
| 17 | 
            -
             | 
| 18 | 
            -
            future==0.18.3
         | 
| 19 | 
             
            gitdb==4.0.10
         | 
| 20 | 
             
            GitPython==3.1.37
         | 
| 21 | 
            -
             | 
| 22 | 
            -
            greenlet==2.0.2
         | 
| 23 | 
            -
            huggingface-hub==0.16.4
         | 
| 24 | 
            -
            humanfriendly==10.0
         | 
| 25 | 
             
            idna==3.4
         | 
| 26 | 
             
            importlib-metadata==6.8.0
         | 
| 27 | 
            -
             | 
|  | |
|  | |
| 28 | 
             
            Jinja2==3.1.2
         | 
| 29 | 
            -
            joblib==1.3.2
         | 
| 30 | 
             
            jsonschema==4.19.1
         | 
| 31 | 
             
            jsonschema-specifications==2023.7.1
         | 
| 32 | 
            -
             | 
| 33 | 
            -
             | 
| 34 | 
             
            markdown-it-py==3.0.0
         | 
| 35 | 
             
            MarkupSafe==2.1.3
         | 
|  | |
| 36 | 
             
            mdurl==0.1.2
         | 
| 37 | 
            -
             | 
| 38 | 
            -
             | 
| 39 | 
            -
             | 
| 40 | 
            -
            networkx==3.1
         | 
| 41 | 
            -
            nh3==0.2.14
         | 
| 42 | 
            -
            nltk==3.8.1
         | 
| 43 | 
            -
            numba==0.57.1
         | 
| 44 | 
            -
            numpy==1.24.4
         | 
| 45 | 
            -
            onnxruntime==1.15.1
         | 
| 46 | 
            -
            openai-whisper==20230314
         | 
| 47 | 
             
            pandas==2.1.1
         | 
| 48 | 
            -
             | 
| 49 | 
            -
             | 
| 50 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 51 | 
             
            pyarrow==13.0.0
         | 
| 52 | 
             
            pydeck==0.8.1b0
         | 
| 53 | 
            -
             | 
| 54 | 
            -
             | 
| 55 | 
             
            python-dotenv==1.0.0
         | 
| 56 | 
             
            pytz==2023.3.post1
         | 
| 57 | 
            -
             | 
|  | |
|  | |
| 58 | 
             
            referencing==0.30.2
         | 
| 59 | 
            -
            regex==2023.8.8
         | 
| 60 | 
             
            requests==2.31.0
         | 
| 61 | 
            -
             | 
| 62 | 
            -
            rfc3986==2.0.0
         | 
| 63 | 
            -
            rich==13.5.2
         | 
| 64 | 
             
            rpds-py==0.10.4
         | 
| 65 | 
            -
             | 
| 66 | 
            -
            scikit-learn==1.3.0
         | 
| 67 | 
            -
            scipy==1.11.2
         | 
| 68 | 
            -
            sentence-transformers==2.2.2
         | 
| 69 | 
            -
            sentencepiece==0.1.99
         | 
| 70 | 
             
            smmap==5.0.1
         | 
| 71 | 
            -
             | 
| 72 | 
             
            streamlit==1.27.2
         | 
| 73 | 
            -
            sympy==1.12
         | 
| 74 | 
             
            tenacity==8.2.3
         | 
| 75 | 
            -
            threadpoolctl==3.2.0
         | 
| 76 | 
            -
            tiktoken==0.3.1
         | 
| 77 | 
            -
            tokenizers==0.13.3
         | 
| 78 | 
             
            toml==0.10.2
         | 
| 79 | 
             
            toolz==0.12.0
         | 
| 80 | 
            -
             | 
| 81 | 
            -
            torchvision==0.15.2
         | 
| 82 | 
             
            tqdm==4.66.1
         | 
| 83 | 
            -
             | 
| 84 | 
            -
             | 
| 85 | 
            -
            typing_extensions==4.7.1
         | 
| 86 | 
             
            tzdata==2023.3
         | 
| 87 | 
             
            tzlocal==5.1
         | 
| 88 | 
            -
            urllib3==2.0. | 
| 89 | 
             
            validators==0.22.0
         | 
| 90 | 
             
            watchdog==3.0.0
         | 
| 91 | 
            -
             | 
|  | 
|  | |
| 1 | 
             
            altair==5.1.2
         | 
| 2 | 
            +
            asttokens==2.2.1
         | 
| 3 | 
             
            attrs==23.1.0
         | 
| 4 | 
            +
            backcall==0.2.0
         | 
|  | |
| 5 | 
             
            blinker==1.6.3
         | 
| 6 | 
             
            cachetools==5.3.1
         | 
|  | |
| 7 | 
             
            certifi==2023.7.22
         | 
| 8 | 
            +
            charset-normalizer==3.3.0
         | 
| 9 | 
             
            click==8.1.7
         | 
| 10 | 
            +
            colorama==0.4.6
         | 
| 11 | 
            +
            comm==0.1.3
         | 
| 12 | 
            +
            debugpy==1.6.7
         | 
| 13 | 
            +
            decorator==5.1.1
         | 
| 14 | 
            +
            executing==1.2.0
         | 
| 15 | 
            +
            filelock==3.12.4
         | 
| 16 | 
            +
            fsspec==2023.9.2
         | 
|  | |
| 17 | 
             
            gitdb==4.0.10
         | 
| 18 | 
             
            GitPython==3.1.37
         | 
| 19 | 
            +
            huggingface-hub==0.18.0
         | 
|  | |
|  | |
|  | |
| 20 | 
             
            idna==3.4
         | 
| 21 | 
             
            importlib-metadata==6.8.0
         | 
| 22 | 
            +
            ipykernel==6.23.3
         | 
| 23 | 
            +
            ipython==8.14.0
         | 
| 24 | 
            +
            jedi==0.18.2
         | 
| 25 | 
             
            Jinja2==3.1.2
         | 
|  | |
| 26 | 
             
            jsonschema==4.19.1
         | 
| 27 | 
             
            jsonschema-specifications==2023.7.1
         | 
| 28 | 
            +
            jupyter_client==8.3.0
         | 
| 29 | 
            +
            jupyter_core==5.3.1
         | 
| 30 | 
             
            markdown-it-py==3.0.0
         | 
| 31 | 
             
            MarkupSafe==2.1.3
         | 
| 32 | 
            +
            matplotlib-inline==0.1.6
         | 
| 33 | 
             
            mdurl==0.1.2
         | 
| 34 | 
            +
            nest-asyncio==1.5.6
         | 
| 35 | 
            +
            numpy==1.26.0
         | 
| 36 | 
            +
            packaging==23.1
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 37 | 
             
            pandas==2.1.1
         | 
| 38 | 
            +
            parso==0.8.3
         | 
| 39 | 
            +
            pickleshare==0.7.5
         | 
| 40 | 
            +
            Pillow==10.0.1
         | 
| 41 | 
            +
            platformdirs==3.8.0
         | 
| 42 | 
            +
            prompt-toolkit==3.0.38
         | 
| 43 | 
            +
            protobuf==4.24.4
         | 
| 44 | 
            +
            psutil==5.9.5
         | 
| 45 | 
            +
            pure-eval==0.2.2
         | 
| 46 | 
             
            pyarrow==13.0.0
         | 
| 47 | 
             
            pydeck==0.8.1b0
         | 
| 48 | 
            +
            Pygments==2.15.1
         | 
| 49 | 
            +
            python-dateutil==2.8.2
         | 
| 50 | 
             
            python-dotenv==1.0.0
         | 
| 51 | 
             
            pytz==2023.3.post1
         | 
| 52 | 
            +
            pywin32==306
         | 
| 53 | 
            +
            PyYAML==6.0.1
         | 
| 54 | 
            +
            pyzmq==25.1.0
         | 
| 55 | 
             
            referencing==0.30.2
         | 
|  | |
| 56 | 
             
            requests==2.31.0
         | 
| 57 | 
            +
            rich==13.6.0
         | 
|  | |
|  | |
| 58 | 
             
            rpds-py==0.10.4
         | 
| 59 | 
            +
            six==1.16.0
         | 
|  | |
|  | |
|  | |
|  | |
| 60 | 
             
            smmap==5.0.1
         | 
| 61 | 
            +
            stack-data==0.6.2
         | 
| 62 | 
             
            streamlit==1.27.2
         | 
|  | |
| 63 | 
             
            tenacity==8.2.3
         | 
|  | |
|  | |
|  | |
| 64 | 
             
            toml==0.10.2
         | 
| 65 | 
             
            toolz==0.12.0
         | 
| 66 | 
            +
            tornado==6.3.2
         | 
|  | |
| 67 | 
             
            tqdm==4.66.1
         | 
| 68 | 
            +
            traitlets==5.9.0
         | 
| 69 | 
            +
            typing_extensions==4.8.0
         | 
|  | |
| 70 | 
             
            tzdata==2023.3
         | 
| 71 | 
             
            tzlocal==5.1
         | 
| 72 | 
            +
            urllib3==2.0.6
         | 
| 73 | 
             
            validators==0.22.0
         | 
| 74 | 
             
            watchdog==3.0.0
         | 
| 75 | 
            +
            wcwidth==0.2.6
         | 
| 76 | 
            +
            zipp==3.17.0
         |