|
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation |
|
from diffusers import StableDiffusionInpaintPipeline,StableDiffusionPipeline |
|
from PIL import Image |
|
import requests |
|
|
|
import cv2 |
|
import torch |
|
import matplotlib.pyplot as plt |
|
|
|
import io |
|
import requests |
|
from huggingface_hub import login |
|
|
|
import os |
|
import streamlit as st |
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline |
|
|
|
|
|
|
|
processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") |
|
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined") |
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
IPmodel_path = "runwayml/stable-diffusion-inpainting" |
|
|
|
IPpipe = StableDiffusionInpaintPipeline.from_pretrained( |
|
IPmodel_path, |
|
revision="fp16", |
|
torch_dtype=torch.float16, |
|
use_auth_token= st.secrets["AUTH_TOKEN"] |
|
).to(device) |
|
|
|
trans_tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M") |
|
trans_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M") |
|
|
|
|
|
SDpipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", revision="fp16", torch_dtype=torch.float16, use_auth_token=st.secrets["AUTH_TOKEN"]).to(device) |
|
|
|
|
|
def create_mask(image, prompt): |
|
inputs = processor(text=[prompt], images=[image], padding="max_length", return_tensors="pt") |
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
|
|
preds = outputs.logits |
|
|
|
filename = f"mask.png" |
|
plt.imsave(filename,torch.sigmoid(preds)) |
|
|
|
gray_image = cv2.cvtColor(cv2.imread(filename), cv2.COLOR_BGR2GRAY) |
|
|
|
(thresh, bw_image) = cv2.threshold(gray_image, 100, 255, cv2.THRESH_BINARY) |
|
|
|
|
|
|
|
|
|
|
|
cv2.cvtColor(bw_image, cv2.COLOR_BGR2RGB) |
|
|
|
mask = cv2.bitwise_not(bw_image) |
|
cv2.imwrite(filename, mask) |
|
|
|
return Image.open('mask.png') |
|
|
|
|
|
|
|
|
|
def generate_image(image, product_name, target_name): |
|
mask = create_mask(image, product_name) |
|
image = image.resize((512, 512)) |
|
mask = mask.resize((512,512)) |
|
guidance_scale=8 |
|
|
|
num_samples = 4 |
|
|
|
prompt = target_name |
|
generator = torch.Generator(device=device).manual_seed(22) |
|
|
|
im = IPpipe( |
|
prompt=prompt, |
|
image=image, |
|
mask_image=mask, |
|
guidance_scale=guidance_scale, |
|
generator=generator, |
|
).images |
|
|
|
return im |
|
|
|
|
|
|
|
def translate_sentence(article, source, target): |
|
if target == 'eng_Latn': |
|
return article |
|
translator = pipeline('translation', model=trans_model, tokenizer=trans_tokenizer, src_lang=source, tgt_lang=target) |
|
output = translator(article, max_length=400) |
|
output = output[0]['translation_text'] |
|
return output |
|
|
|
|
|
codes_as_string = '''Modern Standard Arabic arb_Arab |
|
Danish dan_Latn |
|
German deu_Latn |
|
Greek ell_Grek |
|
English eng_Latn |
|
Estonian est_Latn |
|
Finnish fin_Latn |
|
French fra_Latn |
|
Hebrew heb_Hebr |
|
Hindi hin_Deva |
|
Croatian hrv_Latn |
|
Hungarian hun_Latn |
|
Indonesian ind_Latn |
|
Icelandic isl_Latn |
|
Italian ita_Latn |
|
Japanese jpn_Jpan |
|
Korean kor_Hang |
|
Luxembourgish ltz_Latn |
|
Macedonian mkd_Cyrl |
|
Maltese mlt_Latn |
|
Dutch nld_Latn |
|
Norwegian Bokmål nob_Latn |
|
Polish pol_Latn |
|
Portuguese por_Latn |
|
Russian rus_Cyrl |
|
Slovak slk_Latn |
|
Slovenian slv_Latn |
|
Spanish spa_Latn |
|
Serbian srp_Cyrl |
|
Swedish swe_Latn |
|
Thai tha_Thai |
|
Turkish tur_Latn |
|
Ukrainian ukr_Cyrl |
|
Vietnamese vie_Latn |
|
Chinese (Simplified) zho_Hans''' |
|
|
|
codes_as_string = codes_as_string.split('\n') |
|
|
|
flores_codes = {} |
|
for code in codes_as_string: |
|
lang, lang_code = code.split('\t') |
|
flores_codes[lang] = lang_code |
|
|
|
|
|
|
|
import gradio as gr |
|
import gc |
|
gc.collect() |
|
|
|
image_label = 'Please upload the image (optional)' |
|
extract_label = 'Specify what need to be extracted from the above image' |
|
prompt_label = 'Specify the description of image to be generated' |
|
button_label = "Proceed" |
|
output_label = "Generations" |
|
|
|
|
|
shot_services = ['close-up', 'extreme-closeup', 'POV','medium', 'long'] |
|
shot_label = 'Choose the shot type' |
|
|
|
style_services = ['polaroid', 'monochrome', 'long exposure','color splash', 'Tilt shift'] |
|
style_label = 'Choose the style type' |
|
|
|
lighting_services = ['soft', 'ambivalent', 'ring','sun', 'cinematic'] |
|
lighting_label = 'Choose the lighting type' |
|
|
|
context_services = ['indoor', 'outdoor', 'at night','in the park', 'in the beach','studio'] |
|
context_label = 'Choose the context' |
|
|
|
lens_services = ['wide angle', 'telephoto', '24 mm','EF 70mm', 'Bokeh'] |
|
lens_label = 'Choose the lens type' |
|
|
|
device_services = ['iphone', 'CCTV', 'Nikon ZFX','Canon', 'Gopro'] |
|
device_label = 'Choose the device type' |
|
|
|
|
|
def change_lang(choice): |
|
global lang_choice |
|
lang_choice = choice |
|
new_image_label = translate_sentence(image_label, "english", choice) |
|
return [gr.update(visible=True, label=translate_sentence(image_label, flores_codes["English"],flores_codes[choice])), |
|
gr.update(visible=True, label=translate_sentence(extract_label, flores_codes["English"],flores_codes[choice])), |
|
gr.update(visible=True, label=translate_sentence(prompt_label, flores_codes["English"],flores_codes[choice])), |
|
gr.update(visible=True, value=translate_sentence(button_label, flores_codes["English"],flores_codes[choice])), |
|
gr.update(visible=True, label=translate_sentence(button_label, flores_codes["English"],flores_codes[choice])), |
|
] |
|
|
|
def add_to_prompt(prompt_text,shot_radio, style_radio, lighting_radio, context_radio, lens_radio, device_radio ): |
|
if shot_radio != '': |
|
prompt_text += ","+shot_radio |
|
if style_radio != '': |
|
prompt_text += ","+style_radio |
|
if lighting_radio != '': |
|
prompt_text += ","+lighting_radio |
|
if context_radio != '': |
|
prompt_text += ","+ context_radio |
|
if lens_radio != '': |
|
prompt_text += ","+ lens_radio |
|
if device_radio != '': |
|
prompt_text += ","+ device_radio |
|
return prompt_text |
|
|
|
def proceed_with_generation(input_file, extract_text, prompt_text, shot_radio, style_radio, lighting_radio, context_radio, lens_radio, device_radio): |
|
if extract_text == "" or input_file == "": |
|
translated_prompt = translate_sentence(prompt_text, flores_codes[lang_choice], flores_codes["English"]) |
|
translated_prompt = add_to_prompt(translated_prompt,shot_radio, style_radio, lighting_radio, context_radio, lens_radio, device_radio) |
|
print(translated_prompt) |
|
output = SDpipe(translated_prompt, height=512, width=512, num_images_per_prompt=4) |
|
return output.images |
|
elif extract_text != "" and input_file != "" and prompt_text !='': |
|
translated_prompt = translate_sentence(prompt_text, flores_codes[lang_choice], flores_codes["English"]) |
|
translated_prompt = add_to_prompt(translated_prompt,shot_radio, style_radio, lighting_radio, context_radio, lens_radio, device_radio) |
|
print(translated_prompt) |
|
translated_extract = translate_sentence(extract_text, flores_codes[lang_choice], flores_codes["English"]) |
|
print(translated_extract) |
|
output = generate_image(Image.fromarray(input_file), translated_extract, translated_prompt) |
|
return output |
|
else: |
|
raise gr.Error("Please fill all details for guided image or atleast promt for free image rendition !") |
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
|
lang_option = gr.Dropdown(list(flores_codes.keys()), default='English', label='Please Select your Language') |
|
|
|
with gr.Row(): |
|
input_file = gr.Image(interactive = True, label=image_label, visible=False, shape=(512,512)) |
|
extract_text = gr.Textbox(label= extract_label, lines=1, interactive = True, visible = True) |
|
prompt_text = gr.Textbox(label= prompt_label, lines=1, interactive = True, visible = True) |
|
|
|
with gr.Accordion("Advanced Options", open=False): |
|
shot_radio = gr.Radio(shot_services , label=shot_label, ) |
|
style_radio = gr.Radio(style_services , label=style_label) |
|
lighting_radio = gr.Radio(lighting_services , label=lighting_label) |
|
context_radio = gr.Radio(context_services , label=context_label) |
|
lens_radio = gr.Radio(lens_services , label=lens_label) |
|
device_radio = gr.Radio(device_services , label=device_label) |
|
|
|
button = gr.Button(value = button_label , visible = False) |
|
|
|
with gr.Row(): |
|
output_gallery = gr.Gallery(label = output_label, visible= False) |
|
|
|
|
|
|
|
|
|
lang_option.change(fn=change_lang, inputs=lang_option, outputs=[input_file, extract_text, prompt_text, button, output_gallery]) |
|
button.click( proceed_with_generation, [input_file, extract_text, prompt_text, shot_radio, style_radio, lighting_radio, context_radio, lens_radio, device_radio], [output_gallery]) |
|
|
|
|
|
demo.launch(debug=True) |