disham993 commited on
Commit
ec44ead
1 Parent(s): c04fddb

Gemini Streamlit Application.

Browse files
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ **/*__pycache__
2
+ **/*.env
app.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Q&A Chatbot
2
+ from utils import *
3
+ from ui_files import *
4
+
5
+ # Headers of the app
6
+ initial_headers()
7
+
8
+ # Handle media upload
9
+ media_content, media_type = handle_media_upload()
10
+
11
+ # Handle JSON file upload for authentication
12
+ api_key = handle_credentials(media_type=media_type)
13
+
14
+ # Handle input fields
15
+ prompts = manage_input_fields()
16
+
17
+ submit = st.button(f"Tell me about the {media_type}")
18
+
19
+ # Configure generation and safety settings
20
+ generation_config, safety_settings = configure_generation_and_safety(
21
+ SAFETY_SETTINGS, THRESHOLD_OPTIONS
22
+ )
23
+
24
+ ## If ask button is clicked
25
+ if submit:
26
+ print(f"Response being generated...")
27
+ st.subheader("The Response as follows...")
28
+ start_time = time.time()
29
+
30
+ if media_type == "video":
31
+ final_safety_settings = {}
32
+ for setting in safety_settings:
33
+ final_safety_settings[
34
+ SAFETY_SETTINGS_VIDEO_LABELS[setting["category"]]
35
+ ] = THRESHOLD_OPTIONS_VIDEO_LABELS[setting["threshold"]]
36
+ else:
37
+ final_safety_settings = safety_settings
38
+
39
+ response = get_gemini_response(
40
+ prompts,
41
+ media_content=media_content,
42
+ generation_config=generation_config,
43
+ media_type=media_type,
44
+ safety_settings=final_safety_settings,
45
+ api_key=api_key,
46
+ )
47
+ for chunk in response:
48
+ print(chunk.text) # For Debugging
49
+ st.write(chunk.text)
50
+
51
+ if os.path.exists("tmp/json_data.json"):
52
+ os.remove("tmp/json_data.json")
53
+
54
+ st.write(f"Time taken to generate results: {time.time() - start_time:.2f} seconds.")
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ streamlit==1.29.0
2
+ google-generativeai==0.3.0
3
+ google-ai-generativelanguage==0.4.0
4
+ python-dotenv==1.0.0
5
+ google-cloud-aiplatform==1.38.1
ui_files/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from ui_files.media_handling import *
2
+ from ui_files.authentication import *
3
+ from ui_files.settings import *
4
+ from ui_files.initial_headers import *
5
+ from ui_files.user_input_handler import *
ui_files/authentication.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+ from os.path import dirname as up
3
+
4
+ sys.path.append(os.path.abspath(os.path.join(up(__file__), os.pardir)))
5
+
6
+ from utils import *
7
+
8
+
9
+ def handle_credentials(media_type: str = "image"):
10
+ if media_type == "image":
11
+ api_key = st.text_input(
12
+ "🔐 GOOGLE AI STUDIO API KEY - Required For Image.", key="api_key"
13
+ )
14
+ return api_key
15
+
16
+ elif media_type == "video":
17
+ uploaded_json = st.file_uploader(
18
+ "🔐 Upload a JSON file which includes Google Service Account Credentials - Required for Video.",
19
+ type=["json"],
20
+ )
21
+
22
+ if uploaded_json is not None:
23
+ json_data = json.load(uploaded_json)
24
+ os.makedirs("tmp", exist_ok=True)
25
+ json_path = os.path.join("tmp", "json_data.json")
26
+ with open(json_path, "w") as file:
27
+ json.dump(json_data, file)
28
+
29
+ os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = json_path
30
+ service_account.Credentials.from_service_account_info(json_data)
31
+ st.success(
32
+ "Environment variable GOOGLE_APPLICATION_CREDENTIALS set from JSON file."
33
+ )
ui_files/initial_headers.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+ from os.path import dirname as up
3
+
4
+ sys.path.append(os.path.abspath(os.path.join(up(__file__), os.pardir)))
5
+
6
+ from utils import *
7
+
8
+
9
+ def initial_headers():
10
+ st.set_page_config(page_title="Gemini Image & Video Demo")
11
+ st.header("Gemini Application - Image & Video Demo")
12
+
13
+ st.write(
14
+ "This app is to be used to ask questions on image and video that will be uploaded."
15
+ )
ui_files/media_handling.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+ from os.path import dirname as up
3
+
4
+ sys.path.append(os.path.abspath(os.path.join(up(__file__), os.pardir)))
5
+
6
+ from utils import *
7
+
8
+
9
+ def handle_media_upload():
10
+ uploaded_file = st.file_uploader(
11
+ "**Drag and drop or upload an Image 🖼️ or a Video 📺**",
12
+ type=["jpg", "jpeg", "png", "mp4"],
13
+ )
14
+ media_content = ""
15
+ media_type = "image"
16
+
17
+ if uploaded_file is not None:
18
+ if uploaded_file.type.startswith("image/"):
19
+ media_content = Image.open(uploaded_file)
20
+ media_content = media_content.resize((500, 500))
21
+ st.image(media_content, caption="Uploaded Image.", use_column_width=True)
22
+
23
+ if uploaded_file.type.startswith("video/"):
24
+ file_bytes = uploaded_file.read()
25
+ data = base64.b64encode(file_bytes)
26
+ media_content = Part.from_data(
27
+ data=base64.b64decode(data), mime_type="video/mp4"
28
+ )
29
+ st.video(uploaded_file)
30
+ media_type = "video"
31
+
32
+ return media_content, media_type
ui_files/settings.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+ from os.path import dirname as up
3
+
4
+ sys.path.append(os.path.abspath(os.path.join(up(__file__), os.pardir)))
5
+
6
+ from utils import *
7
+
8
+
9
+ def configure_generation_and_safety(safety_settings, threshold_options):
10
+ # Add sliders for temperature, top_p, top_k, and max_output_tokens
11
+ st.sidebar.header("Generation Configuration")
12
+ temperature = st.sidebar.slider(
13
+ "Temperature", min_value=0.0, max_value=1.0, value=0.7, step=0.01
14
+ )
15
+ top_p = st.sidebar.slider(
16
+ "Top P", min_value=0.0, max_value=1.0, value=0.9, step=0.01
17
+ )
18
+ top_k = st.sidebar.slider("Top K", min_value=0, max_value=100, value=40, step=1)
19
+ max_output_tokens = st.sidebar.slider(
20
+ "Max Output Tokens", min_value=1, max_value=4096, value=1024, step=1
21
+ )
22
+
23
+ generation_config = {
24
+ "temperature": temperature,
25
+ "top_p": top_p,
26
+ "top_k": top_k,
27
+ "max_output_tokens": max_output_tokens,
28
+ }
29
+
30
+ # Sidebar for safety settings
31
+ st.sidebar.header("Safety Settings")
32
+
33
+ # Create a dropdown for each category
34
+ for setting in safety_settings:
35
+ setting["threshold"] = st.sidebar.selectbox(
36
+ f"{setting['category']}",
37
+ threshold_options,
38
+ index=threshold_options.index(setting["threshold"]),
39
+ )
40
+
41
+ return generation_config, safety_settings
ui_files/user_input_handler.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+ from os.path import dirname as up
3
+
4
+ sys.path.append(os.path.abspath(os.path.join(up(__file__), os.pardir)))
5
+
6
+ from utils import *
7
+
8
+
9
+ def manage_input_fields():
10
+ st.write("Enter a single or multiple prompts.")
11
+
12
+ # Initialize session state variables if they don't exist
13
+ if "input_list" not in st.session_state:
14
+ st.session_state.input_list = [""]
15
+
16
+ # Function to add a new input field
17
+ def add_input():
18
+ st.session_state.input_list.append("")
19
+
20
+ # Function to remove an input field
21
+ def remove_input(index):
22
+ st.session_state.input_list.pop(index)
23
+
24
+ # Display the input fields
25
+ for index, value in enumerate(st.session_state.input_list):
26
+ col1, col2 = st.columns([4, 1])
27
+ with col1:
28
+ st.session_state.input_list[index] = st.text_input(
29
+ f"Input Prompt: {index+1}", value=value
30
+ )
31
+ with col2:
32
+ st.button(
33
+ "Remove", key=f"remove_{index}", on_click=remove_input, args=(index,)
34
+ )
35
+
36
+ # Button to add new input field
37
+ st.button("Add new input", on_click=add_input)
38
+
39
+ return st.session_state.input_list
utils/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from utils.helper import *
2
+ from utils.constants import *
utils/constants.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+ from os.path import dirname as up
3
+
4
+ sys.path.append(os.path.abspath(os.path.join(up(__file__), os.pardir)))
5
+
6
+ from utils.helper import *
7
+
8
+ SAFETY_SETTINGS = [
9
+ {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_ONLY_HIGH"},
10
+ {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_ONLY_HIGH"},
11
+ {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_ONLY_HIGH"},
12
+ {
13
+ "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
14
+ "threshold": "BLOCK_ONLY_HIGH",
15
+ },
16
+ ]
17
+
18
+ # Threshold options
19
+ THRESHOLD_OPTIONS = [
20
+ "BLOCK_NONE",
21
+ "BLOCK_ONLY_HIGH",
22
+ "BLOCK_MEDIUM_AND_ABOVE",
23
+ "BLOCK_LOW_AND_ABOVE",
24
+ ]
25
+
26
+ SAFETY_SETTINGS_VIDEO_LABELS = {
27
+ "HARM_CATEGORY_HARASSMENT": HarmCategory.HARM_CATEGORY_HARASSMENT,
28
+ "HARM_CATEGORY_HATE_SPEECH": HarmCategory.HARM_CATEGORY_HATE_SPEECH,
29
+ "HARM_CATEGORY_SEXUALLY_EXPLICIT": HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
30
+ "HARM_CATEGORY_DANGEROUS_CONTENT": HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
31
+ }
32
+
33
+ THRESHOLD_OPTIONS_VIDEO_LABELS = {
34
+ "BLOCK_NONE": HarmBlockThreshold.BLOCK_NONE,
35
+ "BLOCK_ONLY_HIGH": HarmBlockThreshold.BLOCK_ONLY_HIGH,
36
+ "BLOCK_MEDIUM_AND_ABOVE": HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
37
+ "BLOCK_LOW_AND_ABOVE": HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
38
+ }
utils/helper.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+ from os.path import dirname as up
3
+
4
+ sys.path.append(os.path.abspath(os.path.join(up(__file__), os.pardir)))
5
+
6
+ import streamlit as st
7
+ import os
8
+ import google.generativeai as genai
9
+ import pathlib
10
+ import textwrap
11
+ from PIL import Image
12
+
13
+ import json
14
+
15
+ from vertexai.preview.generative_models import (
16
+ GenerativeModel,
17
+ Part,
18
+ HarmCategory,
19
+ HarmBlockThreshold,
20
+ )
21
+ from google.oauth2 import service_account # importing auth using service_account
22
+ import json
23
+
24
+ import os
25
+ import base64
26
+
27
+ import time
28
+ from enum import Enum
29
+ from typing import Union, List, Any, Dict
30
+
31
+
32
+ ## Function to load OpenAI model and get respones
33
+ def get_gemini_response(
34
+ input: Union[str, List[str]],
35
+ media_content: Any,
36
+ generation_config: Dict,
37
+ safety_settings: Union[List[Dict], Dict],
38
+ media_type: str = "image",
39
+ api_key: str = None,
40
+ ):
41
+ print(f"Safety Settings: {safety_settings}")
42
+ print(f"Generation Config: {generation_config}") # -> For Debugging
43
+ if media_type == "video":
44
+ print(f"Media type is video.")
45
+
46
+ model = GenerativeModel(
47
+ model_name="gemini-pro-vision",
48
+ generation_config=generation_config,
49
+ safety_settings=safety_settings,
50
+ )
51
+ else:
52
+ print(f"Media type is image.")
53
+ genai.configure(api_key=api_key)
54
+ model = genai.GenerativeModel(
55
+ "gemini-pro-vision",
56
+ generation_config=generation_config,
57
+ safety_settings=safety_settings,
58
+ )
59
+
60
+ if input != "":
61
+ # For debugging
62
+ # with open("tmp/input.txt", "w") as f:
63
+ # f.write(str(media_content))
64
+ response = model.generate_content(input + [media_content], stream=True)
65
+ else:
66
+ response = model.generate_content(media_content, stream=True)
67
+
68
+ return response