radames commited on
Commit
d8457bc
1 Parent(s): 5f1aa51
server/pipelines/pix2pix/pix2pix_turbo.py CHANGED
@@ -153,6 +153,7 @@ class Pix2Pix_Turbo(torch.nn.Module):
153
  self.caption_enc = None
154
  self.device = "cuda"
155
 
 
156
  def forward(self, c_t, prompt, deterministic=True, r=1.0, noise_map=1.0):
157
  # encode the text prompt
158
  if prompt != self.last_prompt:
 
153
  self.caption_enc = None
154
  self.device = "cuda"
155
 
156
+ @torch.no_grad()
157
  def forward(self, c_t, prompt, deterministic=True, r=1.0, noise_map=1.0):
158
  # encode the text prompt
159
  if prompt != self.last_prompt:
server/pipelines/pix2pixTurbo.py CHANGED
@@ -5,7 +5,7 @@ from config import Args
5
  from pydantic import BaseModel, Field
6
  from PIL import Image
7
  from pipelines.pix2pix.pix2pix_turbo import Pix2Pix_Turbo
8
- from pipelines.utils.canny_gpu import SobelOperator
9
 
10
  default_prompt = "close-up photo of the joker"
11
  page_content = """
@@ -19,6 +19,11 @@ page_content = """
19
  class="text-blue-500 underline hover:no-underline">One-Step Image Translation with Text-to-Image Models
20
  </a>
21
  </p>
 
 
 
 
 
22
  """
23
 
24
 
@@ -62,7 +67,7 @@ class Pipeline:
62
  id="deterministic",
63
  )
64
  canny_low_threshold: float = Field(
65
- 0.31,
66
  min=0,
67
  max=1.0,
68
  step=0.001,
@@ -72,7 +77,7 @@ class Pipeline:
72
  id="canny_low_threshold",
73
  )
74
  canny_high_threshold: float = Field(
75
- 0.125,
76
  min=0,
77
  max=1.0,
78
  step=0.001,
@@ -91,30 +96,25 @@ class Pipeline:
91
 
92
  def __init__(self, args: Args, device: torch.device, torch_dtype: torch.dtype):
93
  self.model = Pix2Pix_Turbo("edge_to_image")
94
- self.canny_torch = SobelOperator(device=device)
95
  self.device = device
96
  self.last_time = 0.0
97
 
98
  def predict(self, params: "Pipeline.InputParams") -> Image.Image:
99
- # generator = torch.manual_seed(params.seed)
100
- # pipe = self.pipes[params.base_model_id]
101
-
102
  canny_pil, canny_tensor = self.canny_torch(
103
  params.image,
104
  params.canny_low_threshold,
105
  params.canny_high_threshold,
106
  output_type="pil,tensor",
107
  )
108
-
109
- with torch.no_grad():
110
- canny_tensor = torch.cat((canny_tensor, canny_tensor, canny_tensor), dim=1)
111
- output_image = self.model(
112
- canny_tensor,
113
- params.prompt,
114
- params.deterministic,
115
- params.strength,
116
- )
117
- output_pil = transforms.ToPILImage()(output_image[0].cpu() * 0.5 + 0.5)
118
 
119
  result_image = output_pil
120
  if params.debug_canny:
 
5
  from pydantic import BaseModel, Field
6
  from PIL import Image
7
  from pipelines.pix2pix.pix2pix_turbo import Pix2Pix_Turbo
8
+ from pipelines.utils.canny_gpu import ScharrOperator
9
 
10
  default_prompt = "close-up photo of the joker"
11
  page_content = """
 
19
  class="text-blue-500 underline hover:no-underline">One-Step Image Translation with Text-to-Image Models
20
  </a>
21
  </p>
22
+ <p class="text-sm text-gray-500">
23
+ Web app <a href="https://github.com/radames/Real-Time-Latent-Consistency-Model" target="_blank" class="text-blue-500 underline hover:no-underline">
24
+ Real-Time Latent Consistency Models
25
+ </a>
26
+ </p>
27
  """
28
 
29
 
 
67
  id="deterministic",
68
  )
69
  canny_low_threshold: float = Field(
70
+ 0.0,
71
  min=0,
72
  max=1.0,
73
  step=0.001,
 
77
  id="canny_low_threshold",
78
  )
79
  canny_high_threshold: float = Field(
80
+ 1.0,
81
  min=0,
82
  max=1.0,
83
  step=0.001,
 
96
 
97
  def __init__(self, args: Args, device: torch.device, torch_dtype: torch.dtype):
98
  self.model = Pix2Pix_Turbo("edge_to_image")
99
+ self.canny_torch = ScharrOperator(device=device)
100
  self.device = device
101
  self.last_time = 0.0
102
 
103
  def predict(self, params: "Pipeline.InputParams") -> Image.Image:
 
 
 
104
  canny_pil, canny_tensor = self.canny_torch(
105
  params.image,
106
  params.canny_low_threshold,
107
  params.canny_high_threshold,
108
  output_type="pil,tensor",
109
  )
110
+ canny_tensor = torch.cat((canny_tensor, canny_tensor, canny_tensor), dim=1)
111
+ output_image = self.model(
112
+ canny_tensor,
113
+ params.prompt,
114
+ params.deterministic,
115
+ params.strength,
116
+ )
117
+ output_pil = transforms.ToPILImage()(output_image[0].cpu() * 0.5 + 0.5)
 
 
118
 
119
  result_image = output_pil
120
  if params.debug_canny: