nupurkmr9 commited on
Commit
8880ecb
1 Parent(s): 3b1d67d

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +10 -29
inference.py CHANGED
@@ -10,7 +10,7 @@ import torch
10
  from diffusers import StableDiffusionPipeline
11
 
12
  sys.path.insert(0, 'lora')
13
- from lora_diffusion import monkeypatch_lora, tune_lora_scale
14
 
15
 
16
  class InferencePipeline:
@@ -28,24 +28,16 @@ class InferencePipeline:
28
  gc.collect()
29
 
30
  @staticmethod
31
- def get_lora_weight_path(name: str) -> pathlib.Path:
32
  curr_dir = pathlib.Path(__file__).parent
33
  return curr_dir / name
34
 
35
- @staticmethod
36
- def get_lora_text_encoder_weight_path(path: pathlib.Path) -> str:
37
- parent_dir = path.parent
38
- stem = path.stem
39
- text_encoder_filename = f'{stem}.text_encoder.pt'
40
- path = parent_dir / text_encoder_filename
41
- return path.as_posix() if path.exists() else ''
42
-
43
- def load_pipe(self, model_id: str, lora_filename: str) -> None:
44
- weight_path = self.get_lora_weight_path(lora_filename)
45
  if weight_path == self.weight_path:
46
  return
47
  self.weight_path = weight_path
48
- lora_weight = torch.load(self.weight_path, map_location=self.device)
49
 
50
  if self.device.type == 'cpu':
51
  pipe = StableDiffusionPipeline.from_pretrained(model_id)
@@ -54,40 +46,29 @@ class InferencePipeline:
54
  model_id, torch_dtype=torch.float16)
55
  pipe = pipe.to(self.device)
56
 
57
- monkeypatch_lora(pipe.unet, lora_weight)
58
-
59
- lora_text_encoder_weight_path = self.get_lora_text_encoder_weight_path(
60
- weight_path)
61
- if lora_text_encoder_weight_path:
62
- lora_text_encoder_weight = torch.load(
63
- lora_text_encoder_weight_path, map_location=self.device)
64
- monkeypatch_lora(pipe.text_encoder,
65
- lora_text_encoder_weight,
66
- target_replace_module=['CLIPAttention'])
67
 
68
  self.pipe = pipe
69
 
70
  def run(
71
  self,
72
  base_model: str,
73
- lora_weight_name: str,
74
  prompt: str,
75
- alpha: float,
76
- alpha_for_text: float,
77
  seed: int,
78
  n_steps: int,
79
  guidance_scale: float,
 
80
  ) -> PIL.Image.Image:
81
  if not torch.cuda.is_available():
82
  raise gr.Error('CUDA is not available.')
83
 
84
- self.load_pipe(base_model, lora_weight_name)
85
 
86
  generator = torch.Generator(device=self.device).manual_seed(seed)
87
- tune_lora_scale(self.pipe.unet, alpha) # type: ignore
88
- tune_lora_scale(self.pipe.text_encoder, alpha_for_text) # type: ignore
89
  out = self.pipe(prompt,
90
  num_inference_steps=n_steps,
91
  guidance_scale=guidance_scale,
 
92
  generator=generator) # type: ignore
93
  return out.images[0]
 
10
  from diffusers import StableDiffusionPipeline
11
 
12
  sys.path.insert(0, 'lora')
13
+ from src import sample_diffuser, diffuser_training
14
 
15
 
16
  class InferencePipeline:
 
28
  gc.collect()
29
 
30
  @staticmethod
31
+ def get_weight_path(name: str) -> pathlib.Path:
32
  curr_dir = pathlib.Path(__file__).parent
33
  return curr_dir / name
34
 
35
+ def load_pipe(self, model_id: str, filename: str) -> None:
36
+ weight_path = self.get_weight_path(filename)
 
 
 
 
 
 
 
 
37
  if weight_path == self.weight_path:
38
  return
39
  self.weight_path = weight_path
40
+ weight = torch.load(self.weight_path, map_location=self.device)
41
 
42
  if self.device.type == 'cpu':
43
  pipe = StableDiffusionPipeline.from_pretrained(model_id)
 
46
  model_id, torch_dtype=torch.float16)
47
  pipe = pipe.to(self.device)
48
 
49
+ diffuser_training.load_model(pipe.text_encoder, pipe.tokenizer, pipe.unet, weight_path, '<new1>')
 
 
 
 
 
 
 
 
 
50
 
51
  self.pipe = pipe
52
 
53
  def run(
54
  self,
55
  base_model: str,
56
+ weight_name: str,
57
  prompt: str,
 
 
58
  seed: int,
59
  n_steps: int,
60
  guidance_scale: float,
61
+ eta: float,
62
  ) -> PIL.Image.Image:
63
  if not torch.cuda.is_available():
64
  raise gr.Error('CUDA is not available.')
65
 
66
+ self.load_pipe(base_model, weight_name)
67
 
68
  generator = torch.Generator(device=self.device).manual_seed(seed)
 
 
69
  out = self.pipe(prompt,
70
  num_inference_steps=n_steps,
71
  guidance_scale=guidance_scale,
72
+ eta = eta,
73
  generator=generator) # type: ignore
74
  return out.images[0]