gmastrapas
commited on
Commit
•
6d8e609
1
Parent(s):
942a5da
fix: kwargs in custom Sentence Transformer
Browse files- custom_st.py +8 -2
custom_st.py
CHANGED
@@ -22,6 +22,8 @@ class Transformer(nn.Module):
|
|
22 |
model_kwargs: Optional[Dict[str, Any]] = None,
|
23 |
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
|
24 |
image_processor_kwargs: Optional[Dict[str, Any]] = None,
|
|
|
|
|
25 |
) -> None:
|
26 |
super(Transformer, self).__init__()
|
27 |
|
@@ -30,19 +32,23 @@ class Transformer(nn.Module):
|
|
30 |
tokenizer_kwargs = tokenizer_kwargs or {}
|
31 |
image_processor_kwargs = image_processor_kwargs or {}
|
32 |
|
33 |
-
config = AutoConfig.from_pretrained(
|
|
|
|
|
34 |
self.model = AutoModel.from_pretrained(
|
35 |
-
model_name_or_path, config=config, **model_kwargs
|
36 |
)
|
37 |
if max_seq_length is not None and 'model_max_length' not in tokenizer_kwargs:
|
38 |
tokenizer_kwargs['model_max_length'] = max_seq_length
|
39 |
|
40 |
self.tokenizer = AutoTokenizer.from_pretrained(
|
41 |
tokenizer_name_or_path or model_name_or_path,
|
|
|
42 |
**tokenizer_kwargs,
|
43 |
)
|
44 |
self.image_processor = AutoImageProcessor.from_pretrained(
|
45 |
image_processor_name_or_path or model_name_or_path,
|
|
|
46 |
**image_processor_kwargs,
|
47 |
)
|
48 |
|
|
|
22 |
model_kwargs: Optional[Dict[str, Any]] = None,
|
23 |
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
|
24 |
image_processor_kwargs: Optional[Dict[str, Any]] = None,
|
25 |
+
cache_dir: str = None,
|
26 |
+
**_,
|
27 |
) -> None:
|
28 |
super(Transformer, self).__init__()
|
29 |
|
|
|
32 |
tokenizer_kwargs = tokenizer_kwargs or {}
|
33 |
image_processor_kwargs = image_processor_kwargs or {}
|
34 |
|
35 |
+
config = AutoConfig.from_pretrained(
|
36 |
+
model_name_or_path, cache_dir=cache_dir, **config_kwargs
|
37 |
+
)
|
38 |
self.model = AutoModel.from_pretrained(
|
39 |
+
model_name_or_path, config=config, cache_dir=cache_dir, **model_kwargs
|
40 |
)
|
41 |
if max_seq_length is not None and 'model_max_length' not in tokenizer_kwargs:
|
42 |
tokenizer_kwargs['model_max_length'] = max_seq_length
|
43 |
|
44 |
self.tokenizer = AutoTokenizer.from_pretrained(
|
45 |
tokenizer_name_or_path or model_name_or_path,
|
46 |
+
cache_dir=cache_dir,
|
47 |
**tokenizer_kwargs,
|
48 |
)
|
49 |
self.image_processor = AutoImageProcessor.from_pretrained(
|
50 |
image_processor_name_or_path or model_name_or_path,
|
51 |
+
cache_dir=cache_dir,
|
52 |
**image_processor_kwargs,
|
53 |
)
|
54 |
|