File size: 4,764 Bytes
6124669
 
 
 
 
c9174f6
 
6124669
 
c9174f6
6124669
 
 
 
 
 
 
 
 
 
 
c9174f6
6124669
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89dbdbc
6124669
 
 
 
 
 
 
 
 
 
 
 
 
 
8a59bca
6124669
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3a4b39d
6124669
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d2fd6d8
 
6124669
 
89dbdbc
 
d2fd6d8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import os
import sys
import torch.nn.functional as F
import torch

PACKAGE_PARENT = '..'
WISE_DIR = '../wise/'
SCRIPT_DIR = os.path.dirname(os.path.realpath(os.path.join(os.getcwd(), os.path.expanduser(__file__))))
sys.path.append(os.path.normpath(os.path.join(SCRIPT_DIR, PACKAGE_PARENT)))
sys.path.append(os.path.normpath(os.path.join(SCRIPT_DIR, WISE_DIR)))


import numpy as np
from PIL import Image
import streamlit as st
from streamlit_drawable_canvas import st_canvas

from effects.minimal_pipeline import MinimalPipelineEffect
from helpers.visual_parameter_def import minimal_pipeline_presets, minimal_pipeline_bump_mapping_preset, minimal_pipeline_xdog_preset
from helpers import torch_to_np, np_to_torch
from effects import get_default_settings
from demo_config import HUGGING_FACE

st.set_page_config(page_title="Preset Edit Demo", layout="wide")


# @st.cache(hash_funcs={OilPaintEffect: id})
@st.cache(hash_funcs={MinimalPipelineEffect: id})
def local_edits_create_effect():
    effect, preset, param_set = get_default_settings("minimal_pipeline")
    effect.enable_checkpoints()
    effect.cuda()
    return effect, param_set


effect, param_set = local_edits_create_effect()
presets = {
    "original": minimal_pipeline_presets,
    "bump mapped": minimal_pipeline_bump_mapping_preset,
    "contoured": minimal_pipeline_xdog_preset
}

st.session_state["action"] = "switch_page_from_presets" # on switchback, remember effect input

active_preset = st.sidebar.selectbox("apply preset: ", ["bump mapped", "contoured", "original"])
blend_strength = st.sidebar.slider("Parameter blending strength (non-hue) : ", 0.0, 1.0, 1.0, 0.05)
hue_blend_strength = st.sidebar.slider("Hue-shift blending strength : ", 0.0, 1.0, 1.0, 0.05)

st.sidebar.text("Drawing options:")
stroke_width = st.sidebar.slider("Stroke width: ", 1, 80, 40)
drawing_mode = st.sidebar.selectbox(
    "Drawing tool:", ("freedraw", "line", "rect", "circle", "transform")
)

st.session_state["preset_canvas_key"] ="preset_canvas"

vp = torch.clone(st.session_state["result_vp"])
org_cuda = st.session_state["effect_input"]

# @st.experimental_memo
def greyscale_original(_org_cuda, content_id): #content_id is used for hashing
    if HUGGING_FACE:
        wsize = 450
        img_org_height, img_org_width = _org_cuda.shape[-2:]
        wpercent = (wsize / float(img_org_width))
        hsize = int((float(img_org_height) * float(wpercent)))
    else:
        longest_edge = 670
        img_org_height, img_org_width = _org_cuda.shape[-2:]
        max_width_height = max(img_org_width, img_org_height)
        hsize = int((float(longest_edge) * float(float(img_org_height) / max_width_height)))
        wsize = int((float(longest_edge) * float(float(img_org_width) / max_width_height)))

    org_img = F.interpolate(_org_cuda, (hsize, wsize), mode="bilinear")
    org_img = torch.mean(org_img, dim=1, keepdim=True) / 2.0
    org_img = torch_to_np(org_img, multiply_by_255=True)[..., np.newaxis].repeat(3, axis=2)
    org_img = Image.fromarray(org_img.astype(np.uint8))
    return org_img, hsize, wsize

greyscale_img, hsize, wsize = greyscale_original(org_cuda, st.session_state["Content_id"])

coll1, coll2 = st.columns(2)
coll1.header("Draw Mask")
coll2.header("Live Result")

with coll1:
    # Create a canvas component
    canvas_result = st_canvas(
        fill_color="rgba(0, 0, 0, 1)",  # Fixed fill color with some opacity
        stroke_width=stroke_width,
        background_image=greyscale_img,
        width=greyscale_img.width,
        height=greyscale_img.height,
        drawing_mode=drawing_mode,
        key=st.session_state["preset_canvas_key"]
    )
    

res_data = None
if canvas_result.image_data is not None:
    abc = np_to_torch(canvas_result.image_data.astype(np.float32)).sum(dim=1, keepdim=True).cuda()

    img_org_width = org_cuda.shape[-1]
    img_org_height = org_cuda.shape[-2]
    res_data = F.interpolate(abc, (img_org_height, img_org_width)).squeeze(1)

    preset_tensor = effect.vpd.preset_tensor(presets[active_preset], org_cuda, add_local_dims=True)
    hue = torch.clone(vp[:,effect.vpd.name2idx["hueShift"]])
    vp[:] = preset_tensor * res_data * blend_strength + vp[:] * (1 - res_data * blend_strength)
    vp[:, effect.vpd.name2idx["hueShift"]] = \
        preset_tensor[:,effect.vpd.name2idx["hueShift"]] * res_data * hue_blend_strength + hue * (1 - res_data * hue_blend_strength)

with torch.no_grad():
    result_cuda = effect(org_cuda, vp)

img_res = Image.fromarray((torch_to_np(result_cuda) * 255.0).astype(np.uint8))
coll2.image(img_res)

print(st.session_state["user"], " edited preset")

apply_btn = st.sidebar.button("Apply")
if apply_btn:
    st.session_state["result_vp"] = vp

st.info("Note: Press apply to make changes permanent")