zeroMN commited on
Commit
88cfbb9
·
verified ·
1 Parent(s): f3eac3b

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +9 -70
README.md CHANGED
@@ -118,74 +118,13 @@ Users (both direct and downstream) should be made aware of the following risks,
118
 
119
  ## How to Get Started with the Model
120
  ```python
121
- import os
122
- import torch
123
- import torch.nn as nn
124
- import numpy as np
125
- import random
126
- from transformers import (
127
- BartForConditionalGeneration,
128
- AutoModelForCausalLM,
129
- BertModel,
130
- Wav2Vec2Model,
131
- CLIPModel,
132
- AutoTokenizer
133
- )
134
-
135
- class MultiModalModel(nn.Module):
136
- def __init__(self):
137
- super(MultiModalModel, self).__init__()
138
- # 初始化子模型
139
- self.text_generator = BartForConditionalGeneration.from_pretrained('facebook/bart-base')
140
- self.code_generator = AutoModelForCausalLM.from_pretrained('gpt2')
141
- self.nlp_encoder = BertModel.from_pretrained('bert-base-uncased')
142
- self.speech_encoder = Wav2Vec2Model.from_pretrained('facebook/wav2vec2-base-960h')
143
- self.vision_encoder = CLIPModel.from_pretrained('openai/clip-vit-base-patch32')
144
-
145
- # 初始化分词器和处理器
146
- self.text_tokenizer = AutoTokenizer.from_pretrained('facebook/bart-base')
147
- self.code_tokenizer = AutoTokenizer.from_pretrained('gpt2')
148
- self.nlp_tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
149
- self.speech_processor = AutoTokenizer.from_pretrained('facebook/wav2vec2-base-960h')
150
- self.vision_processor = AutoTokenizer.from_pretrained('openai/clip-vit-base-patch32')
151
-
152
- def forward(self, task, inputs):
153
- if task == 'text_generation':
154
- attention_mask = inputs.get('attention_mask')
155
- outputs = self.text_generator.generate(
156
- inputs['input_ids'],
157
- max_new_tokens=100,
158
- pad_token_id=self.text_tokenizer.eos_token_id,
159
- attention_mask=attention_mask,
160
- top_p=0.9,
161
- top_k=50,
162
- temperature=0.8,
163
- do_sample=True
164
- )
165
- return self.text_tokenizer.decode(outputs[0], skip_special_tokens=True)
166
- elif task == 'code_generation':
167
- attention_mask = inputs.get('attention_mask')
168
- outputs = self.code_generator.generate(
169
- inputs['input_ids'],
170
- max_new_tokens=50,
171
- pad_token_id=self.code_tokenizer.eos_token_id,
172
- attention_mask=attention_mask,
173
- top_p=0.95,
174
- top_k=50,
175
- temperature=1.2,
176
- do_sample=True
177
- )
178
- return self.code_tokenizer.decode(outputs[0], skip_special_tokens=True)
179
- # 添加其他任务的逻辑...
180
-
181
- # 计算模型参数数量的函数
182
- def count_parameters(model):
183
- return sum(p.numel() for p in model.parameters() if p.requires_grad)
184
-
185
- # 初始化模型
186
- model = MultiModalModel()
187
-
188
- # 计算并打印模型参数数量
189
- total_params = count_parameters(model)
190
- print(f"模型总参数数量: {total_params}")
191
  ```
 
118
 
119
  ## How to Get Started with the Model
120
  ```python
121
+ # Use a pipeline as a high-level helper
122
+ from transformers import pipeline
123
+
124
+ pipe = pipeline("text-generation", model="zeroMN/SHMT")
125
+ ```
126
+ ```python
127
+ # Load model directly
128
+ from transformers import AutoModel
129
+ model = AutoModel.from_pretrained("zeroMN/SHMT")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  ```