Spaces:
				
			
			
	
			
			
		Paused
		
	
	
	
			
			
	
	
	
	
		
		
		Paused
		
	Update app.py
Browse files
    	
        app.py
    CHANGED
    
    | @@ -7,17 +7,12 @@ from PIL import Image | |
| 7 | 
             
            from omegaconf import OmegaConf
         | 
| 8 | 
             
            from huggingface_hub import hf_hub_download
         | 
| 9 |  | 
| 10 | 
            -
            # ---  | 
| 11 | 
            -
             | 
| 12 | 
            -
            # Define o diretório e o caminho para os pesos do modelo
         | 
| 13 | 
             
            WEIGHTS_DIR = "./pretrained_weights/ByteMorpher"
         | 
| 14 | 
             
            MODEL_FILENAME = "dit.safetensors"
         | 
| 15 | 
             
            MODEL_PATH = os.path.join(WEIGHTS_DIR, MODEL_FILENAME)
         | 
| 16 | 
            -
             | 
| 17 | 
            -
            # Cria o diretório se ele não existir
         | 
| 18 | 
             
            os.makedirs(WEIGHTS_DIR, exist_ok=True)
         | 
| 19 |  | 
| 20 | 
            -
            # Verifica se o modelo já existe antes de fazer o download
         | 
| 21 | 
             
            if not os.path.exists(MODEL_PATH):
         | 
| 22 | 
             
                print(f"Modelo não encontrado em {MODEL_PATH}. Baixando do Hugging Face Hub...")
         | 
| 23 | 
             
                try:
         | 
| @@ -25,17 +20,14 @@ if not os.path.exists(MODEL_PATH): | |
| 25 | 
             
                        repo_id="ByteDance-Seed/BM-Model",
         | 
| 26 | 
             
                        filename=MODEL_FILENAME,
         | 
| 27 | 
             
                        local_dir=WEIGHTS_DIR,
         | 
| 28 | 
            -
                        local_dir_use_symlinks=False | 
| 29 | 
             
                    )
         | 
| 30 | 
             
                    print("Download do modelo concluído com sucesso.")
         | 
| 31 | 
             
                except Exception as e:
         | 
| 32 | 
             
                    print(f"Ocorreu um erro durante o download do modelo: {e}")
         | 
| 33 | 
            -
                    # Se o download falhar, o aplicativo não poderá funcionar.
         | 
| 34 | 
            -
                    # Você pode adicionar um tratamento de erro mais robusto aqui se desejar.
         | 
| 35 | 
             
            else:
         | 
| 36 | 
             
                print(f"Modelo já existe em {MODEL_PATH}. Pulando o download.")
         | 
| 37 | 
            -
             | 
| 38 | 
            -
            # --- Fim: Bloco de Download Automático do Modelo ---
         | 
| 39 |  | 
| 40 |  | 
| 41 | 
             
            from image_datasets.dataset import image_resize
         | 
| @@ -50,16 +42,15 @@ def generate(image: Image.Image, edit_prompt: str): | |
| 50 | 
             
                from src.flux.xflux_pipeline import XFluxSampler
         | 
| 51 |  | 
| 52 | 
             
                global sampler
         | 
| 53 | 
            -
                if sampler  | 
| 54 | 
            -
                    #  | 
|  | |
| 55 | 
             
                    sampler = XFluxSampler(
         | 
| 56 | 
            -
                        device | 
| 57 | 
            -
                        ip_loaded= | 
| 58 | 
            -
                        spatial_condition= | 
| 59 | 
            -
                         | 
| 60 | 
            -
                         | 
| 61 | 
            -
                        improj=None,
         | 
| 62 | 
            -
                        share_position_embedding = True,
         | 
| 63 | 
             
                    )
         | 
| 64 |  | 
| 65 | 
             
                img = image_resize(image, 544)
         | 
| @@ -68,6 +59,9 @@ def generate(image: Image.Image, edit_prompt: str): | |
| 68 | 
             
                img = torch.from_numpy((np.array(img) / 127.5) - 1)
         | 
| 69 | 
             
                img = img.permute(2, 0, 1).unsqueeze(0).to(device, dtype=dtype)
         | 
| 70 |  | 
|  | |
|  | |
|  | |
| 71 | 
             
                result = sampler(
         | 
| 72 | 
             
                    prompt=edit_prompt,
         | 
| 73 | 
             
                    width=args.sample_width,
         | 
| @@ -75,9 +69,9 @@ def generate(image: Image.Image, edit_prompt: str): | |
| 75 | 
             
                    num_steps=args.sample_steps,
         | 
| 76 | 
             
                    image_prompt=None,
         | 
| 77 | 
             
                    true_gs=args.cfg_scale,
         | 
| 78 | 
            -
                    seed=args.seed,
         | 
| 79 | 
             
                    ip_scale=args.ip_scale if args.use_ip else 1.0,
         | 
| 80 | 
            -
                    source_image=img if  | 
| 81 | 
             
                )
         | 
| 82 | 
             
                return result
         | 
| 83 |  | 
| @@ -201,7 +195,6 @@ def create_app(): | |
| 201 | 
             
                        </div>
         | 
| 202 | 
             
                        """
         | 
| 203 | 
             
                    )
         | 
| 204 | 
            -
                    # gr.Markdown(header, elem_id="header")
         | 
| 205 | 
             
                    with gr.Row(equal_height=False):
         | 
| 206 | 
             
                        with gr.Column(variant="panel", elem_classes="inputPanel"):
         | 
| 207 | 
             
                            original_image = gr.Image(
         | 
|  | |
| 7 | 
             
            from omegaconf import OmegaConf
         | 
| 8 | 
             
            from huggingface_hub import hf_hub_download
         | 
| 9 |  | 
| 10 | 
            +
            # --- Bloco de Download Automático do Modelo ---
         | 
|  | |
|  | |
| 11 | 
             
            WEIGHTS_DIR = "./pretrained_weights/ByteMorpher"
         | 
| 12 | 
             
            MODEL_FILENAME = "dit.safetensors"
         | 
| 13 | 
             
            MODEL_PATH = os.path.join(WEIGHTS_DIR, MODEL_FILENAME)
         | 
|  | |
|  | |
| 14 | 
             
            os.makedirs(WEIGHTS_DIR, exist_ok=True)
         | 
| 15 |  | 
|  | |
| 16 | 
             
            if not os.path.exists(MODEL_PATH):
         | 
| 17 | 
             
                print(f"Modelo não encontrado em {MODEL_PATH}. Baixando do Hugging Face Hub...")
         | 
| 18 | 
             
                try:
         | 
|  | |
| 20 | 
             
                        repo_id="ByteDance-Seed/BM-Model",
         | 
| 21 | 
             
                        filename=MODEL_FILENAME,
         | 
| 22 | 
             
                        local_dir=WEIGHTS_DIR,
         | 
| 23 | 
            +
                        local_dir_use_symlinks=False
         | 
| 24 | 
             
                    )
         | 
| 25 | 
             
                    print("Download do modelo concluído com sucesso.")
         | 
| 26 | 
             
                except Exception as e:
         | 
| 27 | 
             
                    print(f"Ocorreu um erro durante o download do modelo: {e}")
         | 
|  | |
|  | |
| 28 | 
             
            else:
         | 
| 29 | 
             
                print(f"Modelo já existe em {MODEL_PATH}. Pulando o download.")
         | 
| 30 | 
            +
            # --- Fim do Bloco de Download ---
         | 
|  | |
| 31 |  | 
| 32 |  | 
| 33 | 
             
            from image_datasets.dataset import image_resize
         | 
|  | |
| 42 | 
             
                from src.flux.xflux_pipeline import XFluxSampler
         | 
| 43 |  | 
| 44 | 
             
                global sampler
         | 
| 45 | 
            +
                if sampler is None:
         | 
| 46 | 
            +
                    # CORREÇÃO: Inicializa o sampler usando os argumentos do arquivo .yaml
         | 
| 47 | 
            +
                    print("Inicializando o XFluxSampler com a configuração...")
         | 
| 48 | 
             
                    sampler = XFluxSampler(
         | 
| 49 | 
            +
                        device=device,
         | 
| 50 | 
            +
                        ip_loaded=args.use_ip,
         | 
| 51 | 
            +
                        spatial_condition=args.use_spatial_condition,
         | 
| 52 | 
            +
                        share_position_embedding=args.share_position_embedding,
         | 
| 53 | 
            +
                        use_share_weight_referencenet=args.use_share_weight_referencenet
         | 
|  | |
|  | |
| 54 | 
             
                    )
         | 
| 55 |  | 
| 56 | 
             
                img = image_resize(image, 544)
         | 
|  | |
| 59 | 
             
                img = torch.from_numpy((np.array(img) / 127.5) - 1)
         | 
| 60 | 
             
                img = img.permute(2, 0, 1).unsqueeze(0).to(device, dtype=dtype)
         | 
| 61 |  | 
| 62 | 
            +
                # CORREÇÃO: Passa a imagem de origem se qualquer modo de condicionamento estiver ativo
         | 
| 63 | 
            +
                use_image_conditioning = args.use_spatial_condition or args.use_share_weight_referencenet
         | 
| 64 | 
            +
                
         | 
| 65 | 
             
                result = sampler(
         | 
| 66 | 
             
                    prompt=edit_prompt,
         | 
| 67 | 
             
                    width=args.sample_width,
         | 
|  | |
| 69 | 
             
                    num_steps=args.sample_steps,
         | 
| 70 | 
             
                    image_prompt=None,
         | 
| 71 | 
             
                    true_gs=args.cfg_scale,
         | 
| 72 | 
            +
                    seed=args.seed if args.seed != -1 else np.random.randint(0, 2**32 - 1),
         | 
| 73 | 
             
                    ip_scale=args.ip_scale if args.use_ip else 1.0,
         | 
| 74 | 
            +
                    source_image=img if use_image_conditioning else None,
         | 
| 75 | 
             
                )
         | 
| 76 | 
             
                return result
         | 
| 77 |  | 
|  | |
| 195 | 
             
                        </div>
         | 
| 196 | 
             
                        """
         | 
| 197 | 
             
                    )
         | 
|  | |
| 198 | 
             
                    with gr.Row(equal_height=False):
         | 
| 199 | 
             
                        with gr.Column(variant="panel", elem_classes="inputPanel"):
         | 
| 200 | 
             
                            original_image = gr.Image(
         |