from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation from diffusers import StableDiffusionInpaintPipeline,StableDiffusionPipeline from PIL import Image import requests import cv2 import torch import matplotlib.pyplot as plt import io import requests import os import streamlit as st from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined") device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') IPmodel_path = "runwayml/stable-diffusion-inpainting" IPpipe = StableDiffusionInpaintPipeline.from_pretrained( IPmodel_path, use_auth_token= st.secrets["AUTH_TOKEN"] ).to(device) trans_tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M") trans_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M") SDpipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", use_auth_token=st.secrets["AUTH_TOKEN"]).to(device) def create_mask(image, prompt): inputs = processor(text=[prompt], images=[image], padding="max_length", return_tensors="pt") # predict with torch.no_grad(): outputs = model(**inputs) preds = outputs.logits filename = f"mask.png" plt.imsave(filename,torch.sigmoid(preds)) gray_image = cv2.cvtColor(cv2.imread(filename), cv2.COLOR_BGR2GRAY) (thresh, bw_image) = cv2.threshold(gray_image, 100, 255, cv2.THRESH_BINARY) # For debugging only: # cv2.imwrite(filename,bw_image) # fix color format cv2.cvtColor(bw_image, cv2.COLOR_BGR2RGB) mask = cv2.bitwise_not(bw_image) cv2.imwrite(filename, mask) return Image.open('mask.png') def generate_image(image, product_name, target_name): mask = create_mask(image, product_name) image = image.resize((512, 512)) mask = mask.resize((512,512)) guidance_scale=8 #guidance_scale=16 num_samples = 4 prompt = target_name generator = torch.Generator(device=device).manual_seed(22) # change the seed to get different results im = IPpipe( prompt=prompt, image=image, mask_image=mask, guidance_scale=guidance_scale, generator=generator, ).images return im def translate_sentence(article, source, target): if target == 'eng_Latn': return article translator = pipeline('translation', model=trans_model, tokenizer=trans_tokenizer, src_lang=source, tgt_lang=target) output = translator(article, max_length=400) output = output[0]['translation_text'] return output codes_as_string = '''Modern Standard Arabic arb_Arab Danish dan_Latn German deu_Latn Greek ell_Grek English eng_Latn Estonian est_Latn Finnish fin_Latn French fra_Latn Hebrew heb_Hebr Hindi hin_Deva Croatian hrv_Latn Hungarian hun_Latn Indonesian ind_Latn Icelandic isl_Latn Italian ita_Latn Japanese jpn_Jpan Korean kor_Hang Luxembourgish ltz_Latn Macedonian mkd_Cyrl Maltese mlt_Latn Dutch nld_Latn Norwegian Bokmål nob_Latn Polish pol_Latn Portuguese por_Latn Russian rus_Cyrl Slovak slk_Latn Slovenian slv_Latn Spanish spa_Latn Serbian srp_Cyrl Swedish swe_Latn Thai tha_Thai Turkish tur_Latn Ukrainian ukr_Cyrl Vietnamese vie_Latn Chinese (Simplified) zho_Hans''' codes_as_string = codes_as_string.split('\n') flores_codes = {} for code in codes_as_string: lang, lang_code = code.split('\t') flores_codes[lang] = lang_code import gradio as gr import gc gc.collect() image_label = 'Please upload the image (optional)' extract_label = 'Specify what need to be extracted from the above image' prompt_label = 'Specify the description of image to be generated' button_label = "Proceed" output_label = "Generations" shot_services = ['close-up', 'extreme-closeup', 'POV','medium', 'long'] shot_label = 'Choose the shot type' style_services = ['polaroid', 'monochrome', 'long exposure','color splash', 'Tilt shift'] style_label = 'Choose the style type' lighting_services = ['soft', 'ambivalent', 'ring','sun', 'cinematic'] lighting_label = 'Choose the lighting type' context_services = ['indoor', 'outdoor', 'at night','in the park', 'in the beach','studio'] context_label = 'Choose the context' lens_services = ['wide angle', 'telephoto', '24 mm','EF 70mm', 'Bokeh'] lens_label = 'Choose the lens type' device_services = ['iphone', 'CCTV', 'Nikon ZFX','Canon', 'Gopro'] device_label = 'Choose the device type' def change_lang(choice): global lang_choice lang_choice = choice new_image_label = translate_sentence(image_label, "english", choice) return [gr.update(visible=True, label=translate_sentence(image_label, flores_codes["English"],flores_codes[choice])), gr.update(visible=True, label=translate_sentence(extract_label, flores_codes["English"],flores_codes[choice])), gr.update(visible=True, label=translate_sentence(prompt_label, flores_codes["English"],flores_codes[choice])), gr.update(visible=True, value=translate_sentence(button_label, flores_codes["English"],flores_codes[choice])), gr.update(visible=True, label=translate_sentence(button_label, flores_codes["English"],flores_codes[choice])), ] def add_to_prompt(prompt_text,shot_radio, style_radio, lighting_radio, context_radio, lens_radio, device_radio ): if shot_radio != '': prompt_text += ","+shot_radio if style_radio != '': prompt_text += ","+style_radio if lighting_radio != '': prompt_text += ","+lighting_radio if context_radio != '': prompt_text += ","+ context_radio if lens_radio != '': prompt_text += ","+ lens_radio if device_radio != '': prompt_text += ","+ device_radio return prompt_text def proceed_with_generation(input_file, extract_text, prompt_text, shot_radio, style_radio, lighting_radio, context_radio, lens_radio, device_radio): if extract_text == "" or input_file == "": translated_prompt = translate_sentence(prompt_text, flores_codes[lang_choice], flores_codes["English"]) translated_prompt = add_to_prompt(translated_prompt,shot_radio, style_radio, lighting_radio, context_radio, lens_radio, device_radio) print(translated_prompt) output = SDpipe(translated_prompt, height=512, width=512, num_images_per_prompt=4) return output.images elif extract_text != "" and input_file != "" and prompt_text !='': translated_prompt = translate_sentence(prompt_text, flores_codes[lang_choice], flores_codes["English"]) translated_prompt = add_to_prompt(translated_prompt,shot_radio, style_radio, lighting_radio, context_radio, lens_radio, device_radio) print(translated_prompt) translated_extract = translate_sentence(extract_text, flores_codes[lang_choice], flores_codes["English"]) print(translated_extract) output = generate_image(Image.fromarray(input_file), translated_extract, translated_prompt) return output else: raise gr.Error("Please fill all details for guided image or atleast promt for free image rendition !") with gr.Blocks() as demo: lang_option = gr.Dropdown(list(flores_codes.keys()), default='English', label='Please Select your Language') with gr.Row(): input_file = gr.Image(interactive = True, label=image_label, visible=False, shape=(512,512)) extract_text = gr.Textbox(label= extract_label, lines=1, interactive = True, visible = True) prompt_text = gr.Textbox(label= prompt_label, lines=1, interactive = True, visible = True) with gr.Accordion("Advanced Options", open=False): shot_radio = gr.Radio(shot_services , label=shot_label, ) style_radio = gr.Radio(style_services , label=style_label) lighting_radio = gr.Radio(lighting_services , label=lighting_label) context_radio = gr.Radio(context_services , label=context_label) lens_radio = gr.Radio(lens_services , label=lens_label) device_radio = gr.Radio(device_services , label=device_label) button = gr.Button(value = button_label , visible = False) with gr.Row(): output_gallery = gr.Gallery(label = output_label, visible= False) lang_option.change(fn=change_lang, inputs=lang_option, outputs=[input_file, extract_text, prompt_text, button, output_gallery]) button.click( proceed_with_generation, [input_file, extract_text, prompt_text, shot_radio, style_radio, lighting_radio, context_radio, lens_radio, device_radio], [output_gallery]) demo.launch(debug=True)