munjed commited on
Commit
613a292
·
1 Parent(s): fbe90a0
Files changed (1) hide show
  1. app.py +5 -5
app.py CHANGED
@@ -4,17 +4,17 @@ import numpy as np
4
  import gradio as gr
5
  import os
6
 
7
- #from model import ChaoticCoherentGenerator
8
- from model_256 import EfficientChaoticGenerator
9
  from feature_extractor import CodeFeatureExtractor
10
 
11
  # ------------------- Device -------------------
12
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
 
14
  # ------------------- Load Model -------------------
15
- model = EfficientChaoticGenerator().to(device)
16
- checkpoint = torch.load("models/jid_model_256_e10.pth", map_location=device)
17
- model.load_state_dict(checkpoint["model_state_dict"])
18
  model.eval()
19
 
20
  extractor = CodeFeatureExtractor()
 
4
  import gradio as gr
5
  import os
6
 
7
+ from model import ChaoticCoherentGenerator
8
+ # from model_256 import EfficientChaoticGenerator
9
  from feature_extractor import CodeFeatureExtractor
10
 
11
  # ------------------- Device -------------------
12
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
 
14
  # ------------------- Load Model -------------------
15
+ model = ChaoticCoherentGenerator().to(device)
16
+ checkpoint = torch.load("models/chotic.pth", map_location=device)
17
+ model.load_state_dict(checkpoint["state_dict"])
18
  model.eval()
19
 
20
  extractor = CodeFeatureExtractor()