maze commited on
Commit
fcb49fd
β€’
1 Parent(s): 34bfe14

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -1
app.py CHANGED
@@ -2,6 +2,7 @@ from huggingface_hub import hf_hub_download
2
 
3
 
4
  Rain_Princess = hf_hub_download(repo_id="maze/FastStyleTransfer", filename="Rain_Princess_512.pth")
 
5
  #modelarcanev3 = hf_hub_download(repo_id="akhaliq/ArcaneGANv0.3", filename="ArcaneGANv0.3.jit")
6
  #modelarcanev2 = hf_hub_download(repo_id="akhaliq/ArcaneGANv0.2", filename="ArcaneGANv0.2.jit")
7
 
@@ -136,7 +137,12 @@ def process(image, model):
136
 
137
 
138
  def main(image, backbone, style):
139
- transformer.load_state_dict(torch.load(Rain_Princess, map_location=torch.device('cpu')))
 
 
 
 
 
140
  image = Image.fromarray(image)
141
  isize = image.size
142
  image = process(image, transformer)
 
2
 
3
 
4
  Rain_Princess = hf_hub_download(repo_id="maze/FastStyleTransfer", filename="Rain_Princess_512.pth")
5
+ The_Scream = hf_hub_download(repo_id="maze/FastStyleTransfer", filename="Scream_512.pth")
6
  #modelarcanev3 = hf_hub_download(repo_id="akhaliq/ArcaneGANv0.3", filename="ArcaneGANv0.3.jit")
7
  #modelarcanev2 = hf_hub_download(repo_id="akhaliq/ArcaneGANv0.2", filename="ArcaneGANv0.2.jit")
8
 
 
137
 
138
 
139
  def main(image, backbone, style):
140
+ if style == "The Scream":
141
+ transformer.load_state_dict(torch.load(The_Scream, map_location=torch.device('cpu')))
142
+ elif style == "Rain Princess":
143
+ transformer.load_state_dict(torch.load(Rain_Princess, map_location=torch.device('cpu')))
144
+ else:
145
+ transformer.load_state_dict(torch.load(Rain_Princess, map_location=torch.device('cpu')))
146
  image = Image.fromarray(image)
147
  isize = image.size
148
  image = process(image, transformer)