wjs0725 commited on
Commit
4de0080
·
verified ·
1 Parent(s): 6fb545f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -30
app.py CHANGED
@@ -26,6 +26,7 @@ login(token=os.getenv('Token'))
26
  import torch
27
 
28
  device = torch.cuda.current_device()
 
29
  total_memory = torch.cuda.get_device_properties(device).total_memory
30
  allocated_memory = torch.cuda.memory_allocated(device)
31
  reserved_memory = torch.cuda.memory_reserved(device)
@@ -34,6 +35,14 @@ print(f"Total memory: {total_memory / 1024**2:.2f} MB")
34
  print(f"Allocated memory: {allocated_memory / 1024**2:.2f} MB")
35
  print(f"Reserved memory: {reserved_memory / 1024**2:.2f} MB")
36
 
 
 
 
 
 
 
 
 
37
 
38
  @dataclass
39
  class SamplingOptions:
@@ -54,37 +63,8 @@ is_schnell = False
54
  feature_path = 'feature'
55
  output_dir = 'result'
56
  add_sampling_metadata = True
57
- # class FluxEditor:
58
- # def __init__(self, args):
59
- # self.args = args
60
- # self.device = torch.device(args.device)
61
- # self.offload = args.offload
62
- # self.name = args.name
63
- # self.is_schnell = args.name == "flux-schnell"
64
-
65
- # self.feature_path = 'feature'
66
- # self.output_dir = 'result'
67
- # self.add_sampling_metadata = True
68
-
69
- # if self.name not in configs:
70
- # available = ", ".join(configs.keys())
71
- # raise ValueError(f"Got unknown model name: {name}, chose from {available}")
72
-
73
- # # init all components
74
-
75
 
76
- # if self.offload:
77
- # self.model.cpu()
78
- # torch.cuda.empty_cache()
79
- # self.ae.encoder.to(self.device)
80
- ae = load_ae(name, device="cpu" if offload else device)
81
- t5 = load_t5(device, max_length=256 if name == "flux-schnell" else 512)
82
- clip = load_clip(device)
83
- model = load_flow_model(name, device="cpu" if offload else device)
84
- print("!!!!!!!!!!!!device!!!!!!!!!!!!!!",device)
85
- print("!!!!!!!!self.t5!!!!!!",next(t5.parameters()).device)
86
- print("!!!!!!!!self.clip!!!!!!",next(clip.parameters()).device)
87
- print("!!!!!!!!self.model!!!!!!",next(model.parameters()).device)
88
 
89
  @torch.inference_mode()
90
  def encode(init_image, torch_device, ae):
 
26
  import torch
27
 
28
  device = torch.cuda.current_device()
29
+ print("!!!!!!!!!!!!device!!!!!!!!!!!!!!",device)
30
  total_memory = torch.cuda.get_device_properties(device).total_memory
31
  allocated_memory = torch.cuda.memory_allocated(device)
32
  reserved_memory = torch.cuda.memory_reserved(device)
 
35
  print(f"Allocated memory: {allocated_memory / 1024**2:.2f} MB")
36
  print(f"Reserved memory: {reserved_memory / 1024**2:.2f} MB")
37
 
38
+ ae = load_ae(name, device)
39
+ t5 = load_t5(device, max_length=256 if name == "flux-schnell" else 512)
40
+ clip = load_clip(device)
41
+ model = load_flow_model(name, device=device)
42
+ print("!!!!!!!!!!!!device!!!!!!!!!!!!!!",device)
43
+ print("!!!!!!!!self.t5!!!!!!",next(t5.parameters()).device)
44
+ print("!!!!!!!!self.clip!!!!!!",next(clip.parameters()).device)
45
+ print("!!!!!!!!self.model!!!!!!",next(model.parameters()).device)
46
 
47
  @dataclass
48
  class SamplingOptions:
 
63
  feature_path = 'feature'
64
  output_dir = 'result'
65
  add_sampling_metadata = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
+
 
 
 
 
 
 
 
 
 
 
 
68
 
69
  @torch.inference_mode()
70
  def encode(init_image, torch_device, ae):