Spaces:
Runtime error
Runtime error
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation | |
from diffusers import StableDiffusionInpaintPipeline,StableDiffusionPipeline | |
from PIL import Image | |
import requests | |
import cv2 | |
import torch | |
import matplotlib.pyplot as plt | |
import io | |
import requests | |
from huggingface_hub import login | |
import os | |
import streamlit as st | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline | |
processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") | |
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined") | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
IPmodel_path = "runwayml/stable-diffusion-inpainting" | |
IPpipe = StableDiffusionInpaintPipeline.from_pretrained( | |
IPmodel_path, | |
revision="fp16", | |
torch_dtype=torch.float16, | |
use_auth_token= st.secrets["AUTH_TOKEN"] | |
).to(device) | |
trans_tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M") | |
trans_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M") | |
SDpipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", revision="fp16", torch_dtype=torch.float16, use_auth_token=st.secrets["AUTH_TOKEN"]).to(device) | |
def create_mask(image, prompt): | |
inputs = processor(text=[prompt], images=[image], padding="max_length", return_tensors="pt") | |
# predict | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
preds = outputs.logits | |
filename = f"mask.png" | |
plt.imsave(filename,torch.sigmoid(preds)) | |
gray_image = cv2.cvtColor(cv2.imread(filename), cv2.COLOR_BGR2GRAY) | |
(thresh, bw_image) = cv2.threshold(gray_image, 100, 255, cv2.THRESH_BINARY) | |
# For debugging only: | |
# cv2.imwrite(filename,bw_image) | |
# fix color format | |
cv2.cvtColor(bw_image, cv2.COLOR_BGR2RGB) | |
mask = cv2.bitwise_not(bw_image) | |
cv2.imwrite(filename, mask) | |
return Image.open('mask.png') | |
def generate_image(image, product_name, target_name): | |
mask = create_mask(image, product_name) | |
image = image.resize((512, 512)) | |
mask = mask.resize((512,512)) | |
guidance_scale=8 | |
#guidance_scale=16 | |
num_samples = 4 | |
prompt = target_name | |
generator = torch.Generator(device=device).manual_seed(22) # change the seed to get different results | |
im = IPpipe( | |
prompt=prompt, | |
image=image, | |
mask_image=mask, | |
guidance_scale=guidance_scale, | |
generator=generator, | |
).images | |
return im | |
def translate_sentence(article, source, target): | |
if target == 'eng_Latn': | |
return article | |
translator = pipeline('translation', model=trans_model, tokenizer=trans_tokenizer, src_lang=source, tgt_lang=target) | |
output = translator(article, max_length=400) | |
output = output[0]['translation_text'] | |
return output | |
codes_as_string = '''Modern Standard Arabic arb_Arab | |
Danish dan_Latn | |
German deu_Latn | |
Greek ell_Grek | |
English eng_Latn | |
Estonian est_Latn | |
Finnish fin_Latn | |
French fra_Latn | |
Hebrew heb_Hebr | |
Hindi hin_Deva | |
Croatian hrv_Latn | |
Hungarian hun_Latn | |
Indonesian ind_Latn | |
Icelandic isl_Latn | |
Italian ita_Latn | |
Japanese jpn_Jpan | |
Korean kor_Hang | |
Luxembourgish ltz_Latn | |
Macedonian mkd_Cyrl | |
Maltese mlt_Latn | |
Dutch nld_Latn | |
Norwegian Bokmål nob_Latn | |
Polish pol_Latn | |
Portuguese por_Latn | |
Russian rus_Cyrl | |
Slovak slk_Latn | |
Slovenian slv_Latn | |
Spanish spa_Latn | |
Serbian srp_Cyrl | |
Swedish swe_Latn | |
Thai tha_Thai | |
Turkish tur_Latn | |
Ukrainian ukr_Cyrl | |
Vietnamese vie_Latn | |
Chinese (Simplified) zho_Hans''' | |
codes_as_string = codes_as_string.split('\n') | |
flores_codes = {} | |
for code in codes_as_string: | |
lang, lang_code = code.split('\t') | |
flores_codes[lang] = lang_code | |
import gradio as gr | |
import gc | |
gc.collect() | |
image_label = 'Please upload the image (optional)' | |
extract_label = 'Specify what need to be extracted from the above image' | |
prompt_label = 'Specify the description of image to be generated' | |
button_label = "Proceed" | |
output_label = "Generations" | |
shot_services = ['close-up', 'extreme-closeup', 'POV','medium', 'long'] | |
shot_label = 'Choose the shot type' | |
style_services = ['polaroid', 'monochrome', 'long exposure','color splash', 'Tilt shift'] | |
style_label = 'Choose the style type' | |
lighting_services = ['soft', 'ambivalent', 'ring','sun', 'cinematic'] | |
lighting_label = 'Choose the lighting type' | |
context_services = ['indoor', 'outdoor', 'at night','in the park', 'in the beach','studio'] | |
context_label = 'Choose the context' | |
lens_services = ['wide angle', 'telephoto', '24 mm','EF 70mm', 'Bokeh'] | |
lens_label = 'Choose the lens type' | |
device_services = ['iphone', 'CCTV', 'Nikon ZFX','Canon', 'Gopro'] | |
device_label = 'Choose the device type' | |
def change_lang(choice): | |
global lang_choice | |
lang_choice = choice | |
new_image_label = translate_sentence(image_label, "english", choice) | |
return [gr.update(visible=True, label=translate_sentence(image_label, flores_codes["English"],flores_codes[choice])), | |
gr.update(visible=True, label=translate_sentence(extract_label, flores_codes["English"],flores_codes[choice])), | |
gr.update(visible=True, label=translate_sentence(prompt_label, flores_codes["English"],flores_codes[choice])), | |
gr.update(visible=True, value=translate_sentence(button_label, flores_codes["English"],flores_codes[choice])), | |
gr.update(visible=True, label=translate_sentence(button_label, flores_codes["English"],flores_codes[choice])), | |
] | |
def add_to_prompt(prompt_text,shot_radio, style_radio, lighting_radio, context_radio, lens_radio, device_radio ): | |
if shot_radio != '': | |
prompt_text += ","+shot_radio | |
if style_radio != '': | |
prompt_text += ","+style_radio | |
if lighting_radio != '': | |
prompt_text += ","+lighting_radio | |
if context_radio != '': | |
prompt_text += ","+ context_radio | |
if lens_radio != '': | |
prompt_text += ","+ lens_radio | |
if device_radio != '': | |
prompt_text += ","+ device_radio | |
return prompt_text | |
def proceed_with_generation(input_file, extract_text, prompt_text, shot_radio, style_radio, lighting_radio, context_radio, lens_radio, device_radio): | |
if extract_text == "" or input_file == "": | |
translated_prompt = translate_sentence(prompt_text, flores_codes[lang_choice], flores_codes["English"]) | |
translated_prompt = add_to_prompt(translated_prompt,shot_radio, style_radio, lighting_radio, context_radio, lens_radio, device_radio) | |
print(translated_prompt) | |
output = SDpipe(translated_prompt, height=512, width=512, num_images_per_prompt=4) | |
return output.images | |
elif extract_text != "" and input_file != "" and prompt_text !='': | |
translated_prompt = translate_sentence(prompt_text, flores_codes[lang_choice], flores_codes["English"]) | |
translated_prompt = add_to_prompt(translated_prompt,shot_radio, style_radio, lighting_radio, context_radio, lens_radio, device_radio) | |
print(translated_prompt) | |
translated_extract = translate_sentence(extract_text, flores_codes[lang_choice], flores_codes["English"]) | |
print(translated_extract) | |
output = generate_image(Image.fromarray(input_file), translated_extract, translated_prompt) | |
return output | |
else: | |
raise gr.Error("Please fill all details for guided image or atleast promt for free image rendition !") | |
with gr.Blocks() as demo: | |
lang_option = gr.Dropdown(list(flores_codes.keys()), default='English', label='Please Select your Language') | |
with gr.Row(): | |
input_file = gr.Image(interactive = True, label=image_label, visible=False, shape=(512,512)) | |
extract_text = gr.Textbox(label= extract_label, lines=1, interactive = True, visible = True) | |
prompt_text = gr.Textbox(label= prompt_label, lines=1, interactive = True, visible = True) | |
with gr.Accordion("Advanced Options", open=False): | |
shot_radio = gr.Radio(shot_services , label=shot_label, ) | |
style_radio = gr.Radio(style_services , label=style_label) | |
lighting_radio = gr.Radio(lighting_services , label=lighting_label) | |
context_radio = gr.Radio(context_services , label=context_label) | |
lens_radio = gr.Radio(lens_services , label=lens_label) | |
device_radio = gr.Radio(device_services , label=device_label) | |
button = gr.Button(value = button_label , visible = False) | |
with gr.Row(): | |
output_gallery = gr.Gallery(label = output_label, visible= False) | |
lang_option.change(fn=change_lang, inputs=lang_option, outputs=[input_file, extract_text, prompt_text, button, output_gallery]) | |
button.click( proceed_with_generation, [input_file, extract_text, prompt_text, shot_radio, style_radio, lighting_radio, context_radio, lens_radio, device_radio], [output_gallery]) | |
demo.launch(debug=True) |