|
|
from dotenv import load_dotenv |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
import streamlit as st |
|
|
import os |
|
|
import google.generativeai as genai |
|
|
import replicate |
|
|
from PIL import Image |
|
|
import time |
|
|
|
|
|
|
|
|
genai.configure(api_key=os.getenv("GOOGLE_API_KEY")) |
|
|
|
|
|
|
|
|
|
|
|
vis_model = genai.GenerativeModel("gemini-pro-vision") |
|
|
langugae_model = genai.GenerativeModel("gemini-pro") |
|
|
|
|
|
|
|
|
def get_gemini_response(input, image): |
|
|
|
|
|
if (input != "") and (image != ""): |
|
|
response = vis_model.generate_content([input, image]) |
|
|
elif (input != "") and (image == ""): |
|
|
response = langugae_model.generate_content(input) |
|
|
else: |
|
|
response = vis_model.generate_content(image) |
|
|
|
|
|
return response.text |
|
|
|
|
|
def stream_data(prompt, image): |
|
|
|
|
|
sentences = get_gemini_response(prompt, image).split(". ") |
|
|
|
|
|
for sentence in sentences: |
|
|
for word in sentence.split(): |
|
|
yield word + " " |
|
|
time.sleep(0.02) |
|
|
|
|
|
|
|
|
st.set_page_config( |
|
|
page_title="Multimodal Content Generation", |
|
|
page_icon="β‘οΈ", |
|
|
layout="wide" |
|
|
) |
|
|
|
|
|
|
|
|
st.sidebar.title(":rainbow[MULTIMODAL CONTENT GENERATION]") |
|
|
st.sidebar.write("Built by [jaiminjariwala](https://github.com/jaiminjariwala) π") |
|
|
st.sidebar.divider() |
|
|
|
|
|
|
|
|
multimodal_options = st.sidebar.radio( |
|
|
"**Select What To Do...**", |
|
|
options=["Chat and Image Summarization", "Text 2 Image"], |
|
|
index=0, |
|
|
horizontal=False, |
|
|
) |
|
|
st.sidebar.divider() |
|
|
|
|
|
|
|
|
|
|
|
if multimodal_options == "Chat and Image Summarization": |
|
|
|
|
|
|
|
|
if st.sidebar.button("Get **New Chat** Fresh Page"): |
|
|
st.session_state["messages"] = [] |
|
|
st.experimental_rerun() |
|
|
|
|
|
|
|
|
with st.expander("**Wanna Upload an Image?**"): |
|
|
uploaded_file = st.file_uploader("Choose an image for **Image Summarizer** task...", |
|
|
type=["jpg", "jpeg", "png"]) |
|
|
image="" |
|
|
if uploaded_file is not None: |
|
|
image=Image.open(uploaded_file) |
|
|
st.image(image, caption="Uploaded Image.", use_column_width=True) |
|
|
|
|
|
|
|
|
|
|
|
chat_container = st.container() |
|
|
|
|
|
|
|
|
if "messages" not in st.session_state: |
|
|
st.session_state.messages = [] |
|
|
|
|
|
|
|
|
for message in st.session_state.messages: |
|
|
with chat_container: |
|
|
with st.chat_message(message["role"]): |
|
|
st.markdown(message["content"]) |
|
|
|
|
|
|
|
|
if prompt := st.chat_input("Type here..."): |
|
|
|
|
|
|
|
|
with chat_container: |
|
|
with st.chat_message("user"): |
|
|
st.markdown(prompt) |
|
|
|
|
|
st.session_state.messages.append({"role" : "user", |
|
|
"content" : prompt}) |
|
|
|
|
|
|
|
|
|
|
|
with chat_container: |
|
|
with st.chat_message("assistant"): |
|
|
|
|
|
should_format_as_code = any(keyword in prompt.lower() for keyword in ["code", "python", "java", "javascript", "c++", "c", "program", "react", "reactjs", "node", "nodejs", "html", "css", "javascript", "js"]) |
|
|
|
|
|
if should_format_as_code: |
|
|
st.code(get_gemini_response(prompt, image)) |
|
|
else: |
|
|
st.write_stream(stream_data(prompt, image)) |
|
|
|
|
|
|
|
|
st.session_state.messages.append({"role": "assistant", "content": get_gemini_response(prompt, image)}) |
|
|
|
|
|
|
|
|
|
|
|
def generate_and_display_image(submitted: bool, width: int, height: int, num_outputs: int, scheduler: str, num_inference_steps: int, prompt_strength: float, prompt: str): |
|
|
""" |
|
|
Generates an image using the specified prompt and parameters. |
|
|
""" |
|
|
|
|
|
if REPLICATE_API_TOKEN.startswith('r8_') and submitted and prompt: |
|
|
with st.status('Generating your image...', expanded=True) as status: |
|
|
try: |
|
|
|
|
|
if submitted: |
|
|
all_images = [] |
|
|
|
|
|
output = replicate.run( |
|
|
"stability-ai/sdxl:39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b", |
|
|
input={ |
|
|
"prompt": prompt, |
|
|
"width": width, |
|
|
"height": height, |
|
|
"num_outputs": num_outputs, |
|
|
"scheduler": "K_EULER", |
|
|
"num_inference_steps": num_inference_steps, |
|
|
"guidance_scale": 7.5, |
|
|
"prompt_stregth": prompt_strength, |
|
|
"negative_prompt": "the absolute worst quality, distorted features", |
|
|
"refine": "expert_ensemble_refiner", |
|
|
"high_noise_frac": 0.8 |
|
|
} |
|
|
) |
|
|
if output: |
|
|
st.toast( |
|
|
'Your image has been generated!', icon='π') |
|
|
|
|
|
st.session_state.generated_image = output |
|
|
|
|
|
|
|
|
for image in st.session_state.generated_image: |
|
|
with st.container(): |
|
|
st.image(image, caption="Generated Image βοΈ", |
|
|
use_column_width=True) |
|
|
|
|
|
all_images.append(image) |
|
|
|
|
|
|
|
|
st.session_state.all_images = all_images |
|
|
|
|
|
except replicate.exceptions.ReplicateError as e: |
|
|
st.error(f"Error generating image: {e}") |
|
|
|
|
|
|
|
|
elif not prompt: |
|
|
st.toast("Please input some prompt!", icon="β οΈ") |
|
|
|
|
|
|
|
|
def refine_output(): |
|
|
""" |
|
|
Provides options for users to refine output parameters and returns them. |
|
|
""" |
|
|
|
|
|
with st.expander("**Refine your output if you want...**"): |
|
|
width = st.number_input("Width of output image", value=1024) |
|
|
height = st.number_input("Height of output image", value=1024) |
|
|
|
|
|
num_outputs = st.slider("Number of images to output", value=1, min_value=1, max_value=4) |
|
|
|
|
|
scheduler = st.selectbox('Scheduler', ('DDIM', 'DPMSolverMultistep', 'HeunDiscrete', 'KarrasDPM', 'K_EULER_ANCESTRAL', 'K_EULER', 'PNDM')) |
|
|
|
|
|
num_inference_steps = st.slider( |
|
|
"Number of denoising steps", value=50, min_value=1, max_value=500) |
|
|
|
|
|
prompt_strength = st.slider( |
|
|
"Prompt strength when using img2img/inpaint (1.0 corresponds to full destruction of information in image)", value=0.8, max_value=1.0, step=0.1) |
|
|
|
|
|
|
|
|
prompt = st.text_input("Enter your prompt for the image:", |
|
|
value="Dog and cat dancing on moon") |
|
|
|
|
|
|
|
|
submitted = st.button("Generate") |
|
|
|
|
|
return submitted, width, height, num_outputs, scheduler, num_inference_steps, prompt_strength, prompt |
|
|
|
|
|
|
|
|
|
|
|
if multimodal_options == "Text 2 Image": |
|
|
|
|
|
REPLICATE_API_TOKEN=st.sidebar.text_input( |
|
|
"Enter your REPLICATE API TOKEN", |
|
|
placeholder="Paste token here...", |
|
|
type="password" |
|
|
) |
|
|
os.environ["REPLICATE_API_TOKEN"]=REPLICATE_API_TOKEN |
|
|
|
|
|
if not REPLICATE_API_TOKEN.startswith('r8_'): |
|
|
st.warning('Please enter your REPLICATE API KEY in **Sidebar**!', icon='β ') |
|
|
|
|
|
width, height, num_outputs, scheduler, num_inference_steps, prompt_strength, prompt, submitted = refine_output() |
|
|
generate_and_display_image(width, height, num_outputs, scheduler, num_inference_steps, prompt_strength, prompt, submitted) |
|
|
|
|
|
|