Sophie98 commited on
Commit
acda9fe
1 Parent(s): baaaa83
Files changed (1) hide show
  1. styleTransfer.py +18 -17
styleTransfer.py CHANGED
@@ -44,29 +44,30 @@ decoder = StyTR.decoder
44
  Trans = transformer.Transformer()
45
  embedding = StyTR.PatchEmbed()
46
 
47
- decoder.eval()
48
- Trans.eval()
49
- vgg.eval()
50
 
51
- new_state_dict = OrderedDict()
52
- state_dict = torch.load(decoder_path)
53
- decoder.load_state_dict(state_dict)
54
 
55
- new_state_dict = OrderedDict()
56
- state_dict = torch.load(Trans_path)
57
- Trans.load_state_dict(state_dict)
58
 
59
- new_state_dict = OrderedDict()
60
- state_dict = torch.load(embedding_path)
61
- embedding.load_state_dict(state_dict)
62
 
63
- network = StyTR.StyTrans(vgg,decoder,embedding,Trans)
64
- network.eval()
 
65
 
66
- content_tf = content_transform()
67
- style_tf = style_transform(style_size,style_size)
 
 
 
68
 
69
- def StyleTransformer(content_img: Image, style_img: Image):
70
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
71
  network.to(device)
72
  content = content_tf(content_img.convert("RGB"))
44
  Trans = transformer.Transformer()
45
  embedding = StyTR.PatchEmbed()
46
 
47
+ def StyleTransformer(content_img: Image, style_img: Image):
 
 
48
 
49
+ decoder.eval()
50
+ Trans.eval()
51
+ vgg.eval()
52
 
53
+ new_state_dict = OrderedDict()
54
+ state_dict = torch.load(decoder_path)
55
+ decoder.load_state_dict(state_dict)
56
 
57
+ new_state_dict = OrderedDict()
58
+ state_dict = torch.load(Trans_path)
59
+ Trans.load_state_dict(state_dict)
60
 
61
+ new_state_dict = OrderedDict()
62
+ state_dict = torch.load(embedding_path)
63
+ embedding.load_state_dict(state_dict)
64
 
65
+ network = StyTR.StyTrans(vgg,decoder,embedding,Trans)
66
+ network.eval()
67
+
68
+ content_tf = content_transform()
69
+ style_tf = style_transform(style_size,style_size)
70
 
 
71
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
72
  network.to(device)
73
  content = content_tf(content_img.convert("RGB"))