cha0smagick commited on
Commit
8fcc99d
1 Parent(s): 395b936

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -31
app.py CHANGED
@@ -1,41 +1,56 @@
1
  import streamlit as st
2
- from transformers import AutoTokenizer, AutoModelForImageCaptioning
3
- import requests
4
  from PIL import Image
5
- import numpy as np
 
6
 
7
- # Initialize the tokenizer and model
8
- tokenizer = AutoTokenizer.from_pretrained("microsoft/beit-base-patch16-224-in21k")
9
- model = AutoModelForImageCaptioning.from_pretrained("microsoft/beit-base-patch16-224-in21k")
 
10
 
11
- def generate_caption(image_url):
12
- # Get the image from the URL
13
- image = Image.open(requests.get(image_url, stream=True).raw)
14
-
15
- # Preprocess the image
16
- input_array = np.array(image) / 255.0
17
- input_array = np.transpose(input_array, (2, 0, 1))
18
- input_ids = tokenizer(image_url, return_tensors="pt").input_ids
19
-
20
- # Generate the caption
21
- output = model.generate(input_ids, max_length=20)
22
- caption = tokenizer.batch_decode(output, skip_special_tokens=True)
23
 
24
- return caption[0]
25
 
 
26
  def main():
27
- # Create a sidebar for the user to input the image URL
28
- st.sidebar.header("Image Caption Generator")
29
- image_url = st.sidebar.text_input("Enter the URL of an image:")
30
-
31
- # Generate the caption if the user clicks the button
32
- if st.sidebar.button("Generate Caption"):
33
- if image_url != "":
34
- caption = generate_caption(image_url)
35
- st.success(f"Caption: {caption}")
36
- else:
37
- st.error("Please enter a valid image URL.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
- # Run the main function
40
  if __name__ == "__main__":
41
  main()
 
1
  import streamlit as st
 
 
2
  from PIL import Image
3
+ import textwrap
4
+ import google.generativeai as genai
5
 
6
+ # Function to display formatted Markdown text
7
+ def to_markdown(text):
8
+ text = text.replace('•', ' *')
9
+ return textwrap.indent(text, '> ', predicate=lambda _: True)
10
 
11
+ # Function to generate content using Gemini API
12
+ def generate_gemini_content(prompt, model_name='gemini-pro', image=None):
13
+ model = genai.GenerativeModel(model_name)
14
+ if image:
15
+ response = model.generate_content([prompt, image])
16
+ else:
17
+ response = model.generate_content(prompt)
 
 
 
 
 
18
 
19
+ return response
20
 
21
+ # Streamlit app
22
  def main():
23
+ st.title("Gemini API Demo with Streamlit")
24
+
25
+ # Get Gemini API key from user input
26
+ api_key = st.text_input("Enter your Gemini API key:")
27
+ genai.configure(api_key=api_key)
28
+
29
+ # Choose a model
30
+ model_name = st.selectbox("Select a Gemini model", ["gemini-pro", "gemini-pro-vision"])
31
+
32
+ # Get user input prompt
33
+ prompt = st.text_area("Enter your prompt:")
34
+
35
+ # Get optional image input
36
+ image_file = st.file_uploader("Upload an image (if applicable):", type=["jpg", "jpeg", "png"])
37
+
38
+ # Display image if provided
39
+ if image_file:
40
+ st.image(image_file, caption="Uploaded Image", use_column_width=True)
41
+
42
+ # Generate content on button click
43
+ if st.button("Generate Content"):
44
+ st.markdown("### Generated Content:")
45
+ if image_file:
46
+ # If an image is provided, use gemini-pro-vision model
47
+ image = Image.open(image_file)
48
+ response = generate_gemini_content(prompt, model_name='gemini-pro-vision', image=image)
49
+ else:
50
+ response = generate_gemini_content(prompt, model_name=model_name)
51
+
52
+ # Display the generated content in Markdown format
53
+ st.markdown(to_markdown(response.text))
54
 
 
55
  if __name__ == "__main__":
56
  main()