|
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation |
|
from diffusers import StableDiffusionInpaintPipeline,StableDiffusionPipeline |
|
from PIL import Image |
|
import requests |
|
|
|
import cv2 |
|
import torch |
|
import matplotlib.pyplot as plt |
|
|
|
import io |
|
import requests |
|
|
|
import os |
|
import streamlit as st |
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline |
|
|
|
import gradio as gr |
|
import gc |
|
|
|
|
|
|
|
processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") |
|
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined") |
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
IPmodel_path = "runwayml/stable-diffusion-inpainting" |
|
|
|
IPpipe = StableDiffusionInpaintPipeline.from_pretrained( |
|
IPmodel_path, |
|
use_auth_token= st.secrets["AUTH_TOKEN"] |
|
).to(device) |
|
|
|
trans_tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M") |
|
trans_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M") |
|
|
|
|
|
SDpipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", use_auth_token=st.secrets["AUTH_TOKEN"]).to(device) |
|
|
|
|
|
def create_mask(image, prompt): |
|
inputs = processor(text=[prompt], images=[image], padding="max_length", return_tensors="pt") |
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
|
|
preds = outputs.logits |
|
|
|
filename = f"mask.png" |
|
plt.imsave(filename,torch.sigmoid(preds)) |
|
|
|
gray_image = cv2.cvtColor(cv2.imread(filename), cv2.COLOR_BGR2GRAY) |
|
|
|
(thresh, bw_image) = cv2.threshold(gray_image, 100, 255, cv2.THRESH_BINARY) |
|
|
|
|
|
|
|
|
|
|
|
cv2.cvtColor(bw_image, cv2.COLOR_BGR2RGB) |
|
|
|
mask = cv2.bitwise_not(bw_image) |
|
cv2.imwrite(filename, mask) |
|
|
|
return Image.open('mask.png') |
|
|
|
|
|
|
|
|
|
def generate_image(image, product_name, target_name): |
|
mask = create_mask(image, product_name) |
|
image = image.resize((512, 512)) |
|
mask = mask.resize((512,512)) |
|
guidance_scale=8 |
|
|
|
num_samples = 4 |
|
|
|
prompt = target_name |
|
generator = torch.Generator(device=device).manual_seed(22) |
|
|
|
im = IPpipe( |
|
prompt=prompt, |
|
image=image, |
|
mask_image=mask, |
|
guidance_scale=guidance_scale, |
|
generator=generator, |
|
).images |
|
|
|
im.enable_attention_slicing() |
|
return im |
|
|
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
|
|
image_label = 'Please upload the image (optional)' |
|
extract_label = 'Specify what needs to be extracted from the above image' |
|
prompt_label = 'Specify the description of image to be generated' |
|
button_label = "Proceed" |
|
output_label = "Results" |
|
|
|
shot_services = ['close-up', 'extreme-closeup', 'POV','medium', 'long'] |
|
shot_label = 'Choose the shot type' |
|
|
|
style_services = ['polaroid', 'monochrome', 'long exposure','color splash', 'Tilt shift'] |
|
style_label = 'Choose the style type' |
|
|
|
lighting_services = ['soft', 'ambivalent', 'ring','sun', 'cinematic'] |
|
lighting_label = 'Choose the lighting type' |
|
|
|
context_services = ['indoor', 'outdoor', 'at night','in the park', 'in the beach','studio'] |
|
context_label = 'Choose the context' |
|
|
|
lens_services = ['wide angle', 'telephoto', '24 mm','EF 70mm', 'Bokeh'] |
|
lens_label = 'Choose the lens type' |
|
|
|
device_services = ['iphone', 'CCTV', 'Nikon ZFX','Canon', 'Gopro'] |
|
device_label = 'Choose the device type' |
|
|
|
|
|
def add_to_prompt(prompt_text,shot_radio, style_radio, lighting_radio, context_radio, lens_radio, device_radio ): |
|
if shot_radio != '': |
|
prompt_text += ", "+shot_radio |
|
if style_radio != '': |
|
prompt_text += ", "+style_radio |
|
if lighting_radio != '': |
|
prompt_text += ", "+lighting_radio |
|
if context_radio != '': |
|
prompt_text += ", "+ context_radio |
|
if lens_radio != '': |
|
prompt_text += ", "+ lens_radio |
|
if device_radio != '': |
|
prompt_text += ", "+ device_radio |
|
return prompt_text |
|
|
|
def proceed_with_generation(input_file, extract_text, prompt_text, shot_radio, style_radio, lighting_radio, context_radio, lens_radio, device_radio): |
|
if extract_text == "" or input_file == "": |
|
prompt = add_to_prompt(prompt_text, shot_radio, style_radio, lighting_radio, context_radio, lens_radio, device_radio) |
|
print(prompt) |
|
output = SDpipe(prompt, height=512, width=512, num_images_per_prompt=4) |
|
return output.images |
|
elif extract_text != "" and input_file != "" and prompt_text !='': |
|
prompt = add_to_prompt(prompt_text,shot_radio, style_radio, lighting_radio, context_radio, lens_radio, device_radio) |
|
print(prompt) |
|
print(extract_text) |
|
output = generate_image(Image.fromarray(input_file), extract_text, prompt) |
|
return output |
|
else: |
|
raise gr.Error("Please fill all details for guided image or atleast prompt for free image rendition !") |
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
|
with gr.Row(): |
|
input_file = gr.Image(interactive = True, label=image_label, visible=False, shape=(512,512)) |
|
extract_text = gr.Textbox(label= extract_label, lines=1, interactive = True, visible = True) |
|
prompt_text = gr.Textbox(label= prompt_label, lines=1, interactive = True, visible = True) |
|
|
|
with gr.Accordion("Advanced Options", open=False): |
|
shot_radio = gr.Radio(shot_services , label=shot_label, ) |
|
style_radio = gr.Radio(style_services , label=style_label) |
|
lighting_radio = gr.Radio(lighting_services , label=lighting_label) |
|
context_radio = gr.Radio(context_services , label=context_label) |
|
lens_radio = gr.Radio(lens_services , label=lens_label) |
|
device_radio = gr.Radio(device_services , label=device_label) |
|
|
|
button = gr.Button(value = button_label , visible = True) |
|
|
|
with gr.Row(): |
|
output_gallery = gr.Gallery(label = output_label, visible= True) |
|
|
|
|
|
button.click( proceed_with_generation, [input_file, extract_text, prompt_text, shot_radio, style_radio, lighting_radio, context_radio, lens_radio, device_radio], [output_gallery]) |
|
|
|
|
|
demo.launch(debug=True) |