Annaamalai commited on
Commit
e275bf6
1 Parent(s): 5e28cac

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -17
app.py CHANGED
@@ -1,25 +1,35 @@
1
  import streamlit as st
2
- from transformers import AutoTokenizer, AutoModelForCausalLM
3
 
4
  # Load the model and tokenizer
5
- tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b")
6
- model = AutoModelForCausalLM.from_pretrained("google/gemma-7b", device="cuda" if st.sidebar.checkbox("Use GPU", True) else "cpu")
 
7
 
8
- # Function to generate text based on user input
9
- def generate_text(prompt):
10
- input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
11
- outputs = model.generate(input_ids, max_length=100)
12
- generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
13
- return generated_text
14
 
15
  # Streamlit app
16
- st.title("Text Generation with Google Gemma 7b")
17
 
18
- prompt = st.text_area("Enter your prompt here:", "")
19
- if st.button("Generate Text"):
20
- if prompt:
21
- generated_text = generate_text(prompt)
22
- st.write("Generated Text:")
23
- st.write(generated_text)
 
 
24
  else:
25
- st.warning("Please enter a prompt.")
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
 
4
  # Load the model and tokenizer
5
+ model_id = "mistralai/Mixtral-8x7B-Instruct-v0.1"
6
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
7
+ model = AutoModelForCausalLM.from_pretrained(model_id, device="cuda" if st.sidebar.checkbox("Use GPU", True) else "cpu")
8
 
9
+ # Function to generate responses based on user messages
10
+ def generate_response(messages):
11
+ inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to(model.device)
12
+ outputs = model.generate(inputs, max_new_tokens=100)
13
+ generated_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
14
+ return generated_response
15
 
16
  # Streamlit app
17
+ st.title("Mixtral Chatbot")
18
 
19
+ messages = []
20
+ user_input = st.text_input("You:", "")
21
+
22
+ if st.button("Send"):
23
+ if user_input:
24
+ messages.append({"role": "user", "content": user_input})
25
+ bot_response = generate_response(messages)
26
+ messages.append({"role": "assistant", "content": bot_response})
27
  else:
28
+ st.warning("Please enter a message.")
29
+
30
+ # Display conversation
31
+ for message in messages:
32
+ if message["role"] == "user":
33
+ st.text_input("You:", value=message["content"], disabled=True)
34
+ elif message["role"] == "assistant":
35
+ st.text_area("Mixtral:", value=message["content"], disabled=True)