yanchaocc commited on
Commit
1029ecd
1 Parent(s): 601776e

Create ModelInfo.ts

Browse files
Files changed (1) hide show
  1. ModelInfo.ts +234 -0
ModelInfo.ts ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export class ModelInfo {
2
+ /**
3
+ * Key to config.json file.
4
+ */
5
+ key: string;
6
+ etag: string;
7
+ lastModified: Date;
8
+ size: number;
9
+ modelId: ModelId;
10
+ author?: AuthorId;
11
+ siblings: IS3ObjectWRelativeFilename[];
12
+ config: Obj;
13
+ configTxt?: string; /// if flag is set when fetching.
14
+ downloads?: number; /// if flag is set when fetching.
15
+ naturalIdx: number;
16
+ cardSource?: Source;
17
+ cardData?: Obj;
18
+
19
+ constructor(o: Partial<ModelInfo>) {
20
+ return Object.assign(this, o);
21
+ }
22
+
23
+ get jsonUrl(): string {
24
+ return Bucket.R.models.urlForKey(this.key);
25
+ }
26
+
27
+ get cdnJsonUrl(): string {
28
+ return Bucket.R.models.cdnUrlForKey(this.key);
29
+ }
30
+
31
+ async validate(): Promise<Ajv.ErrorObject[] | undefined> {
32
+ const jsonSchema = JSON.parse(
33
+ await fs.promises.readFile(CONFIG_JSON_SCHEMA, 'utf8')
34
+ );
35
+ const ajv = new Ajv();
36
+ ajv.validate(jsonSchema, this.config);
37
+ return ajv.errors ?? undefined;
38
+ }
39
+
40
+ /**
41
+ * Readme key, w. and w/o S3 prefix.
42
+ */
43
+ get readmeKey(): string {
44
+ return this.key.replace("config.json", "README.md");
45
+ }
46
+ get readmeTrimmedKey(): string {
47
+ return Utils.trimPrefix(this.readmeKey, S3_MODELS_PREFIX);
48
+ }
49
+
50
+ /**
51
+ * ["pytorch", "tf", ...]
52
+ */
53
+ get mlFrameworks(): string[] {
54
+ return Object.keys(FileType).filter(k => {
55
+ const filename = FileType[k];
56
+ const isExtension = filename.startsWith(".");
57
+ return isExtension
58
+ ? this.siblings.some(sibling => sibling.rfilename.endsWith(filename))
59
+ : this.siblings.some(sibling => sibling.rfilename === filename);
60
+ });
61
+ }
62
+ /**
63
+ * What to display in the code sample.
64
+ */
65
+ get autoArchitecture(): string {
66
+ const useTF = this.mlFrameworks.includes("tf") && ! this.mlFrameworks.includes("pytorch");
67
+ const arch = this.autoArchType[0];
68
+ return useTF ? `TF${arch}` : arch;
69
+ }
70
+ get autoArchType(): [string, string | undefined] {
71
+ const architectures = this.config.architectures;
72
+ if (!architectures || architectures.length === 0) {
73
+ return ["AutoModel", undefined];
74
+ }
75
+ const architecture = architectures[0].toString() as string;
76
+ if (architecture.endsWith("ForQuestionAnswering")) {
77
+ return ["AutoModelForQuestionAnswering", "question-answering"];
78
+ }
79
+ else if (architecture.endsWith("ForTokenClassification")) {
80
+ return ["AutoModelForTokenClassification", "token-classification"];
81
+ }
82
+ else if (architecture.endsWith("ForSequenceClassification")) {
83
+ return ["AutoModelForSequenceClassification", "text-classification"];
84
+ }
85
+ else if (architecture.endsWith("ForMultipleChoice")) {
86
+ return ["AutoModelForMultipleChoice", "multiple-choice"];
87
+ }
88
+ else if (architecture.endsWith("ForPreTraining")) {
89
+ return ["AutoModelForPreTraining", "pretraining"];
90
+ }
91
+ else if (architecture.endsWith("ForMaskedLM")) {
92
+ return ["AutoModelForMaskedLM", "masked-lm"];
93
+ }
94
+ else if (architecture.endsWith("ForCausalLM")) {
95
+ return ["AutoModelForCausalLM", "causal-lm"];
96
+ }
97
+ else if (
98
+ architecture.endsWith("ForConditionalGeneration")
99
+ || architecture.endsWith("MTModel")
100
+ || architecture == "EncoderDecoderModel"
101
+ ) {
102
+ return ["AutoModelForSeq2SeqLM", "seq2seq"];
103
+ }
104
+ else if (architecture.includes("LMHead")) {
105
+ return ["AutoModelWithLMHead", "lm-head"];
106
+ }
107
+ else if (architecture.endsWith("Model")) {
108
+ return ["AutoModel", undefined];
109
+ }
110
+ else {
111
+ return [architecture, undefined];
112
+ }
113
+ }
114
+ /**
115
+ * All tags
116
+ */
117
+ get tags(): string[] {
118
+ const x = [
119
+ ...this.mlFrameworks,
120
+ ];
121
+ if (this.config.model_type) {
122
+ x.push(this.config.model_type);
123
+ }
124
+ const arch = this.autoArchType[1];
125
+ if (arch) {
126
+ x.push(arch);
127
+ }
128
+ if (arch === "lm-head" && this.config.model_type) {
129
+ if ([
130
+ "t5",
131
+ "bart",
132
+ "marian",
133
+ ].includes(this.config.model_type)) {
134
+ x.push("seq2seq");
135
+ }
136
+ else if ([
137
+ "gpt2",
138
+ "ctrl",
139
+ "openai-gpt",
140
+ "xlnet",
141
+ "transfo-xl",
142
+ "reformer",
143
+ ].includes(this.config.model_type)) {
144
+ x.push("causal-lm");
145
+ }
146
+ else {
147
+ x.push("masked-lm");
148
+ }
149
+ }
150
+ x.push(
151
+ ...this.languages() ?? []
152
+ );
153
+ x.push(
154
+ ...this.datasets().map(k => `dataset:${k}`)
155
+ );
156
+ for (let [k, v] of Object.entries(this.cardData ?? {})) {
157
+ if (!['tags', 'license'].includes(k)) {
158
+ /// ^^ whitelist of other accepted keys
159
+ continue;
160
+ }
161
+ if (typeof v === 'string') {
162
+ v = [ v ];
163
+ } else if (Utils.isStrArray(v)) {
164
+ /// ok
165
+ } else {
166
+ c.error(`Invalid ${k} tag type`, v);
167
+ c.debug(this.modelId);
168
+ continue;
169
+ }
170
+ if (k === 'license') {
171
+ x.push(...v.map(x => `license:${x.toLowerCase()}`));
172
+ } else {
173
+ x.push(...v);
174
+ }
175
+ }
176
+ if (this.config.task_specific_params) {
177
+ const keys = Object.keys(this.config.task_specific_params);
178
+ for (const key of keys) {
179
+ x.push(`pipeline:${key}`);
180
+ }
181
+ }
182
+ const explicit_ptag = this.cardData?.pipeline_tag;
183
+ if (explicit_ptag) {
184
+ if (typeof explicit_ptag === 'string') {
185
+ x.push(`pipeline_tag:${explicit_ptag}`);
186
+ } else {
187
+ x.push(`pipeline_tag:invalid`);
188
+ }
189
+ }
190
+ return [...new Set(x)];
191
+ }
192
+
193
+ get pipeline_tag(): (keyof typeof PipelineType) | undefined {
194
+ if (isBlacklisted(this.modelId) || this.cardData?.inference === false) {
195
+ return undefined;
196
+ }
197
+
198
+ const explicit_ptag = this.cardData?.pipeline_tag;
199
+ if (explicit_ptag) {
200
+ if (typeof explicit_ptag == 'string') {
201
+ return explicit_ptag as keyof typeof PipelineType;
202
+ } else {
203
+ c.error(`Invalid explicit pipeline_tag`, explicit_ptag);
204
+ return undefined;
205
+ }
206
+ }
207
+
208
+ const tags = this.tags;
209
+ /// Special case for translation
210
+ /// Get the first of the explicit tags that matches.
211
+ const EXPLICIT_PREFIX = "pipeline:";
212
+ const explicit_tag = tags.find(x => x.startsWith(EXPLICIT_PREFIX + `translation`));
213
+ if (!!explicit_tag) {
214
+ return "translation";
215
+ }
216
+ /// Otherwise, get the first (most specific) match **from the mapping**.
217
+ for (const ptag of ALL_PIPELINE_TYPES) {
218
+ if (tags.includes(ptag)) {
219
+ return ptag;
220
+ }
221
+ }
222
+ /// Extra mapping
223
+ const mapping = new Map<string, keyof typeof PipelineType>([
224
+ ["seq2seq", "text-generation"],
225
+ ["causal-lm", "text-generation"],
226
+ ["masked-lm", "fill-mask"],
227
+ ]);
228
+ for (const [tag, ptag] of mapping) {
229
+ if (tags.includes(tag)) {
230
+ return ptag;
231
+ }
232
+ }
233
+ }
234
+ }