GemmaTest / src /streamlit_app.py
daniloedu's picture
Update src/streamlit_app.py
10bc8c4 verified
import streamlit as st
from transformers import pipeline
from PIL import Image
import torch
import os
# Set cache directory to avoid permission issues
os.environ["TRANSFORMERS_CACHE"] = "/app/cache/transformers"
os.environ["HF_HOME"] = "/app/cache/hf"
os.environ["HF_HUB_CACHE"] = "/app/cache/hf"
# Set HF token from environment
hf_token = os.getenv("HF_TOKEN")
if hf_token:
os.environ["HUGGINGFACE_HUB_TOKEN"] = hf_token
# Set page config
st.set_page_config(
page_title="Gemma-3n E4B Vision-Language Model",
page_icon="πŸ€–",
layout="wide"
)
@st.cache_resource
def load_model():
"""Load the model pipeline with caching"""
try:
# Check if token is available
if not hf_token:
st.error("HF_TOKEN not found in environment variables")
return None
# Use pipeline approach which is more compatible
pipe = pipeline(
"image-text-to-text",
model="google/gemma-3n-E4B-it",
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto" if torch.cuda.is_available() else "cpu",
token=hf_token # Pass token directly to pipeline
)
return pipe
except Exception as e:
st.error(f"Error loading model: {str(e)}")
st.error("Make sure you have access to the model and your token is valid.")
return None
def generate_response(pipe, image, text_prompt, max_tokens=100):
"""Generate response from the model"""
try:
# Prepare messages in the expected format
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": text_prompt}
]
}
]
# Generate response using pipeline
response = pipe(messages, max_new_tokens=max_tokens)
# Extract text from response
if isinstance(response, list) and len(response) > 0:
if isinstance(response[0], dict) and 'generated_text' in response[0]:
return response[0]['generated_text']
elif isinstance(response[0], str):
return response[0]
return str(response)
except Exception as e:
return f"Error generating response: {str(e)}"
def main():
st.title("πŸ€– Gemma-3n E4B Vision-Language Model")
st.markdown("Upload an image and ask questions about it!")
# Check if token is available
if not hf_token:
st.error("❌ HuggingFace token not found in environment variables.")
st.markdown("""
**To fix this:**
1. Go to your Space settings (βš™οΈ icon)
2. Navigate to "Repository secrets"
3. Add a secret with name: `HF_TOKEN`
4. Value: Your HuggingFace token
5. Restart the Space
""")
return
else:
st.success("βœ… HuggingFace token found!")
# Check if user is authenticated
st.sidebar.markdown("### πŸ“‹ Setup Status")
st.sidebar.markdown(f"""
βœ… **Token**: Found in environment
Make sure you have:
1. βœ… Access to the gated model
2. βœ… Added your HF token to Space secrets
3. βœ… Token has proper permissions
""")
# Load model
with st.spinner("Loading model... This may take a few minutes on first run."):
pipe = load_model()
if pipe is None:
st.error("Failed to load model. Please check your setup and try again.")
return
st.success("Model loaded successfully!")
# Create two columns
col1, col2 = st.columns([1, 1])
with col1:
st.subheader("πŸ“€ Input")
# Image upload
uploaded_file = st.file_uploader(
"Choose an image...",
type=['png', 'jpg', 'jpeg', 'gif', 'bmp'],
help="Upload an image to analyze"
)
# Text input
text_prompt = st.text_area(
"Ask a question about the image:",
placeholder="What do you see in this image?",
height=100
)
# Generation parameters
max_tokens = st.slider(
"Max tokens to generate:",
min_value=10,
max_value=200,
value=100,
help="Maximum number of tokens to generate"
)
# Generate button
generate_btn = st.button("πŸš€ Generate Response", type="primary")
with col2:
st.subheader("πŸ“€ Output")
if uploaded_file is not None:
# Display uploaded image
image = Image.open(uploaded_file)
st.image(image, caption="Uploaded image", use_column_width=True)
# Generate response when button is clicked
if generate_btn:
if not text_prompt.strip():
st.warning("Please enter a question about the image.")
else:
with st.spinner("Generating response..."):
response = generate_response(
pipe, image, text_prompt, max_tokens
)
st.subheader("πŸ€– Model Response:")
st.write(response)
else:
st.info("πŸ‘† Please upload an image to get started")
# Example section
st.markdown("---")
st.subheader("πŸ’‘ Example Questions to Try:")
st.markdown("""
- What objects do you see in this image?
- Describe the scene in detail
- What colors are present in the image?
- What is the main subject of this image?
- Can you identify any text in this image?
""")
# Footer
st.markdown("---")
st.markdown(
"Built with ❀️ using [Streamlit](https://streamlit.io) and "
"[Hugging Face Transformers](https://huggingface.co/transformers/)"
)
if __name__ == "__main__":
main()