chumengl commited on
Commit
8b5917a
·
1 Parent(s): d96c7c3

[update] delete flash-attn

Browse files
Files changed (2) hide show
  1. inference.py +0 -60
  2. model.py +16 -7
inference.py DELETED
@@ -1,60 +0,0 @@
1
- """Simple inference script to test the HuggingFace LangFlow model."""
2
-
3
- import argparse
4
- import torch
5
- from transformers import AutoModelForMaskedLM, AutoTokenizer
6
-
7
-
8
- def main():
9
- parser = argparse.ArgumentParser(description="Generate samples with LangFlow")
10
- parser.add_argument(
11
- "--model_path", type=str, default="hf_release/model_weights",
12
- help="Path to the HuggingFace model directory")
13
- parser.add_argument(
14
- "--num_samples", type=int, default=5,
15
- help="Number of samples to generate")
16
- parser.add_argument(
17
- "--num_steps", type=int, default=128,
18
- help="Number of denoising steps")
19
- parser.add_argument(
20
- "--seq_length", type=int, default=1024,
21
- help="Sequence length")
22
- parser.add_argument(
23
- "--seed", type=int, default=42,
24
- help="Random seed")
25
- args = parser.parse_args()
26
-
27
- # Set seed for reproducibility
28
- torch.manual_seed(args.seed)
29
- if torch.cuda.is_available():
30
- torch.cuda.manual_seed_all(args.seed)
31
-
32
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
33
- print(f"Using device: {device}")
34
-
35
- tokenizer = AutoTokenizer.from_pretrained("gpt2")
36
- model = AutoModelForMaskedLM.from_pretrained(
37
- args.model_path,
38
- trust_remote_code=True
39
- )
40
- model = model.to(device)
41
- model.eval()
42
-
43
- print(f"\nGenerating {args.num_samples} samples with {args.num_steps} steps...")
44
- with torch.no_grad():
45
- samples = model.generate_samples(
46
- num_samples=args.num_samples,
47
- seq_length=args.seq_length,
48
- num_steps=args.num_steps,
49
- device=device
50
- )
51
-
52
- texts = tokenizer.batch_decode(samples, skip_special_tokens=True)
53
- for i, text in enumerate(texts):
54
- print(f"\n--- Sample {i+1} ---")
55
- # Print first 500 characters to keep output manageable
56
- print(text[:500] + ("..." if len(text) > 500 else ""))
57
-
58
-
59
- if __name__ == "__main__":
60
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model.py CHANGED
@@ -7,8 +7,6 @@ import math
7
  import typing
8
 
9
  import einops
10
- import flash_attn
11
- import flash_attn.layers.rotary
12
  import torch
13
  import torch.nn as nn
14
  import torch.nn.functional as F
@@ -91,6 +89,19 @@ class Rotary(nn.Module):
91
  return self.cos_cached, self.sin_cached
92
 
93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  def split_and_apply_rotary_pos_emb(qkv, rotary_cos_sin):
95
  with torch.autocast(device_type='cuda', enabled=False):
96
  cos, sin = rotary_cos_sin
@@ -99,10 +110,8 @@ def split_and_apply_rotary_pos_emb(qkv, rotary_cos_sin):
99
  cos = cos[0, :, 0, 0, :cos.shape[-1]//2]
100
  sin = sin[0, :, 0, 0, :sin.shape[-1]//2]
101
  q, k, v = qkv.chunk(3, dim=2)
102
- q = flash_attn.layers.rotary.apply_rotary_emb_torch(
103
- q.squeeze(dim=2), cos, sin)
104
- k = flash_attn.layers.rotary.apply_rotary_emb_torch(
105
- k.squeeze(dim=2), cos, sin)
106
  v = v.squeeze(dim=2)
107
  return q, k, v
108
 
@@ -548,4 +557,4 @@ class LangFlow(transformers.PreTrainedModel):
548
  return_dict=False)
549
  samples = logits.argmax(dim=-1)
550
 
551
- return samples
 
7
  import typing
8
 
9
  import einops
 
 
10
  import torch
11
  import torch.nn as nn
12
  import torch.nn.functional as F
 
89
  return self.cos_cached, self.sin_cached
90
 
91
 
92
+ def _apply_rotary_emb(x, cos, sin):
93
+ # x: [batch, seqlen, nheads, headdim]
94
+ # cos, sin: [seqlen, headdim//2]
95
+ ro_dim = cos.shape[-1] * 2
96
+ # Expand to [1, seqlen, 1, ro_dim] for broadcasting
97
+ cos = torch.cat([cos, cos], dim=-1)[None, :, None, :]
98
+ sin = torch.cat([sin, sin], dim=-1)[None, :, None, :]
99
+ x_rot = x[..., :ro_dim]
100
+ x1, x2 = x_rot.chunk(2, dim=-1)
101
+ x_rotated = torch.cat([-x2, x1], dim=-1)
102
+ return torch.cat([x_rot * cos + x_rotated * sin, x[..., ro_dim:]], dim=-1)
103
+
104
+
105
  def split_and_apply_rotary_pos_emb(qkv, rotary_cos_sin):
106
  with torch.autocast(device_type='cuda', enabled=False):
107
  cos, sin = rotary_cos_sin
 
110
  cos = cos[0, :, 0, 0, :cos.shape[-1]//2]
111
  sin = sin[0, :, 0, 0, :sin.shape[-1]//2]
112
  q, k, v = qkv.chunk(3, dim=2)
113
+ q = _apply_rotary_emb(q.squeeze(dim=2), cos, sin)
114
+ k = _apply_rotary_emb(k.squeeze(dim=2), cos, sin)
 
 
115
  v = v.squeeze(dim=2)
116
  return q, k, v
117
 
 
557
  return_dict=False)
558
  samples = logits.argmax(dim=-1)
559
 
560
+ return samples