Erwann Millon commited on
Commit
81f589b
1 Parent(s): 518b6f2
Files changed (6) hide show
  1. .gitmodules +0 -3
  2. app.py +0 -53
  3. attn.py +0 -1
  4. aws +0 -1
  5. color +0 -1
  6. requirements.txt +0 -1
.gitmodules CHANGED
@@ -1,3 +0,0 @@
1
- [submodule "cicdiff"]
2
- path = color
3
- url = https://github.com/ErwannMillon/Color-diffusion.git
 
 
 
 
app.py DELETED
@@ -1,53 +0,0 @@
1
- import gradio as gr
2
- from transformers import pipeline
3
- from PLModel import PLColorDiff
4
- from dataset import ColorizationDataset
5
- from utils import lab_to_rgb, split_lab
6
- import torch
7
- import default_configs
8
- from icecream import ic
9
-
10
- from unet import SimpleUnet
11
- def get_image(image):
12
- print(image)
13
- dataset = ColorizationDataset([image], split="val", config=conf, size=128)
14
- lab_img = dataset.get_tensor_from_path(image)
15
- batch = lab_img.unsqueeze(0)
16
- # x_l, _ = split_lab(batch)
17
- # bw = torch.cat((x_l, *[torch.zeros_like(x_l)] * 2), dim=1)
18
- model.eval()
19
- img = model.sample_plot_image(batch)
20
- rgb_img = lab_to_rgb(*split_lab(img))
21
- # model.test_step(batch)
22
- return(rgb_img[0])
23
- conf = SimpleUnetConfig = dict (
24
- # device = get_device(),
25
- device = "mps",
26
- pin_memory = torch.cuda.is_available(),
27
- T=300,
28
- lr=5e-4,
29
- batch_size=64,
30
- img_size = 128,
31
- sample=False,
32
- log=False,
33
- should_log=False,
34
- sample_fn = None,
35
- val_every=20,
36
- epochs=100,
37
- using_cond=False
38
- )
39
- ckpt_path = "checkpoints/epoch=1-step=706.ckpt"
40
- ckpt = torch.load(ckpt_path, map_location=torch.device("mps"))
41
- unet = SimpleUnet()
42
- model = PLColorDiff(unet, None, None)
43
- ic.disable()
44
- model.load_state_dict(ckpt["state_dict"])
45
-
46
- demo = gr.Interface(
47
- get_image,
48
- inputs=gr.inputs.Image(label="Upload a black and white face", type="filepath"),
49
- outputs="image",
50
- title="Upload a black and white face and get a colorized image!",
51
- )
52
-
53
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
attn.py DELETED
@@ -1 +0,0 @@
1
- color/attn.py
 
 
aws DELETED
@@ -1 +0,0 @@
1
- color/aws
 
 
color DELETED
@@ -1 +0,0 @@
1
- Subproject commit fd4b3285315dd8e0e612ffca7d7ecd02f6569f2d
 
 
requirements.txt DELETED
@@ -1 +0,0 @@
1
- color/requirements.txt