geekyrakshit commited on
Commit
c493374
1 Parent(s): 4823cf7

update: app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -18
app.py CHANGED
@@ -1,4 +1,3 @@
1
- import os
2
  import shutil
3
 
4
  import gradio as gr
@@ -7,7 +6,7 @@ import torch
7
  from diffusers import DiffusionPipeline
8
 
9
  import wandb
10
- from wandb_addons.diffusers import get_wandb_callback
11
 
12
 
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -36,21 +35,9 @@ def generate_image(
36
  guidance_rescale,
37
  ):
38
  if not (wandb_api_key is None or wandb_api_key == ""):
39
- os.environ["WANDB_API_KEY"] = wandb_api_key
40
  generator = torch.Generator(device="cuda").manual_seed(seed)
41
- configs = {
42
- "guidance_scale": guidance_scale,
43
- "guidance_rescale": guidance_rescale,
44
- "seed": seed,
45
- }
46
- callback = get_wandb_callback(
47
- pipeline,
48
- prompt=prompt,
49
- negative_prompt=negative_prompt,
50
- wandb_project=wandb_project,
51
- num_inference_steps=num_inference_steps,
52
- configs=configs,
53
- )
54
  run_url = wandb.run.get_url()
55
  image = pipeline(
56
  prompt,
@@ -59,9 +46,10 @@ def generate_image(
59
  height=height,
60
  width=width,
61
  num_inference_steps=num_inference_steps,
62
- callback=callback,
63
- guidance_rescale=configs["guidance_rescale"],
64
  ).images[0]
 
65
  if torch.cuda.is_available():
66
  torch.cuda.empty_cache()
67
  shutil.rmtree("wandb")
 
 
1
  import shutil
2
 
3
  import gradio as gr
 
6
  from diffusers import DiffusionPipeline
7
 
8
  import wandb
9
+ from wandb.integration.diffusers import autolog
10
 
11
 
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
35
  guidance_rescale,
36
  ):
37
  if not (wandb_api_key is None or wandb_api_key == ""):
38
+ wandb.login(key=wandb_api_key, relogin=True)
39
  generator = torch.Generator(device="cuda").manual_seed(seed)
40
+ autolog(init={"project": wandb_project})
 
 
 
 
 
 
 
 
 
 
 
 
41
  run_url = wandb.run.get_url()
42
  image = pipeline(
43
  prompt,
 
46
  height=height,
47
  width=width,
48
  num_inference_steps=num_inference_steps,
49
+ guidance_scale=guidance_scale,
50
+ guidance_rescale=guidance_rescale,
51
  ).images[0]
52
+ wandb.finish()
53
  if torch.cuda.is_available():
54
  torch.cuda.empty_cache()
55
  shutil.rmtree("wandb")