Annaamalai commited on
Commit
5e28cac
1 Parent(s): 03ec5b1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -9
app.py CHANGED
@@ -1,15 +1,25 @@
1
  import streamlit as st
2
- from transformers import pipeline
3
 
4
- # Load the text generation pipeline
5
- pipe = pipeline("text-generation")
 
6
 
7
- # Streamlit UI
8
- st.title("Demo with Hugging Face")
 
 
 
 
9
 
10
- prompt = st.text_area("Enter your text prompt:")
 
11
 
 
12
  if st.button("Generate Text"):
13
- # Generate text based on the prompt
14
- generated_text = pipe(prompt, max_length=100, do_sample=True)[0]["generated_text"]
15
- st.write(generated_text)
 
 
 
 
1
  import streamlit as st
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
 
4
+ # Load the model and tokenizer
5
+ tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b")
6
+ model = AutoModelForCausalLM.from_pretrained("google/gemma-7b", device="cuda" if st.sidebar.checkbox("Use GPU", True) else "cpu")
7
 
8
+ # Function to generate text based on user input
9
+ def generate_text(prompt):
10
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
11
+ outputs = model.generate(input_ids, max_length=100)
12
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
13
+ return generated_text
14
 
15
+ # Streamlit app
16
+ st.title("Text Generation with Google Gemma 7b")
17
 
18
+ prompt = st.text_area("Enter your prompt here:", "")
19
  if st.button("Generate Text"):
20
+ if prompt:
21
+ generated_text = generate_text(prompt)
22
+ st.write("Generated Text:")
23
+ st.write(generated_text)
24
+ else:
25
+ st.warning("Please enter a prompt.")