damerajee commited on
Commit
0d8b7c5
·
verified ·
1 Parent(s): 0dd73f3

Create text_model.py

Browse files
Files changed (1) hide show
  1. text_model.py +19 -0
text_model.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ import transformers
3
+ from .modeling_gpt2 import GPT2LMHeadModel
4
+ from .configuration_gptvision import GPT2Config
5
+
6
+ transformers.logging.set_verbosity_error()
7
+
8
+
9
+ class TextModel(nn.Module):
10
+ def __init__(self, config) -> None:
11
+ super().__init__()
12
+
13
+ if type(config.gpt2_config) == dict:
14
+ gpt2_config = GPT2Config(**config.gpt2_config)
15
+ else:
16
+ gpt2_config = config.gpt2_config
17
+
18
+ self.model = GPT2LMHeadModel(gpt2_config)
19
+ self.text_emb = self.model.get_input_embeddings()