|
import streamlit as st |
|
from PIL import Image |
|
import requests |
|
import torch |
|
import random |
|
import numpy as np |
|
from diffusers import StableDiffusionInstructPix2PixPipeline, EulerAncestralDiscreteScheduler |
|
|
|
|
|
model_id = "/home/gopinath28031995/yashwanth/projects/watermark_env/instruction-tuned-sd/woman-avatar-gen" |
|
pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(model_id, torch_dtype=torch.float16) |
|
pipe.to("cuda") |
|
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config) |
|
|
|
|
|
def download_image(url): |
|
image = Image.open(requests.get(url, stream=True).raw) |
|
|
|
if hasattr(image, '_getexif'): |
|
exif = image._getexif() |
|
if exif is not None: |
|
orientation = exif.get(0x0112) |
|
if orientation is not None: |
|
if orientation == 3: |
|
image = image.rotate(180, expand=True) |
|
elif orientation == 6: |
|
image = image.rotate(270, expand=True) |
|
elif orientation == 8: |
|
image = image.rotate(90, expand=True) |
|
image = image.convert("RGB") |
|
return image |
|
|
|
|
|
st.title("Instruct Pix2Pix Image Generation") |
|
|
|
|
|
uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"]) |
|
image_url = st.text_input("Enter image URL") |
|
|
|
|
|
prompt = st.text_input("Enter prompt", "Generate a fantasy version, retain hair and facial features, 8k") |
|
|
|
|
|
seed = st.number_input("Seed", value=42, step=1) |
|
num_inference_steps = st.number_input("Number of Inference Steps", value=300, step=10, min_value=0) |
|
text_cfg_scale = st.number_input("Text CFG Scale", value=3.0, step=0.1, min_value=0.0) |
|
image_cfg_scale = st.number_input("Image CFG Scale", value=7.5, step=0.1, min_value=0.0) |
|
|
|
if uploaded_file is not None: |
|
|
|
image = Image.open(uploaded_file) |
|
st.image(image, caption="Uploaded Image", use_column_width=True) |
|
elif image_url: |
|
|
|
try: |
|
image = download_image(image_url) |
|
st.image(image, caption="Image from URL", use_column_width=True) |
|
except Exception as e: |
|
st.error("Error downloading image from URL. Please make sure the URL is correct.") |
|
else: |
|
|
|
url = "https://raw.githubusercontent.com/timothybrooks/instruct-pix2pix/main/imgs/example.jpg" |
|
st.write("Using default image.") |
|
image = download_image(url) |
|
|
|
|
|
if st.button("Generate"): |
|
|
|
generated_images = pipe(prompt, |
|
image=image, |
|
num_inference_steps=num_inference_steps, |
|
image_cfg=image_cfg_scale, |
|
text_cfg_scale=text_cfg_scale, |
|
seed=seed) |
|
st.image(generated_images[0], caption="Generated Image", use_column_width=True) |
|
|