from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation from diffusers import StableDiffusionInpaintPipeline,StableDiffusionPipeline from PIL import Image import requests import cv2 import torch import matplotlib.pyplot as plt import io import requests import os import streamlit as st from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline import gradio as gr import gc processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined") device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') IPmodel_path = "runwayml/stable-diffusion-inpainting" IPpipe = StableDiffusionInpaintPipeline.from_pretrained( IPmodel_path, use_auth_token= st.secrets["AUTH_TOKEN"] ).to(device) trans_tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M") trans_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M") SDpipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", use_auth_token=st.secrets["AUTH_TOKEN"]).to(device) def create_mask(image, prompt): inputs = processor(text=[prompt], images=[image], padding="max_length", return_tensors="pt") # predict with torch.no_grad(): outputs = model(**inputs) preds = outputs.logits filename = f"mask.png" plt.imsave(filename,torch.sigmoid(preds)) gray_image = cv2.cvtColor(cv2.imread(filename), cv2.COLOR_BGR2GRAY) (thresh, bw_image) = cv2.threshold(gray_image, 100, 255, cv2.THRESH_BINARY) # For debugging only: # cv2.imwrite(filename,bw_image) # fix color format cv2.cvtColor(bw_image, cv2.COLOR_BGR2RGB) mask = cv2.bitwise_not(bw_image) cv2.imwrite(filename, mask) return'mask.png') def generate_image(image, product_name, target_name): mask = create_mask(image, product_name) image = image.resize((512, 512)) mask = mask.resize((512,512)) guidance_scale=8 #guidance_scale=16 num_samples = 4 prompt = target_name generator = torch.Generator(device=device).manual_seed(22) # change the seed to get different results im = IPpipe( prompt=prompt, image=image, mask_image=mask, guidance_scale=guidance_scale, generator=generator, ).images im.enable_attention_slicing() return im gc.collect() torch.cuda.empty_cache() image_label = 'Please upload the image (optional)' extract_label = 'Specify what needs to be extracted from the above image' prompt_label = 'Specify the description of image to be generated' button_label = "Proceed" output_label = "Results" shot_services = ['close-up', 'extreme-closeup', 'POV','medium', 'long'] shot_label = 'Choose the shot type' style_services = ['polaroid', 'monochrome', 'long exposure','color splash', 'Tilt shift'] style_label = 'Choose the style type' lighting_services = ['soft', 'ambivalent', 'ring','sun', 'cinematic'] lighting_label = 'Choose the lighting type' context_services = ['indoor', 'outdoor', 'at night','in the park', 'in the beach','studio'] context_label = 'Choose the context' lens_services = ['wide angle', 'telephoto', '24 mm','EF 70mm', 'Bokeh'] lens_label = 'Choose the lens type' device_services = ['iphone', 'CCTV', 'Nikon ZFX','Canon', 'Gopro'] device_label = 'Choose the device type' def add_to_prompt(prompt_text,shot_radio, style_radio, lighting_radio, context_radio, lens_radio, device_radio ): if shot_radio != '': prompt_text += ", "+shot_radio if style_radio != '': prompt_text += ", "+style_radio if lighting_radio != '': prompt_text += ", "+lighting_radio if context_radio != '': prompt_text += ", "+ context_radio if lens_radio != '': prompt_text += ", "+ lens_radio if device_radio != '': prompt_text += ", "+ device_radio return prompt_text def proceed_with_generation(input_file, extract_text, prompt_text, shot_radio, style_radio, lighting_radio, context_radio, lens_radio, device_radio): if extract_text == "" or input_file == "": prompt = add_to_prompt(prompt_text, shot_radio, style_radio, lighting_radio, context_radio, lens_radio, device_radio) print(prompt) output = SDpipe(prompt, height=512, width=512, num_images_per_prompt=4) return output.images elif extract_text != "" and input_file != "" and prompt_text !='': prompt = add_to_prompt(prompt_text,shot_radio, style_radio, lighting_radio, context_radio, lens_radio, device_radio) print(prompt) print(extract_text) output = generate_image(Image.fromarray(input_file), extract_text, prompt) return output else: raise gr.Error("Please fill all details for guided image or atleast prompt for free image rendition !") with gr.Blocks() as demo: with gr.Row(): input_file = gr.Image(interactive = True, label=image_label, visible=False, shape=(512,512)) extract_text = gr.Textbox(label= extract_label, lines=1, interactive = True, visible = True) prompt_text = gr.Textbox(label= prompt_label, lines=1, interactive = True, visible = True) with gr.Accordion("Advanced Options", open=False): shot_radio = gr.Radio(shot_services , label=shot_label, ) style_radio = gr.Radio(style_services , label=style_label) lighting_radio = gr.Radio(lighting_services , label=lighting_label) context_radio = gr.Radio(context_services , label=context_label) lens_radio = gr.Radio(lens_services , label=lens_label) device_radio = gr.Radio(device_services , label=device_label) button = gr.Button(value = button_label , visible = True) with gr.Row(): output_gallery = gr.Gallery(label = output_label, visible= True) proceed_with_generation, [input_file, extract_text, prompt_text, shot_radio, style_radio, lighting_radio, context_radio, lens_radio, device_radio], [output_gallery]) demo.launch(debug=True)