loocorez commited on
Commit
9ef35af
·
verified ·
1 Parent(s): e9ab863

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +9 -7
app.py CHANGED
@@ -54,19 +54,21 @@ weights_path = os.path.join(local_dir, "pytorch_model.bin")
54
  state = torch.load(weights_path, map_location=device)
55
  state = {k.lstrip("_orig_mod."): v for k, v in state.items()}
56
  model.load_state_dict(state, strict=True, assign=True)
57
- model = model.to(dtype=torch.float32)
 
58
  model.eval()
59
 
60
  def complete(prompt, max_new_tokens=64):
61
  input_ids = tokenizer.encode(prompt, prepend=tokenizer.get_bos_token_id())
62
  ids = torch.tensor([input_ids], dtype=torch.long, device=device)
63
- generated = []
64
  with torch.inference_mode():
65
- for _ in range(max_new_tokens):
66
- logits = model.forward(ids)
67
- logits = logits[:, -1, :]
68
- next_token = torch.argmax(logits, dim=-1, keepdim=True)
69
- ids = torch.cat([ids, next_token], dim=1)
 
 
70
  return tokenizer.decode(ids[0].tolist())
71
 
72
  with gr.Blocks() as demo:
 
54
  state = torch.load(weights_path, map_location=device)
55
  state = {k.lstrip("_orig_mod."): v for k, v in state.items()}
56
  model.load_state_dict(state, strict=True, assign=True)
57
+ # Ensure rotary buffers and weights are bf16 as expected by model
58
+ model = model.to(device).to(dtype=torch.bfloat16)
59
  model.eval()
60
 
61
  def complete(prompt, max_new_tokens=64):
62
  input_ids = tokenizer.encode(prompt, prepend=tokenizer.get_bos_token_id())
63
  ids = torch.tensor([input_ids], dtype=torch.long, device=device)
 
64
  with torch.inference_mode():
65
+ # autocast so activations match model bf16 dtype
66
+ with torch.autocast(device_type=device.type, dtype=torch.bfloat16):
67
+ for _ in range(max_new_tokens):
68
+ logits = model.forward(ids)
69
+ logits = logits[:, -1, :]
70
+ next_token = torch.argmax(logits, dim=-1, keepdim=True)
71
+ ids = torch.cat([ids, next_token], dim=1)
72
  return tokenizer.decode(ids[0].tolist())
73
 
74
  with gr.Blocks() as demo: