Ravi theja K commited on
Commit
3f8d992
1 Parent(s): 7643658

Create simple_app.py

Browse files
Files changed (1) hide show
  1. simple_app.py +95 -0
simple_app.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import replicate
3
+ import os
4
+ from transformers import AutoTokenizer
5
+
6
+ # Set assistant icon to Snowflake logo
7
+ icons = {"assistant": "./Snowflake_Logomark_blue.svg", "user": "⛷️"}
8
+
9
+ # App title
10
+ st.set_page_config(page_title="Snowflake Arctic")
11
+
12
+ # Replicate Credentials
13
+ with st.sidebar:
14
+ st.title('Snowflake Arctic')
15
+ if 'REPLICATE_API_TOKEN' in st.secrets:
16
+ replicate_api = st.secrets['REPLICATE_API_TOKEN']
17
+ else:
18
+ replicate_api = st.text_input('Enter Replicate API token:', type='password')
19
+ if not (replicate_api.startswith('r8_') and len(replicate_api)==40):
20
+ st.warning('Please enter your Replicate API token.', icon='⚠️')
21
+ st.markdown("**Don't have an API token?** Head over to [Replicate](https://replicate.com) to sign up for one.")
22
+
23
+ os.environ['REPLICATE_API_TOKEN'] = replicate_api
24
+ st.subheader("Adjust model parameters")
25
+ temperature = st.sidebar.slider('temperature', min_value=0.01, max_value=5.0, value=0.3, step=0.01)
26
+ top_p = st.sidebar.slider('top_p', min_value=0.01, max_value=1.0, value=0.9, step=0.01)
27
+
28
+ # Store LLM-generated responses
29
+ if "messages" not in st.session_state.keys():
30
+ st.session_state.messages = [{"role": "assistant", "content": "Hi. I'm Arctic, a new, efficient, intelligent, and truly open language model created by Snowflake AI Research. Ask me anything."}]
31
+
32
+ # Display or clear chat messages
33
+ for message in st.session_state.messages:
34
+ with st.chat_message(message["role"], avatar=icons[message["role"]]):
35
+ st.write(message["content"])
36
+
37
+ def clear_chat_history():
38
+ st.session_state.messages = [{"role": "assistant", "content": "Hi. I'm Arctic, a new, efficient, intelligent, and truly open language model created by Snowflake AI Research. Ask me anything."}]
39
+
40
+ st.sidebar.button('Clear chat history', on_click=clear_chat_history)
41
+ st.sidebar.caption('Built by [Snowflake](https://snowflake.com/) to demonstrate [Snowflake Arctic](https://www.snowflake.com/blog/arctic-open-and-efficient-foundation-language-models-snowflake). App hosted on [Streamlit Community Cloud](https://streamlit.io/cloud). Model hosted by [Replicate](https://replicate.com/snowflake/snowflake-arctic-instruct).')
42
+ st.sidebar.caption('Build your own app powered by Arctic and [enter to win](https://arctic-streamlit-hackathon.devpost.com/) $10k in prizes.')
43
+
44
+ @st.cache_resource(show_spinner=False)
45
+ def get_tokenizer():
46
+ """Get a tokenizer to make sure we're not sending too much text
47
+ text to the Model. Eventually we will replace this with ArcticTokenizer
48
+ """
49
+ return AutoTokenizer.from_pretrained("huggyllama/llama-7b")
50
+
51
+ def get_num_tokens(prompt):
52
+ """Get the number of tokens in a given prompt"""
53
+ tokenizer = get_tokenizer()
54
+ tokens = tokenizer.tokenize(prompt)
55
+ return len(tokens)
56
+
57
+ # Function for generating Snowflake Arctic response
58
+ def generate_arctic_response():
59
+ prompt = []
60
+ for dict_message in st.session_state.messages:
61
+ if dict_message["role"] == "user":
62
+ prompt.append("<|im_start|>user\n" + dict_message["content"] + "<|im_end|>")
63
+ else:
64
+ prompt.append("<|im_start|>assistant\n" + dict_message["content"] + "<|im_end|>")
65
+
66
+ prompt.append("<|im_start|>assistant")
67
+ prompt.append("")
68
+ prompt_str = "\n".join(prompt)
69
+
70
+ if get_num_tokens(prompt_str) >= 3072:
71
+ st.error("Conversation length too long. Please keep it under 3072 tokens.")
72
+ st.button('Clear chat history', on_click=clear_chat_history, key="clear_chat_history")
73
+ st.stop()
74
+
75
+ for event in replicate.stream("snowflake/snowflake-arctic-instruct",
76
+ input={"prompt": prompt_str,
77
+ "prompt_template": r"{prompt}",
78
+ "temperature": temperature,
79
+ "top_p": top_p,
80
+ }):
81
+ yield str(event)
82
+
83
+ # User-provided prompt
84
+ if prompt := st.chat_input(disabled=not replicate_api):
85
+ st.session_state.messages.append({"role": "user", "content": prompt})
86
+ with st.chat_message("user", avatar="⛷️"):
87
+ st.write(prompt)
88
+
89
+ # Generate a new response if last message is not from assistant
90
+ if st.session_state.messages[-1]["role"] != "assistant":
91
+ with st.chat_message("assistant", avatar="./Snowflake_Logomark_blue.svg"):
92
+ response = generate_arctic_response()
93
+ full_response = st.write_stream(response)
94
+ message = {"role": "assistant", "content": full_response}
95
+ st.session_state.messages.append(message)