anas-awadalla commited on
Commit
0a405ca
1 Parent(s): b630945
app.py CHANGED
@@ -54,13 +54,14 @@ with open("bad_words.txt", "r") as f:
54
  model, image_processor, tokenizer = create_model_and_transforms(
55
  clip_vision_encoder_pretrained="openai",
56
  clip_vision_encoder_path="ViT-L-14",
57
- lang_encoder_path="anas-awadalla/mpt-7b",
58
- tokenizer_path="anas-awadalla/mpt-7b",
59
- cross_attn_every_n_layers=4,
60
  )
61
 
62
- checkpoint_path = hf_hub_download("openflamingo/OpenFlamingo-9B-vitl-mpt7b", "checkpoint.pt")
63
  model.load_state_dict(torch.load(checkpoint_path), strict=False)
 
64
  model.eval()
65
 
66
  def generate(
@@ -153,13 +154,13 @@ def generate(
153
  # with torch.cuda.amp.autocast(dtype=torch.bfloat16):
154
  output = model.generate(
155
  vision_x=vision_x,
156
- lang_x=input_ids.to("cuda"),
157
- attention_mask=attention_mask.to("cuda"),
158
  max_new_tokens=30,
159
  num_beams=3,
160
- do_sample=True,
161
- temperature=0.3,
162
- top_k=0,
163
  )
164
 
165
  gen_text = tokenizer.decode(
 
54
  model, image_processor, tokenizer = create_model_and_transforms(
55
  clip_vision_encoder_pretrained="openai",
56
  clip_vision_encoder_path="ViT-L-14",
57
+ lang_encoder_path="anas-awadalla/mpt-1b-redpajama-200b-dolly",
58
+ tokenizer_path="anas-awadalla/mpt-1b-redpajama-200b-dolly",
59
+ cross_attn_every_n_layers=1,
60
  )
61
 
62
+ checkpoint_path = hf_hub_download("openflamingo/OpenFlamingo-3B-vitl-mpt1b-langinstruct", "checkpoint.pt")
63
  model.load_state_dict(torch.load(checkpoint_path), strict=False)
64
+
65
  model.eval()
66
 
67
  def generate(
 
154
  # with torch.cuda.amp.autocast(dtype=torch.bfloat16):
155
  output = model.generate(
156
  vision_x=vision_x,
157
+ lang_x=input_ids,
158
+ attention_mask=attention_mask,
159
  max_new_tokens=30,
160
  num_beams=3,
161
+ # do_sample=True,
162
+ # temperature=0.3,
163
+ # top_k=0,
164
  )
165
 
166
  gen_text = tokenizer.decode(
open_flamingo/open_flamingo/src/factory.py CHANGED
@@ -1,6 +1,5 @@
1
  from transformers import AutoModelForCausalLM, AutoTokenizer
2
  import open_clip
3
- import torch
4
 
5
  from .flamingo import Flamingo
6
  from .flamingo_lm import FlamingoLMMixin
@@ -58,7 +57,8 @@ def create_model_and_transforms(
58
  lang_encoder = AutoModelForCausalLM.from_pretrained(
59
  lang_encoder_path,
60
  local_files_only=use_local_files,
61
- trust_remote_code=True)
 
62
 
63
  # hacks for MPT-1B, which doesn't have a get_input_embeddings method
64
  if "mpt-1b-redpajama-200b" in lang_encoder_path:
@@ -79,7 +79,6 @@ def create_model_and_transforms(
79
  decoder_layers_attr_name = _infer_decoder_layers_attr_name(lang_encoder)
80
  lang_encoder.set_decoder_layers_attr_name(decoder_layers_attr_name)
81
  lang_encoder.resize_token_embeddings(len(text_tokenizer))
82
- lang_encoder.to(0)
83
 
84
  model = Flamingo(
85
  vision_encoder,
@@ -90,7 +89,8 @@ def create_model_and_transforms(
90
  "width"
91
  ],
92
  cross_attn_every_n_layers=cross_attn_every_n_layers,
93
- **flamingo_kwargs)
 
94
 
95
  # Freeze all parameters
96
  model.requires_grad_(False)
 
1
  from transformers import AutoModelForCausalLM, AutoTokenizer
2
  import open_clip
 
3
 
4
  from .flamingo import Flamingo
5
  from .flamingo_lm import FlamingoLMMixin
 
57
  lang_encoder = AutoModelForCausalLM.from_pretrained(
58
  lang_encoder_path,
59
  local_files_only=use_local_files,
60
+ trust_remote_code=True,
61
+ )
62
 
63
  # hacks for MPT-1B, which doesn't have a get_input_embeddings method
64
  if "mpt-1b-redpajama-200b" in lang_encoder_path:
 
79
  decoder_layers_attr_name = _infer_decoder_layers_attr_name(lang_encoder)
80
  lang_encoder.set_decoder_layers_attr_name(decoder_layers_attr_name)
81
  lang_encoder.resize_token_embeddings(len(text_tokenizer))
 
82
 
83
  model = Flamingo(
84
  vision_encoder,
 
89
  "width"
90
  ],
91
  cross_attn_every_n_layers=cross_attn_every_n_layers,
92
+ **flamingo_kwargs,
93
+ )
94
 
95
  # Freeze all parameters
96
  model.requires_grad_(False)
open_flamingo/open_flamingo/src/flamingo.py CHANGED
@@ -212,7 +212,7 @@ class Flamingo(nn.Module):
212
  with torch.no_grad():
213
  vision_x = self.vision_encoder(vision_x)[1]
214
  vision_x = rearrange(vision_x, "(b T F) v d -> b T F v d", b=b, T=T, F=F)
215
- vision_x = self.perceiver(vision_x).to(0)
216
 
217
  for layer in self.lang_encoder._get_decoder_layers():
218
  layer.condition_vis_x(vision_x)
 
212
  with torch.no_grad():
213
  vision_x = self.vision_encoder(vision_x)[1]
214
  vision_x = rearrange(vision_x, "(b T F) v d -> b T F v d", b=b, T=T, F=F)
215
+ vision_x = self.perceiver(vision_x)
216
 
217
  for layer in self.lang_encoder._get_decoder_layers():
218
  layer.condition_vis_x(vision_x)