Spaces:
Paused
Paused
import os | |
from collections import OrderedDict | |
import gradio as gr | |
import shutil | |
import uuid | |
import torch | |
from pathlib import Path | |
from lib.utils.iimage import IImage | |
from PIL import Image | |
from lib import models | |
from lib.methods import rasg, sd, sr | |
from lib.utils import poisson_blend, image_from_url_text | |
TMP_DIR = 'gradio_tmp' | |
if Path(TMP_DIR).exists(): | |
shutil.rmtree(TMP_DIR) | |
Path(TMP_DIR).mkdir(exist_ok=True, parents=True) | |
os.environ['GRADIO_TEMP_DIR'] = TMP_DIR | |
on_huggingspace = os.environ.get("SPACE_AUTHOR_NAME") == "PAIR" | |
negative_prompt_str = "text, bad anatomy, bad proportions, blurry, cropped, deformed, disfigured, duplicate, error, extra limbs, gross proportions, jpeg artifacts, long neck, low quality, lowres, malformed, morbid, mutated, mutilated, out of frame, ugly, worst quality" | |
positive_prompt_str = "Full HD, 4K, high quality, high resolution" | |
example_inputs = [ | |
['assets/examples/images/a40.jpg', 'medieval castle'], | |
['assets/examples/images/a4.jpg', 'parrot'], | |
['assets/examples/images/a65.jpg', 'hoodie'], | |
['assets/examples/images/a54.jpg', 'salad'], | |
['assets/examples/images/a51.jpg', 'space helmet'], | |
['assets/examples/images/a46.jpg', 'stack of books'], | |
['assets/examples/images/a19.jpg', 'antique greek vase'], | |
['assets/examples/images/a2.jpg', 'sunglasses'], | |
] | |
thumbnails = [ | |
'https://lh3.googleusercontent.com/pw/ABLVV87iCMYxrA59EAiPc3eRTA59jGF2_sJSnwBSilXwx2wtVlEhCsla2CNhCte4oHuDc4Gix4r7QAlbyRrYkHHgT8v6SOhpElvzlS-Ve91lYUEsWeiOf5yix3b47zuW8gEq_n5qBAGocyMpnc0In52mvl3CYIYcXonJx3dXmFvhhSOegluIcnq33t70mkA75YNMaFz3ovDyq9AfCIFEXGqI6cPUSy9n-fvoAUwG5VZZONdWQPpAClk0Im981hvHuLHz3viAHTbrHL4kzEJNclgXFHGUfEB-IU-cUmCVWnXPEaBFnlu_YabcgFdDAfz-bOy0W64TuUAYDzY_X8mv2fKXMoO0KOnMGZL3j0x08rK8kFzs9opczNFwmRjDQJySNVC16upa2HqYBk3S_O-uhC6DPmfm-5RBKXaNrz_MDFKc7n4-DYEBGXz3drbC0U1AIhqvRf5i99QzQWaxccU-maHRJcKSXAndCW0tWxZpWiYx-PFlcUvKM8ls7ySk7JIFFz9gypjubtP2m205uoDwAd0qoXdxgjbFJB5S33PDXVA4qWh9VCs4QLM23H1hoVKzMxxsSXBRkI0-gQK1K6epmcNaEecS61MeUE-sSy24gAKDP00vIx7_RWCqDTMW9jJNm9cQrI7ZCZmSExczudzjb1zXu9r1B1NywGXfboI5cRw_YNZCjPxEHNTXkfV73bApnLnXBDLOSczfi7S4lurLBQh42gbdoMO9SfR0Ssaf5I15D8MGGvxAU9ol6BKxvxIc7PkCIBfs_-codk53u4Rnmt0KrOklYMDOpEBFC_yb_1tlmCkf8QC3gWgNTwIOHU9GDc3Zg6-_-Cu-1FpAeaZOUU4P-HJvC1xSmnMZ0jEc_ZDZZxeuRxk5BccWgt9A-qhvQMxGY3jqcyUdD3taaqzJ_JfsaiBHTk1mrW7IpVJsfVMeNcMZCoV2mI7t8AHh74L2WyTIJFRAOZFPdNLD=w3580-h1150-s-no-gm?authuser=2', | |
'https://lh3.googleusercontent.com/pw/ABLVV84o8jPd39N9XwRLiEwh3G7G8EzTwkoC1dOlL_DapKteXwaX8TLGfkbrC9r4oawjuxf9NiPHnsgmJpWEIDxnQZVzniGw06ylaCDGjBiiNq0z3fEyWHC7jeugaWj98Aqv9FdRm6KFznDeKi0crCxh6QsPbDj8VAFMfoYvpqeVTogWKeA8anzzX5H9zqbvC1H5XPqQFq5DCIPgXN33eZuR7e6yuRqMEgRB9wcaH6Wuv1VUU3HVtsgWzaIC0JPsgTcA2gZA_tikypUWZc32Rtg9rJxmp_ONcOLW56C3AxnsYU_QVohX0Ojol_MsUWciG2jFpRLtVPWEmcwriMxMIsIWwaReY7jUUoMOaYK2AnIMrnNH8_RTvQvBtHfysXmV4BoyyOV4waZeK_gEYO32wlHPztAF_urY-Mhcw__SlVPxHZ32fZi8W8cc2fX6oTCOMSiZ_lTw1kiIRg-K7ZLYbsJlkuC7Al4irAta1bbEBNDVdW2RVRrPcn6PYhmx6LA8kU2NOwDy7j2_zk8_ZWB1VtCGdmU0dYjXmzvPXa5nblCAwogecMeIvWrQE5M6fvxpIWZUnB-77JCOTiAeYWdYrcS5I4SXdKo4Zh6ZnUAAUmoNSQI9XM94s0wcpGTFOfKD8vdOthTyXezrZ72yigVEjIwxkbtU-ktz6p3Tw5qq3bTQPdZchu1YjmrRBY3_TEeKdkjV3m-Kr8gSrL3eIhRtzGv1h955nFdUakHXJwE6KiRrFaI6vD-zdroFrc9O2ufaqo1cXFCueth2hJmqqeQE4mMLFbrPyMS4reYWJ0NZ0mRjn4ykdDVUOEPybDfG6H_j0pLhrTj6qdCCaUjfZ4YAz-THcjUcJUI64-RKO-virtNkk_xGr4_JDbgnnWDTw8t7HX-mcejT_nMR5mqdMm0b9yeSpPCmmKZLGa4cBTaG_S0MfzJYaM3dn-1EgwzC8bBI-wbM2NLMKef3V7Xi=w3580-h1152-s-no-gm?authuser=2', | |
'https://lh3.googleusercontent.com/pw/ABLVV84hMfPeQn5fuVQ1sHUn2adaCduf8MqFPfUE33e7ZI68paIQOqun2_FcIaKAv6KzKsQDHd0gcPYLui7PwKUwaU1Myyfsk5zmhNLOlKu2pxNUdRI1_xWDXe0OkgIGgSVTnxJcAYBsRcPLQc6l-r6w1Tyh37-t5EDw6fYAbtN283ttKN6rPikByccpL6u14TuQLLnxS7KP1qX1BC1uKdnbSX2YMxnB2oPyDXzrSNyuwnqilCEMOdEOv1C4s8piaqLNpfTU-w4yz769Zmt0RYtcoU_RupWFd3Wv6dvrQqZ4pvNkhmKuKMv0vdhSxuCgQV_upFhV2XvNgyNmiXgu9TQZGabb83iRfUVAJ7TyxHic-noQbjfP7NqpEDyDRGDJWtaVOurARj67NuZN3up0MN1NByanM9xnXrnCH84ptfrz_oq9Z9EHAEHbssRkNRqOjMMWjrvECyKI4uoKLGMRdTTXz9znD24odxFOthA9Zk0kGGzv09_ghe8OU9cLjV7zyCkukQsVizbKy6rLfR8v-MfqiVywb_DSSh9ta0xwwgHNNCF9PplOxOpPsDqwIEMGXQPnMse7wWuW3h2m5oYKMwnv4cEwZsOc2-qSYMlscLtcuUlYEDiY0lBwpBigRHuAraYEXhBLaQ3RXGznc5loOfUvKeM2Cx0yDqz-A9vHXgcpLHOFB2duL7P9zH7WUFu3L6hnHONh1CttideaiFFJ3vguCFGj6RhUhaIBxJSJfuEqHMXaYOH2sgqk2GthQPnpCXirWAf6rOdvSgwyLEsJIw0XiM1EfcKAhPqAa32SxWq9V56OLWoRPZRbKWnZnGS3DhkKgeqRwF39FlDUscrBxKwBBTFC59sbIBoUGfUvsuxIwsfOjo8GRFQibEchI6tIZSPIw70ANNkIQBRrPHgVU18-fCR44Ls2FoogJKyQCz76-Fz0ox1i-I4irmZ1Lle0m5CTPmImFLXx-KyU=w2924-h1858-s-no-gm?authuser=2', | |
'https://lh3.googleusercontent.com/pw/ABLVV85OaqBvRCxSxZSKl4eI87bBzXkM0_QgTJ1c4cwviYwxRL7WJ3a9q8yGxUCC3pKUFo13zA9oSEvAGQXPp-hKNfIAaNSWt3oBMrmfNb6RNDc9ER1vMttovSubw0rsr7YkSqztAfb_SB5aAM-VupYpNM2Zqc8sE2t5MREv-WDR9Z61DuIwVUSYbkaOldzobJ33eAqpsq6jRSNjcYbTY_lt7ngi2I9S8li3D8STvj8TE_NJ7ewuUwrWoNu9FVdiJ8zO_2faH9VFXFxMjnPmj8bc9g26HVYitlXPmso8am_lBJNLQsEfmcBIfzcavmOKsj7tKRtOxls_4x1ApQcE7RavHWITNj7eD4aSlHcoJgh-IwdBuocckOyqvrBVhlTuh7vX02j94VsAaVbNhZHsXBI-Lb9_JFnsPeTVyIM4Nc47LFFY3sLgXtzT6P1ydTsRjMw_XlkBhgPNMNhtt9PgJ_7sjxeZIeLGrp1ESZjAsyto5p8M0DkGX6aOxvAXvcK-q8MX3wuf0C-WpzDet0rUO0KhMHyVQpSzVsviYG-XV01kUabsAP6DCzb--brXxgj81lydeLq5HqMzaN13VZsKSIIxDxFwurNPBfesykul195qnRj58onbOUr4kgFPiodB_iX64DrBTR-tt9oXB3Fq6jVmn4FXhnH0g2bzChAADghDCGHcNHc2jUssVr-L59n9Bjz4StDwpL92PpXq-bmxlNvusr-VIs_197p1KqDjQSsBJa3IFSJ3wBv0HhbSZv292ok1Pu0ypv0HoD9FIQjnP6dsk5zzX0AhRfbD8y03lCkvH27Q55LT3dOClw1azAOsWefhuEd0O7Uad57TKA6T-UVaDfPvx4T5EXk9agtDakwPzBAd1xt7cldQjkUvKSLkahmFMWP9cMqidG6bI7v97Bl0y2kb6hKSKNNX7F3aKRF6XUppHbOcEMCO65FJOyuDEug5FUMI7CHh_H4j=w2622-h1858-s-no-gm?authuser=2', | |
'https://lh3.googleusercontent.com/pw/ABLVV879DjJZRPXLj_pDl3X61Q2NC0UpNLnvCw7ME5ooTujvQG3PHlNijyLHk3lJ9Su4p0ejZ7Q7mLl_w7kil2cnUfy4iBCjZFLEcRjjIoZOdxFWayy2_MX9nN8frnSVVcLh_1GE-Rt98AAWaWiEZGx1cbUCAv4y3Y9SXxAIe_DVfMtB1sWzO_dwcMd5ybTaAZ0pXRnRtP9cIrZIvusYRFLWAX-WcwCyOtNAlYvh1X14Dyd3OIs0zut9Z4H38xETX-KdZhgQOvx4XBGN0n2WNlLIftoIm7-VzUgAMdi9UTbN7emZtImq6YKP0BIW2QFPeZqvMtIBVebHa7NAM6Z-V2plDlsj3FUAsdzbwjqzJ5vkO-cHVCLH-a_Vu2IKEt4zzMcIxOp9d7WkCIXyLq_e_aqvSDLWqov-w3F4-EfYadyOiBU0DVS55kz082ZYrs5tnxyc9pP9Skw0e4M82zj_eWOjWCx-Hg9mG8wIsQ_AnVYzPH9hkGSaTW7TNVuTz3gmjcmYZ4poRv7vsgMjfvt_pj6HOdCvW6LFYGLlEt6BaqyEr_Bh_5fUS9FvmFiAYfm4kEcJEFWfxgs8Kg3K5HGJIKx8z5JKYccjgIocpodoIaJLSUqN_Twb_ymt6_7craZLK4x_ISviYtinu0Q1XLdRHCL5LtTLahBSm_4a7RLvYp35cGBqACU-gW6itpLEnfVH749_jIrmH5RPkUiV12ScDbT6OCBo0o7k5Smu-gDY79x-3vFezYGhhjhFB85SHClbjN0ssmfphpwYf0nF_W2PZyns-yP32R5Qt_jNdi49tSPiji6ZWB4fSYe7ffRX-wVodgpEjE_fRR5Agk8-rLFNUGjjNT2givh94-OzQAYi9Dw5XPOFeaKIN-_k15-qc9aeu9oLXImb4_lPufrQN5TrRGnhFFEZAUJ4cGxdIZwl63Wk_wXX1ongG9UW-E5ptRGHQmijFdEuF26KSA4q=w2528-h1858-s-no-gm?authuser=2', | |
'https://lh3.googleusercontent.com/pw/ABLVV84PM84Taj4ptJ0okpqAtv7v6Dcsi0v4tul0_iHAKwpNoKgQ9KulZOZfl34FUR_yNgcOv0SVliE1fB_rliNpMROY5LhZFO-mlIPC-ONaLPLBqhsK-f-rOcVWJLLoaAhrzSNpi4Q2MynkNo-iJvNeN-HzDx_oS4PYvsksqpzY3uqIU4JN50N8JN2k_YBbU4Ckyl4whazYuE06XOSl8qp5oRhWqEK5A1t_c-zF720MFemIi2EWUVPgU68PWqWzwBD2vCnUkN6lxw_qx9h5RnPjF4yHH2Gp3Ytk30psgdOELIxspBOE9egndaA4vnLwQ6-DrAwnyVLDmDSCXjlcj4sJEDKJvtZsJLmfLK2DOxgMvPI6y8Cz_lzL1xzBGgNQ9J9h83P7a4Ui-Uf7yIb-w09qR6kxzwIZpun1gxS2XOURDHkzRt-e5VZ4QAqKT9mXH8prqp9ZQszBTJAkhX7Q7XFlP612Wc0DP3Dzulyrpvc-b2V8gjPQiubGN9OtsI9GEcsJuzfrHnTNV8Mro30z5puvIlT-gCvqy9S95QtxmawK9W9OAdzkF2ZUmko4_mhwKgZCC6IcaSJA-hJtwB7zE6awttk10okMvcknoArJVLIA89rxmUwYVtNjMUOo8L0_gRzjnjwCMFvReeVEQUjbxlYkpzeL6B4ZuIv-xrdSjPp0aS-24L087pSt-IribOI79yibRYoutqSNbs9WeaatjCuciKzJ2NO50ET_zKUxsW0DkL0lNVxo11XyjRPxDl8rGVxyOLc9DmdZfMb_y4HtyFZ3lZdrNzzGRJHJ7Tyk6Jg4mDPmdqWMzCMbxAON-7G2g_ct3Mb2_rvItO4u9yG0JSFv49OckDccc3z3PsWZGRccDcmD0jzRKWM3mys-xBLOFIaC6j9ZzdJKmcjwdrTnWpEMJmDv04u1N1nuYGVZc8VOjzVGmygAjIbAk3wtU28Z2xXToyHycFMCIRj6=w3580-h1186-s-no-gm?authuser=2', | |
'https://lh3.googleusercontent.com/pw/ABLVV85IDYIlqtfT35Vd49jfoyh4x_40iNwA-law-7cfQ1WFa2xmmJmYDlaG6upIn2Fmjw0tr7xcOdWSySJSh2evT9EfLL0aU1jhx9m2xz8SuxJUZZpxAS-FMtQcobIDQBvc7onMDipr_RjNoiwbiO6Smtdw7cpuiDlPnEC9EfwZH-ucEGLHBczojYX2AjjQiCgnNMuqqv-DMUDNYwY9CqjR-TIyE_0pSx-ciW65sHzOPgjtHMRT2Se30ebjoWsvHXCAYhBGMPD6w-UtUyMXL_Zu8lBec7AZ5mnY1emaIOpTbo9lzKbso4dapNQKQ82zTvk4m4DirP8qdzRc483BCCFP-weRaRn7eG-qI6jamb67CVgCFqXHvhf84lUX3jWhX0-aUiqW89aK-heViog9EhTL2o12o7w_65eIy4hkaFWWa4Ptdy22rb_u2EReW5AIboLE3QPUy9QpvMsTfEBzNMKPBCC8QuqxSFJSwv6JrBisFmCVzBZ-sDUKJAr5s8j5qaUf-Uh3M_xsn2ZhpzVubIJNsU7Ccwn-IIKMzMF4aR2LAa5E4YWsRK7rBab2Md8UQrUScbQf5xN8wCMIVRxgB6Z7JyrTje2wNHP5iYGQLksAmmp6o_kuDfbAkqYgR_mbPwG78bJXUJuBPTdcv7mm0iPsPX3Ufa6_RxAANtTMTfgLUZUtGarJ0CaP_NOlRM6NNZpe3k_-gnAXaWZ_klqdcfS3Bl-tY4DdC0dOwSxpUVD-Xeqt4u1JGRxCeZiDer21gpQVgNZOFjEjCg17YSRzN8BpdRP5XI2JfOfTElxmwk4jF-YtpArQkvGDCHEkpA9nuQvQzdA-jDJJaqUalmnacODylxns5WGEXdRzwqonq4ljkGregD3tCGD5KubOP2UdDm7NJ31rIZuB2Jod_J-3bCZowZKqCwMmYmGx26mbhSXsZtUcZEE9GDTNHRnix4E5a2PCf6yZ-7pX-vBi=w2846-h1858-s-no-gm?authuser=2', | |
'https://lh3.googleusercontent.com/pw/ABLVV85R30x4UdJxYPlc6QpL4MFfEpp93B2ma-8GzkvGV9BwLWCcdmkb6fdI7DnvONanoi8qQSYH8p6Et4ZII_Q5tMSEd4luzNYN5xTMw8Tz9dUTh8FZewwInu9itgZOuHjSgIxA-9vGF27KZcVwTq2VyZ8_yQleLNe77drFj3GAcFwzBAGr23AU6J2btXs7yqK4nqii4RzWG2_Xc8xIHocKPD4G4ELTABiCiP8sGtDObhotVEUhOLGKIGwLDHGVYx70a2nV8JDdmopKW3f3K5CTmZwBHDgvHDoYYclakT8uqbVW8-LgHKsoy4LqCIGA725D0aaURINPI8GKwtMAvTO_DEcyVr9tg_47J7S9jWEALcosrggSJr9_6MSMd9JLW26UVIKg3-h7_SEel2maAFaFQxClNhCzpOFoesrfLC0-W7eRNT-nbUeP2IccFoEnd88ZjAeiWDcSuAydcEmlEQAfX_6MdKaDB3t8CVq_iDoI9ejx-9sS7LnXrQUqfypMdo8fsXnvAH6Cm4HsOPXcUf88z2I5uWs8WMyGyKwfGvrgOKUx1F3i43FHUIFlPfNCb4DcwJicCMzXsSttaFoViaKNObC8IjYNdElJ2EFeRQ6fwQ0QdvVjs-G6e5qojLczmQheiwnvwfwDYYxRvBMvIbKrrPtk7BbxmxWtZNEQTZ38FYVp0sJzZgSgdcscXuCGzxJOlHl5ezGuSoaveeNOfwUcnt1UdwbIY5k532csFKinxq3By4bdBqQv7Fi3VPj--UywcofES6oGKwqwcrO79xBg9FIhcboEQ588zQJjrvXrsbfPp6qv9m2HE9S3glUvjmq66F4z1SdJVtMj3OHsYdu4K3zdVKynrutGbrtuCaz5eVIgAkO0CTJ5O_Wkch3XmvQNlW6xf_cuU7kSPpPHDf1BdRvBPzpKNGb7d4wf2kEwUTRGLhQaj6qh-jImuQ3XeA2cfEapc9Hp93lr=w3580-h1382-s-no-gm?authuser=2' | |
] | |
example_previews = [ | |
[thumbnails[0], 'Prompt: medieval castle'], | |
[thumbnails[1], 'Prompt: parrot'], | |
[thumbnails[2], 'Prompt: hoodie'], | |
[thumbnails[3], 'Prompt: salad'], | |
[thumbnails[4], 'Prompt: space helmet'], | |
[thumbnails[5], 'Prompt: laptop'], | |
[thumbnails[6], 'Prompt: antique greek vase'], | |
[thumbnails[7], 'Prompt: sunglasses'], | |
] | |
# Load models | |
inpainting_models = OrderedDict([ | |
("Dreamshaper Inpainting V8", models.ds_inp.load_model()), | |
("Stable-Inpainting 2.0", models.sd2_inp.load_model()), | |
("Stable-Inpainting 1.5", models.sd15_inp.load_model()) | |
]) | |
sr_model = models.sd2_sr.load_model(device='cuda:1') | |
sam_predictor = models.sam.load_model(device='cuda:0') | |
inp_model = inpainting_models[list(inpainting_models.keys())[0]] | |
def set_model_from_name(inp_model_name): | |
global inp_model | |
print (f"Activating Inpaintng Model: {inp_model_name}") | |
inp_model = inpainting_models[inp_model_name] | |
def rasg_run(use_painta, prompt, input, seed, eta, negative_prompt, positive_prompt, ddim_steps, | |
guidance_scale=7.5, batch_size=4): | |
torch.cuda.empty_cache() | |
seed = int(seed) | |
batch_size = max(1, min(int(batch_size), 4)) | |
image = IImage(input['image']).resize(512) | |
mask = IImage(input['mask']).rgb().resize(512) | |
method = ['rasg'] | |
if use_painta: method.append('painta') | |
inpainted_images = [] | |
blended_images = [] | |
for i in range(batch_size): | |
inpainted_image = rasg.run( | |
ddim = inp_model, | |
method = '-'.join(method), | |
prompt = prompt, | |
image = image.padx(64), | |
mask = mask.alpha().padx(64), | |
seed = seed+i*1000, | |
eta = eta, | |
prefix = '{}', | |
negative_prompt = negative_prompt, | |
positive_prompt = f', {positive_prompt}', | |
dt = 1000 // ddim_steps, | |
guidance_scale = guidance_scale | |
).crop(image.size) | |
blended_image = poisson_blend(orig_img = image.data[0], fake_img = inpainted_image.data[0], | |
mask = mask.data[0], dilation = 12) | |
blended_images.append(blended_image) | |
inpainted_images.append(inpainted_image.numpy()[0]) | |
return blended_images, inpainted_images | |
def sd_run(use_painta, prompt, input, seed, eta, negative_prompt, positive_prompt, ddim_steps, | |
guidance_scale=7.5, batch_size=4): | |
torch.cuda.empty_cache() | |
seed = int(seed) | |
batch_size = max(1, min(int(batch_size), 4)) | |
image = IImage(input['image']).resize(512) | |
mask = IImage(input['mask']).rgb().resize(512) | |
method = ['default'] | |
if use_painta: method.append('painta') | |
inpainted_images = [] | |
blended_images = [] | |
for i in range(batch_size): | |
inpainted_image = sd.run( | |
ddim = inp_model, | |
method = '-'.join(method), | |
prompt = prompt, | |
image = image.padx(64), | |
mask = mask.alpha().padx(64), | |
seed = seed+i*1000, | |
eta = eta, | |
prefix = '{}', | |
negative_prompt = negative_prompt, | |
positive_prompt = f', {positive_prompt}', | |
dt = 1000 // ddim_steps, | |
guidance_scale = guidance_scale | |
).crop(image.size) | |
blended_image = poisson_blend(orig_img = image.data[0], fake_img = inpainted_image.data[0], | |
mask = mask.data[0], dilation = 12) | |
blended_images.append(blended_image) | |
inpainted_images.append(inpainted_image.numpy()[0]) | |
return blended_images, inpainted_images | |
def upscale_run( | |
prompt, input, ddim_steps, seed, use_sam_mask, gallery, img_index, | |
negative_prompt='', positive_prompt=', high resolution professional photo'): | |
torch.cuda.empty_cache() | |
seed = int(seed) | |
img_index = int(img_index) | |
img_index = 0 if img_index < 0 else img_index | |
img_index = len(gallery) - 1 if img_index >= len(gallery) else img_index | |
img_info = gallery[img_index if img_index >= 0 else 0] | |
inpainted_image = image_from_url_text(img_info) | |
lr_image = IImage(inpainted_image) | |
hr_image = IImage(input['image']).resize(2048) | |
hr_mask = IImage(input['mask']).resize(2048) | |
output_image = sr.run(sr_model, sam_predictor, lr_image, hr_image, hr_mask, prompt=prompt + positive_prompt, | |
noise_level=0, blend_trick=True, blend_output=True, negative_prompt=negative_prompt, | |
seed=seed, use_sam_mask=use_sam_mask) | |
return output_image.numpy()[0], output_image.numpy()[0] | |
def switch_run(use_rasg, model_name, *args): | |
set_model_from_name(model_name) | |
if use_rasg: | |
return rasg_run(*args) | |
return sd_run(*args) | |
with gr.Blocks(css='style.css') as demo: | |
gr.HTML( | |
""" | |
<div style="text-align: center; max-width: 1200px; margin: 20px auto;"> | |
<h1 style="font-weight: 900; font-size: 3rem; margin-bottom: 0.5rem"> | |
🧑🎨 HD-Painter Demo | |
</h1> | |
<h2 style="font-weight: 450; font-size: 1rem; margin: 0rem"> | |
Hayk Manukyan<sup>1*</sup>, Andranik Sargsyan<sup>1*</sup>, Barsegh Atanyan<sup>1</sup>, Zhangyang Wang<sup>1,2</sup>, Shant Navasardyan<sup>1</sup> | |
and <a href="https://www.humphreyshi.com/home">Humphrey Shi</a><sup>1,3</sup> | |
</h2> | |
<h2 style="font-weight: 450; font-size: 1rem; margin: 0rem"> | |
<sup>1</sup>Picsart AI Resarch (PAIR), <sup>2</sup>UT Austin, <sup>3</sup>Georgia Tech | |
</h2> | |
<h2 style="font-weight: 450; font-size: 1rem; margin: 0rem"> | |
[<a href="https://arxiv.org/abs/2312.14091" style="color:blue;">arXiv</a>] | |
[<a href="https://github.com/Picsart-AI-Research/HD-Painter" style="color:blue;">GitHub</a>] | |
</h2> | |
<h2 style="font-weight: 450; font-size: 1rem; margin: 0.7rem auto; max-width: 1000px"> | |
<b>HD-Painter</b> enables prompt-faithfull and high resolution (up to 2k) image inpainting upon any diffusion-based image inpainting method. | |
</h2> | |
</div> | |
""") | |
if on_huggingspace: | |
gr.HTML(""" | |
<p>For faster inference without waiting in queue, you may duplicate the space and upgrade to the suggested GPU in settings. | |
<br/> | |
<a href="https://huggingface.co/spaces/PAIR/HD-Painter?duplicate=true"> | |
<img style="margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a> | |
</p>""") | |
with open('script.js', 'r') as f: | |
js_str = f.read() | |
demo.load(_js=js_str) | |
with gr.Row(): | |
with gr.Column(): | |
model_picker = gr.Dropdown( | |
list(inpainting_models.keys()), | |
value=list(inpainting_models.keys())[0], | |
label = "Please select a model!", | |
) | |
with gr.Column(): | |
use_painta = gr.Checkbox(value = True, label = "Use PAIntA") | |
use_rasg = gr.Checkbox(value = True, label = "Use RASG") | |
prompt = gr.Textbox(label = "Inpainting Prompt") | |
with gr.Row(): | |
with gr.Column(): | |
input = gr.ImageMask(label = "Input Image", brush_color='#ff0000', elem_id="inputmask", type="pil") | |
with gr.Row(): | |
inpaint_btn = gr.Button("Inpaint", scale = 0) | |
with gr.Accordion('Advanced options', open=False): | |
guidance_scale = gr.Slider(minimum = 0, maximum = 30, value = 7.5, label = "Guidance Scale") | |
eta = gr.Slider(minimum = 0, maximum = 1, value = 0.1, label = "eta") | |
ddim_steps = gr.Slider(minimum = 10, maximum = 100, value = 50, step = 1, label = 'Number of diffusion steps') | |
with gr.Row(): | |
seed = gr.Number(value = 49123, label = "Seed") | |
batch_size = gr.Number(value = 1, label = "Batch size", minimum=1, maximum=4) | |
negative_prompt = gr.Textbox(value=negative_prompt_str, label = "Negative prompt", lines=3) | |
positive_prompt = gr.Textbox(value=positive_prompt_str, label = "Positive prompt", lines=1) | |
with gr.Column(): | |
with gr.Row(): | |
output_gallery = gr.Gallery( | |
[], | |
columns = 4, | |
preview = True, | |
allow_preview = True, | |
object_fit='scale-down', | |
elem_id='outputgallery' | |
) | |
with gr.Row(): | |
upscale_btn = gr.Button("Send to Inpainting-Specialized Super-Resolution (x4)", scale = 1) | |
with gr.Row(): | |
use_sam_mask = gr.Checkbox(value = False, label = "Use SAM mask for background preservation (for SR only, experimental feature)") | |
with gr.Row(): | |
hires_image = gr.Image(label = "Hi-res Image") | |
label = gr.Markdown("## High-Resolution Generation Samples (2048px large side)") | |
with gr.Column(): | |
example_container = gr.Gallery( | |
example_previews, | |
columns = 4, | |
preview = True, | |
allow_preview = True, | |
object_fit='scale-down' | |
) | |
gr.Examples( | |
[ | |
example_inputs[i] + [[example_previews[i]]] | |
for i in range(len(example_previews)) | |
], | |
[input, prompt, example_container] | |
) | |
mock_output_gallery = gr.Gallery([], columns = 4, visible=False) | |
mock_hires = gr.Image(label = "__MHRO__", visible = False) | |
html_info = gr.HTML(elem_id=f'html_info', elem_classes="infotext") | |
inpaint_btn.click( | |
fn=switch_run, | |
inputs=[ | |
use_rasg, | |
model_picker, | |
use_painta, | |
prompt, | |
input, | |
seed, | |
eta, | |
negative_prompt, | |
positive_prompt, | |
ddim_steps, | |
guidance_scale, | |
batch_size | |
], | |
outputs=[output_gallery, mock_output_gallery], | |
api_name="inpaint" | |
) | |
upscale_btn.click( | |
fn=upscale_run, | |
inputs=[ | |
prompt, | |
input, | |
ddim_steps, | |
seed, | |
use_sam_mask, | |
mock_output_gallery, | |
html_info | |
], | |
outputs=[hires_image, mock_hires], | |
api_name="upscale", | |
_js="function(a, b, c, d, e, f, g){ return [a, b, c, d, e, f, selected_gallery_index()] }", | |
) | |
demo.queue(max_size=20) | |
demo.launch(share=True, allowed_paths=[TMP_DIR]) |