SerdarHelli commited on
Commit
a87d02f
1 Parent(s): d25fdb0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -17
app.py CHANGED
@@ -52,11 +52,7 @@ cars=hf_hub_download("SerdarHelli/SDF-StyleGAN-3D", filename="cars.ckpt",revisio
52
 
53
  #default model
54
  device='cuda' if torch.cuda.is_available() else 'cpu'
55
- if device=="cuda":
56
- model = StyleGAN2_3D.load_from_checkpoint(cars).cuda(0)
57
- else:
58
- model = StyleGAN2_3D_not_cuda.load_from_checkpoint(cars)
59
- model.eval()
60
 
61
 
62
  models={"Car":cars,
@@ -74,15 +70,9 @@ def seed_all(seed):
74
  random.seed(seed)
75
 
76
 
77
- def change_model(ckpt_path):
78
- if device=="cuda":
79
- model = StyleGAN2_3D.load_from_checkpoint(ckpt_path).cuda(0)
80
- else:
81
- model = StyleGAN2_3D_not_cuda.load_from_checkpoint(ckpt_path)
82
- model.eval()
83
 
84
 
85
- def predict(seed,trunc_psi):
86
  if seed==None:
87
  seed=777
88
  seed_all(seed)
@@ -107,9 +97,22 @@ def predict(seed,trunc_psi):
107
  return x,y,z,i,j,k
108
 
109
  def generate(seed,model_name,trunc_psi):
110
- if model_name:
111
- change_model(models[model_name])
112
- x,y,z,i,j,k=predict(seed,trunc_psi)
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
 
115
  fig = go.Figure(go.Mesh3d(x=x, y=y, z=z,
@@ -131,7 +134,6 @@ def generate(seed,model_name,trunc_psi):
131
 
132
  markdown=f'''
133
  # SDF-StyleGAN: Implicit SDF-Based StyleGAN for 3D Shape Generation
134
-
135
 
136
  [The space demo for the SGP 2022 paper "SDF-StyleGAN: Implicit SDF-Based StyleGAN for 3D Shape Generation".](https://arxiv.org/abs/2206.12055)
137
 
@@ -150,7 +152,7 @@ with gr.Blocks() as demo:
150
  gr.Markdown(markdown)
151
  with gr.Row():
152
  seed = gr.Slider( minimum=0, maximum=2**16,label='Seed')
153
- model_name=gr.Dropdown(choices=["Car","Airplane","Chair","Rifle","Table"],label="Choose Model Type")
154
  trunc_psi = gr.Slider( minimum=0, maximum=2,label='Truncate PSI')
155
 
156
  btn = gr.Button(value="Generate")
 
52
 
53
  #default model
54
  device='cuda' if torch.cuda.is_available() else 'cpu'
55
+
 
 
 
 
56
 
57
 
58
  models={"Car":cars,
 
70
  random.seed(seed)
71
 
72
 
 
 
 
 
 
 
73
 
74
 
75
+ def predict(seed,model,trunc_psi):
76
  if seed==None:
77
  seed=777
78
  seed_all(seed)
 
97
  return x,y,z,i,j,k
98
 
99
  def generate(seed,model_name,trunc_psi):
100
+ print(model_name)
101
+ try :
102
+ ckpt=models[model_name]
103
+ except KeyError:
104
+ ckpt=cars
105
+
106
+
107
+ if device=="cuda":
108
+ model = StyleGAN2_3D.load_from_checkpoint(ckpt).cuda(0)
109
+ else:
110
+ model = StyleGAN2_3D_not_cuda.load_from_checkpoint(ckpt)
111
+ model.eval()
112
+
113
+
114
+
115
+ x,y,z,i,j,k=predict(seed,model,trunc_psi)
116
 
117
 
118
  fig = go.Figure(go.Mesh3d(x=x, y=y, z=z,
 
134
 
135
  markdown=f'''
136
  # SDF-StyleGAN: Implicit SDF-Based StyleGAN for 3D Shape Generation
 
137
 
138
  [The space demo for the SGP 2022 paper "SDF-StyleGAN: Implicit SDF-Based StyleGAN for 3D Shape Generation".](https://arxiv.org/abs/2206.12055)
139
 
 
152
  gr.Markdown(markdown)
153
  with gr.Row():
154
  seed = gr.Slider( minimum=0, maximum=2**16,label='Seed')
155
+ model_name=gr.Dropdown(choices=["Car","Airplane"],label="Choose Model Type")
156
  trunc_psi = gr.Slider( minimum=0, maximum=2,label='Truncate PSI')
157
 
158
  btn = gr.Button(value="Generate")