nftblackmagic commited on
Commit
cefa68e
·
unverified ·
2 Parent(s): 43dcb58 e7138b1

Merge pull request #28 from asutermo/user/asutermo/replicate

Browse files
Files changed (3) hide show
  1. .gitignore +3 -1
  2. cog.yaml +39 -0
  3. predict.py +97 -0
.gitignore CHANGED
@@ -55,4 +55,6 @@ Thumbs.db
55
  .gradio/example/github.mp4
56
 
57
  aws/
58
- checkpoints/
 
 
 
55
  .gradio/example/github.mp4
56
 
57
  aws/
58
+ checkpoints/
59
+
60
+ .cog/
cog.yaml ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Configuration for Cog ⚙️
2
+ # Reference: https://cog.run/yaml
3
+
4
+ build:
5
+ # set to true if your model requires a GPU
6
+ gpu: true
7
+
8
+ # a list of ubuntu apt packages to install
9
+ system_packages:
10
+ - "libgl1-mesa-glx"
11
+ - "libglib2.0-0"
12
+
13
+ # python version in the form '3.11' or '3.11.4'
14
+ python_version: "3.11"
15
+
16
+ # a list of packages in the format <package-name>==<version>
17
+ python_packages:
18
+ - torch==2.4.0
19
+ - transformers==4.43.3
20
+ - datasets==2.20.0
21
+ - accelerate==1.3.0
22
+ - jupyter==1.0.0
23
+ - numpy==1.26.4
24
+ - pillow==10.2.0
25
+ - peft==0.13.2
26
+ - diffusers>=0.32.0
27
+ - timm==0.9.16
28
+ - torchvision==0.19.0
29
+ - tqdm==4.66.5
30
+ - numpy==1.26.4
31
+ - sentencepiece
32
+ - protobuf
33
+
34
+ # commands run after the environment is setup
35
+ run:
36
+ - curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.8.2/pget_linux_x86_64" && chmod +x /usr/local/bin/pget
37
+
38
+ # predict.py defines how predictions are run on your model
39
+ predict: "predict.py:Predictor"
predict.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ from cog import BasePredictor, Input, Path, Secret
4
+ from diffusers.utils import load_image
5
+ from diffusers import FluxFillPipeline
6
+ from diffusers import FluxTransformer2DModel
7
+ import torch
8
+ from torchvision import transforms
9
+
10
+ class Predictor(BasePredictor):
11
+ def setup(self) -> None:
12
+ """Load part of the model into memory to make running multiple predictions efficient"""
13
+ self.try_on_transformer = FluxTransformer2DModel.from_pretrained("xiaozaa/catvton-flux-beta",
14
+ torch_dtype=torch.bfloat16)
15
+ self.try_off_transformer = FluxTransformer2DModel.from_pretrained("xiaozaa/cat-tryoff-flux",
16
+ torch_dtype=torch.bfloat16)
17
+
18
+ def predict(self,
19
+ hf_token: Secret = Input(description="Hugging Face API token. Create a write token at https://huggingface.co/settings/token. You also need to approve the Flux Dev terms."),
20
+ image: Path = Input(description="Image file path", default="https://github.com/nftblackmagic/catvton-flux/raw/main/example/person/1.jpg"),
21
+ mask: Path = Input(description="Mask file path", default="https://github.com/nftblackmagic/catvton-flux/blob/main/example/person/1_mask.png?raw=true"),
22
+ try_on: bool = Input(True, description="Try on or try off"),
23
+ garment: Path = Input(description="Garment file path", default="https://github.com/nftblackmagic/catvton-flux/raw/main/example/garment/00035_00.jpg"),
24
+ num_steps: int = Input(50, description="Number of steps to run the model for"),
25
+ guidance_scale: float = Input(30, description="Guidance scale for the model"),
26
+ seed: int = Input(0, description="Seed for the model"),
27
+ width: int = Input(576, description="Width of the output image"),
28
+ height: int = Input(768, description="Height of the output image")) -> List[Path]:
29
+
30
+ size = (width, height)
31
+ i = load_image(str(image)).convert("RGB").resize(size)
32
+ m = load_image(str(mask)).convert("RGB").resize(size)
33
+ g = load_image(str(garment)).convert("RGB").resize(size)
34
+
35
+ if try_on:
36
+ self.transformer = self.try_on_transformer
37
+ else:
38
+ self.transformer = self.try_off_transformer
39
+
40
+ self.pipe = FluxFillPipeline.from_pretrained(
41
+ "black-forest-labs/FLUX.1-dev",
42
+ transformer=self.transformer,
43
+ torch_dtype=torch.bfloat16,
44
+ token=hf_token.get_secret_value()
45
+ ).to("cuda")
46
+
47
+ self.pipe.transformer.to(torch.bfloat16)
48
+
49
+ transform = transforms.Compose([
50
+ transforms.ToTensor(),
51
+ transforms.Normalize([0.5], [0.5]) # For RGB images
52
+ ])
53
+ mask_transform = transforms.Compose([
54
+ transforms.ToTensor()
55
+ ])
56
+
57
+ # Transform images using the new preprocessing
58
+ image_tensor = transform(i)
59
+ mask_tensor = mask_transform(m)[:1] # Take only first channel
60
+ garment_tensor = transform(g)
61
+
62
+ # Create concatenated images
63
+ inpaint_image = torch.cat([garment_tensor, image_tensor], dim=2) # Concatenate along width
64
+ garment_mask = torch.zeros_like(mask_tensor)
65
+
66
+ if try_on:
67
+ extended_mask = torch.cat([garment_mask, mask_tensor], dim=2)
68
+ else:
69
+ extended_mask = torch.cat([1 - garment_mask, mask_tensor], dim=2)
70
+
71
+ prompt = f"The pair of images highlights a clothing and its styling on a model, high resolution, 4K, 8K; " \
72
+ f"[IMAGE1] Detailed product shot of a clothing" \
73
+ f"[IMAGE2] The same cloth is worn by a model in a lifestyle setting."
74
+
75
+ generator = torch.Generator(device="cuda").manual_seed(seed)
76
+ result = self.pipe(
77
+ height=size[1],
78
+ width=size[0] * 2,
79
+ image=inpaint_image,
80
+ mask_image=extended_mask,
81
+ num_inference_steps=num_steps,
82
+ generator=generator,
83
+ max_sequence_length=512,
84
+ guidance_scale=guidance_scale,
85
+ prompt=prompt,
86
+ ).images[0]
87
+
88
+ # Split and save results
89
+ width = size[0]
90
+ garment_result = result.crop((0, 0, width, size[1]))
91
+ try_result = result.crop((width, 0, width * 2, size[1]))
92
+ out_path = "/tmp/try.png"
93
+ try_result.save(out_path)
94
+ garm_out_path = "/tmp/garment.png"
95
+ garment_result.save(garm_out_path)
96
+ return [Path(out_path), Path(garm_out_path)]
97
+