Edit model card

StyleBooth: Image Style Editing with Multimodal Instruction

Run StyleBooth

# `pip install scepter>0.0.4` or
# clone newest SCEPTER and run `PYTHONPATH=./ python <this_script>` at the main branch root.
import os
import unittest

from PIL import Image
from torchvision.utils import save_image

from scepter.modules.inference.stylebooth_inference import StyleboothInference
from scepter.modules.utils.config import Config
from scepter.modules.utils.file_system import FS
from scepter.modules.utils.logger import get_logger


class DiffusionInferenceTest(unittest.TestCase):
    def setUp(self):
        print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
        self.logger = get_logger(name='scepter')
        config_file = 'scepter/methods/studio/scepter_ui.yaml'
        cfg = Config(cfg_file=config_file)
        if 'FILE_SYSTEM' in cfg:
            for fs_info in cfg['FILE_SYSTEM']:
                FS.init_fs_client(fs_info)
        self.tmp_dir = './cache/save_data/diffusion_inference'
        if not os.path.exists(self.tmp_dir):
            os.makedirs(self.tmp_dir)

    def tearDown(self):
        super().tearDown()

    # uncomment this line to skip this module.
    # @unittest.skip('')
    def test_stylebooth(self):
        config_file = 'scepter/methods/studio/inference/edit/stylebooth_tb_pro.yaml'
        cfg = Config(cfg_file=config_file)
        diff_infer = StyleboothInference(logger=self.logger)
        diff_infer.init_from_cfg(cfg)

        output = diff_infer({'prompt': 'Let this image be in the style of sai-lowpoly'},
                            style_edit_image=Image.open('asset/images/inpainting_text_ref/ex4_scene_im.jpg'),
                            style_guide_scale_text=7.5,
                            style_guide_scale_image=1.5,
                            stylebooth_state=True)
        save_path = os.path.join(self.tmp_dir,
                                 'stylebooth_test_lowpoly_cute_dog.png')
        save_image(output['images'], save_path)


if __name__ == '__main__':
    unittest.main()
Downloads last month

-

Downloads are not tracked for this model. How to track
Unable to determine this model's library. Check the docs .