Files changed (1) hide show
  1. app.py +27 -4
app.py CHANGED
@@ -14,6 +14,27 @@ import re
14
  import matplotlib.pyplot as plt
15
  from transformers import AutoConfig
16
  from PIL import Image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
 
19
  # Set title
@@ -31,15 +52,17 @@ btn = st.button("Predict")
31
  # Create a function to load model
32
  @st.cache(allow_output_mutation=True)
33
  def load_model():
34
- model_path = './model_1_updated.pth' # Update with the correct file name
35
 
 
 
36
  # Load the model configuration
37
- config = AutoConfig.from_pretrained(model_path)
38
 
39
  # Load the model and tokenizer from the specified file path
40
- model = AutoModelForSequenceClassification.from_pretrained(model_path, config = config)
41
  tokenizer = AutoTokenizer.from_pretrained(model_path)
42
- return model, tokenizer
43
 
44
  # Create a function to predict the input image file using the loaded model
45
  def predict(image_file):
 
14
  import matplotlib.pyplot as plt
15
  from transformers import AutoConfig
16
  from PIL import Image
17
+ from facenet_pytorch import InceptionResnetV1
18
+ vggface = InceptionResnetV1(pretrained='vggface2')
19
+
20
+
21
+ class classifier_vggface(nn.Module):
22
+ def __init__(self):
23
+ super(classifier_vggface, self).__init__()
24
+ self.encoder= vggface
25
+ self.classifier= nn.Sequential(
26
+ nn.Linear(512, 512),
27
+ nn.BatchNorm1d(512),
28
+ nn.ReLU(inplace=True),
29
+ nn.Linear(512, 128),
30
+ nn.BatchNorm1d(128),
31
+ nn.ReLU(inplace=True),
32
+ nn.Linear(128, 2),
33
+ )
34
+ def forward(self, x):
35
+ x= self.encoder(x)
36
+ x= self.classifier(x)
37
+ return x
38
 
39
 
40
  # Set title
 
52
  # Create a function to load model
53
  @st.cache(allow_output_mutation=True)
54
  def load_model():
55
+ model1= classifier_vggface()
56
 
57
+ model_path = './model1.pt' # Update with the correct file name
58
+ model1.load_state_dict(torch.load(model_path))
59
  # Load the model configuration
60
+ # config = AutoConfig.from_pretrained(model_path)
61
 
62
  # Load the model and tokenizer from the specified file path
63
+ # model = AutoModelForSequenceClassification.from_pretrained(model_path, config = config)
64
  tokenizer = AutoTokenizer.from_pretrained(model_path)
65
+ return model1, tokenizer
66
 
67
  # Create a function to predict the input image file using the loaded model
68
  def predict(image_file):