awacke1 commited on
Commit
21993d9
·
verified ·
1 Parent(s): 46be881

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -16
app.py CHANGED
@@ -1,23 +1,59 @@
1
- # Load model directly
2
  import streamlit as st
3
- import transformers
4
 
5
- st.title("A Simple Interface for a Language Model")
 
6
 
7
- st.subheader("Input Text")
8
- input_text = st.text_area("Enter your text here", "Type something here...")
 
 
 
 
 
 
9
 
10
- if st.button("Generate Response"):
11
- # Initialize tokenizer and model
12
- tokenizer = transformers.AutoTokenizer.from_pretrained("microsoft/phi-2")
13
- model = transformers.AutoModelForCausalLM.from_pretrained("microsoft/phi-2")
14
 
15
- # Encode input text
16
- inputs = tokenizer(input_text, return_tensors="pt")
 
 
 
 
 
 
17
 
18
- # Generate response
19
- response = model.generate(**inputs, max_length=100, do_sample=True)
20
 
21
- # Decode response
22
- st.subheader("Generated Response")
23
- st.write(tokenizer.decode(response[0]))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ from huggingface_hub import InferenceClient
3
 
4
+ # Initialize the client
5
+ client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
6
 
7
+ # Function to format the prompt
8
+ def format_prompt(message, history):
9
+ prompt = "<s>"
10
+ for user_prompt, bot_response in history:
11
+ prompt += f"[INST] {user_prompt} [/INST]"
12
+ prompt += f" {bot_response}</s> "
13
+ prompt += f"[INST] {message} [/INST]"
14
+ return prompt
15
 
16
+ # Function to generate response
17
+ def generate(prompt, history, temperature=0.2, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0):
18
+ temperature = max(float(temperature), 1e-2)
19
+ top_p = float(top_p)
20
 
21
+ generate_kwargs = dict(
22
+ temperature=temperature,
23
+ max_new_tokens=max_new_tokens,
24
+ top_p=top_p,
25
+ repetition_penalty=repetition_penalty,
26
+ do_sample=True,
27
+ seed=42,
28
+ )
29
 
30
+ formatted_prompt = format_prompt(prompt, history)
 
31
 
32
+ stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
33
+ output = ""
34
+ for response in stream:
35
+ output += response.token.text
36
+ return output
37
+
38
+ # Streamlit interface
39
+ st.title("Mistral 8x7b Chat")
40
+
41
+ # Chat history
42
+ if 'history' not in st.session_state:
43
+ st.session_state.history = []
44
+
45
+ # User input
46
+ user_input = st.text_input("Your message:", key="user_input")
47
+
48
+ # Generate response and update history
49
+ if st.button("Send"):
50
+ if user_input:
51
+ bot_response = generate(user_input, st.session_state.history)
52
+ st.session_state.history.append((user_input, bot_response))
53
+ st.session_state.user_input = ""
54
+
55
+ # Display conversation
56
+ chat_text = ""
57
+ for user_msg, bot_msg in st.session_state.history:
58
+ chat_text += f"You: {user_msg}\nBot: {bot_msg}\n\n"
59
+ st.text_area("Chat", value=chat_text, height=300, disabled=True)