Yardenfren commited on
Commit
078397a
1 Parent(s): 2ad93a6

Update inf.py

Browse files
Files changed (1) hide show
  1. inf.py +25 -21
inf.py CHANGED
@@ -18,15 +18,6 @@ class InferencePipeline:
18
  self.hf_token = hf_token
19
  self.base_model_id = "stabilityai/stable-diffusion-xl-base-1.0"
20
  self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
21
- # if self.device.type == 'cpu':
22
- # self.pipe = StableDiffusionXLPipeline.from_pretrained(
23
- # self.base_model_id, use_auth_token=self.hf_token)
24
- # else:
25
- # self.pipe = StableDiffusionXLPipeline.from_pretrained(
26
- # self.base_model_id,
27
- # torch_dtype=torch.float16,
28
- # use_auth_token=self.hf_token)
29
- # self.pipe = self.pipe.to(self.device)
30
  self.pipe = StableDiffusionXLPipeline.from_pretrained(
31
  self.base_model_id,
32
  torch_dtype=torch.float16,
@@ -98,15 +89,10 @@ class InferencePipeline:
98
 
99
  self.content_lora_model_id = content_lora_model_id
100
  self.style_lora_model_id = style_lora_model_id
101
-
102
  @spaces.GPU
103
- def run(
104
- self,
105
- content_lora_model_id: str,
106
- style_lora_model_id: str,
107
  prompt: str,
108
- content_alpha: float,
109
- style_alpha: float,
110
  seed: int,
111
  n_steps: int,
112
  guidance_scale: float,
@@ -114,13 +100,8 @@ class InferencePipeline:
114
  ) -> PIL.Image.Image:
115
  if not torch.cuda.is_available():
116
  raise gr.Error('CUDA is not available.')
117
-
118
- self.load_pipe(content_lora_model_id, style_lora_model_id, content_alpha, style_alpha)
119
-
120
  self.pipe.to("cuda")
121
-
122
  generator = torch.Generator(device="cuda").manual_seed(seed)
123
- print(self.pipe.device)
124
  out = self.pipe(
125
  prompt,
126
  num_inference_steps=n_steps,
@@ -129,3 +110,26 @@ class InferencePipeline:
129
  num_images_per_prompt=num_images_per_prompt,
130
  ) # type: ignore
131
  return out.images
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  self.hf_token = hf_token
19
  self.base_model_id = "stabilityai/stable-diffusion-xl-base-1.0"
20
  self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
 
 
 
 
 
 
 
 
21
  self.pipe = StableDiffusionXLPipeline.from_pretrained(
22
  self.base_model_id,
23
  torch_dtype=torch.float16,
 
89
 
90
  self.content_lora_model_id = content_lora_model_id
91
  self.style_lora_model_id = style_lora_model_id
92
+
93
  @spaces.GPU
94
+ def inference(self,
 
 
 
95
  prompt: str,
 
 
96
  seed: int,
97
  n_steps: int,
98
  guidance_scale: float,
 
100
  ) -> PIL.Image.Image:
101
  if not torch.cuda.is_available():
102
  raise gr.Error('CUDA is not available.')
 
 
 
103
  self.pipe.to("cuda")
 
104
  generator = torch.Generator(device="cuda").manual_seed(seed)
 
105
  out = self.pipe(
106
  prompt,
107
  num_inference_steps=n_steps,
 
110
  num_images_per_prompt=num_images_per_prompt,
111
  ) # type: ignore
112
  return out.images
113
+
114
+
115
+ def run(
116
+ self,
117
+ content_lora_model_id: str,
118
+ style_lora_model_id: str,
119
+ prompt: str,
120
+ content_alpha: float,
121
+ style_alpha: float,
122
+ seed: int,
123
+ n_steps: int,
124
+ guidance_scale: float,
125
+ num_images_per_prompt: int = 1
126
+ ) -> PIL.Image.Image:
127
+
128
+ self.load_pipe(content_lora_model_id, style_lora_model_id, content_alpha, style_alpha)
129
+
130
+ return self.inference(
131
+ prompt=prompt,
132
+ n_steps=n_steps,
133
+ guidance_scale=guidance_scale,
134
+ num_images_per_prompt=num_images_per_prompt,
135
+ )