zswwsz commited on
Commit
7792673
1 Parent(s): 8e62f1f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -27
app.py CHANGED
@@ -221,7 +221,7 @@ def res2net152_v1b_26w_4s(pretrained=False, **kwargs):
221
 
222
  class mutil_model(nn.Module):
223
 
224
- def __init__(self, category_num=7):
225
  super(mutil_model, self).__init__()
226
  self.model1 = res2net50_v1b_26w_4s(pretrained=False)
227
  self.model1.fc = nn.Sequential(
@@ -241,39 +241,40 @@ class mutil_model(nn.Module):
241
  return x
242
 
243
 
244
- pth_path = './res2net_model_6_new.pt'
245
- category_num = 6
246
 
247
  # "cuda" only when GPUs are available.
248
- device = "cuda" if torch.cuda.is_available() else "cpu"
249
-
250
- # Initialize a model, and put it on the device specified.
251
  # 导入res2net预训练模型
252
- # pthfile = '/cbd_lixiaogang_lixianneng/morror_art/pre_train_model/res2net50_v1b.pth'
253
- # model = res2net50_v1b_26w_4s(pretrained=False)
254
  # 修改全连接层,输出维度为预测 分类
255
- # num_ftrs = model.fc.in_features
256
- # model.fc = nn.Sequential(
257
- # nn.Linear(in_features=2048, out_features=1000, bias=True),
258
- # nn.Dropout(0.5),
259
- # nn.Linear(1000, out_features=category_num)
260
- # )
261
- # model.fc = nn.Sequential(
262
- # nn.Linear(in_features=2048, out_features=category_num, bias=True),
263
- # )
264
- # model = model.to(device)
265
- # model.device = device
266
- # model.load_state_dict(torch.load(pth_path,torch.device('cpu')))
267
- # model.eval()
 
268
 
269
 
270
  # 增加人脸识别模型
271
- model = mutil_model(category_num=7)
272
- model_state = torch.load('./res2net_model_6_new.pt', map_location=torch.device('cpu')).state_dict()
273
- model.load_state_dict(model_state) # 加载模型参数
274
- model.eval()
275
 
276
- labels = ['伤感', '开心', '励志', '宣泄', '平静', '感人']
277
 
278
  import requests
279
  import torch
@@ -327,4 +328,5 @@ gr.Interface(
327
  # inputs='image',
328
  # outputs='label',
329
  # examples=[["images/cheetah1.jpg"], ["images/lion.jpg"]],
330
- ).launch(debug=True, share=True)
 
 
221
 
222
  class mutil_model(nn.Module):
223
 
224
+ def __init__(self, category_num=10):
225
  super(mutil_model, self).__init__()
226
  self.model1 = res2net50_v1b_26w_4s(pretrained=False)
227
  self.model1.fc = nn.Sequential(
 
241
  return x
242
 
243
 
244
+ pth_path = './res2net_pretrain_model_999.pt'
245
+ category_num = 9
246
 
247
  # "cuda" only when GPUs are available.
248
+ #device = "cuda" if torch.cuda.is_available() else "cpu"
249
+ device = "cpu"
250
+ #Initialize a model, and put it on the device specified.
251
  # 导入res2net预训练模型
252
+ # pthfile = './res2net50_v1b.pth'
253
+ model = res2net50_v1b_26w_4s(pretrained=False)
254
  # 修改全连接层,输出维度为预测 分类
255
+ num_ftrs = model.fc.in_features
256
+ model.fc = nn.Sequential(
257
+ nn.Linear(in_features=2048, out_features=1000, bias=True),
258
+ nn.Dropout(0.5),
259
+ nn.Linear(1000, out_features=category_num)
260
+ )
261
+ model.fc = nn.Sequential(
262
+ nn.Linear(in_features=2048, out_features=category_num, bias=True),
263
+ )
264
+
265
+ model = model.to(device)
266
+ model.device = device
267
+ model.load_state_dict(torch.load(pth_path,torch.device('cpu')))
268
+ model.eval()
269
 
270
 
271
  # 增加人脸识别模型
272
+ #model = mutil_model(category_num=7)
273
+ #model_state = torch.load('./add_face_emotion_model_7.pt', map_location=torch.device('cpu')).state_dict()
274
+ #model.load_state_dict(model_state) # 加载模型参数
275
+ #model.eval()
276
 
277
+ labels = ['中国风', '古典', '电子', '摇滚', '乡村', '说唱', '民谣', '动漫', '现代']
278
 
279
  import requests
280
  import torch
 
328
  # inputs='image',
329
  # outputs='label',
330
  # examples=[["images/cheetah1.jpg"], ["images/lion.jpg"]],
331
+ ).launch(share=True)
332
+ # share=True