bwang0911 commited on
Commit
53fce83
1 Parent(s): 57d7f74

fix model loading in custom_st.py (#5)

Browse files

- fix model loading in custom_st.py (a2fdae736ab4270d2d67c71a5abd936ac32f5cb7)

Files changed (1) hide show
  1. custom_st.py +1 -2
custom_st.py CHANGED
@@ -6,7 +6,6 @@ from typing import Any, Dict, List, Optional, Tuple, Union
6
 
7
  import requests
8
  import torch
9
- from PIL import Image
10
  from torch import nn
11
  from transformers import AutoConfig, AutoImageProcessor, AutoModel, AutoTokenizer
12
 
@@ -55,7 +54,7 @@ class Transformer(nn.Module):
55
  config_args = {}
56
 
57
  config = AutoConfig.from_pretrained(model_name_or_path, **config_args, cache_dir=cache_dir)
58
- self._load_model(model_name_or_path, config, cache_dir, **model_args)
59
 
60
  if max_seq_length is not None and "model_max_length" not in tokenizer_args:
61
  tokenizer_args["model_max_length"] = max_seq_length
 
6
 
7
  import requests
8
  import torch
 
9
  from torch import nn
10
  from transformers import AutoConfig, AutoImageProcessor, AutoModel, AutoTokenizer
11
 
 
54
  config_args = {}
55
 
56
  config = AutoConfig.from_pretrained(model_name_or_path, **config_args, cache_dir=cache_dir)
57
+ self.auto_model = AutoModel.from_pretrained(model_name_or_path, **config_args, cache_dir=cache_dir, **model_args)
58
 
59
  if max_seq_length is not None and "model_max_length" not in tokenizer_args:
60
  tokenizer_args["model_max_length"] = max_seq_length