broadfield-dev commited on
Commit
fb8354d
·
verified ·
1 Parent(s): 07baf53

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +165 -0
app.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, render_template, request
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
+ import numpy as np
5
+ import requests
6
+ import json
7
+ from huggingface_hub import hf_hub_download
8
+
9
+ app = Flask(__name__)
10
+ _cache = {}
11
+
12
+
13
+ def get_sigma(hidden_size: int, seed: int):
14
+ rng = np.random.default_rng(seed)
15
+ sigma = rng.permutation(hidden_size)
16
+ sigma_inv = np.argsort(sigma)
17
+ return torch.tensor(sigma, dtype=torch.long), torch.tensor(sigma_inv, dtype=torch.long)
18
+
19
+
20
+ def load_client_components(ee_model_name: str):
21
+ if ee_model_name in _cache:
22
+ return _cache[ee_model_name]
23
+
24
+ config_path = hf_hub_download(ee_model_name, "ee_config.json")
25
+ with open(config_path) as f:
26
+ ee_config = json.load(f)
27
+
28
+ hidden_size = ee_config["hidden_size"]
29
+ original_model_name = ee_config["original_model"]
30
+
31
+ tokenizer = AutoTokenizer.from_pretrained(original_model_name, trust_remote_code=True)
32
+
33
+ original_model = AutoModelForCausalLM.from_pretrained(
34
+ original_model_name,
35
+ torch_dtype=torch.float32,
36
+ device_map="cpu",
37
+ trust_remote_code=True,
38
+ )
39
+ embed_layer = original_model.model.embed_tokens
40
+ lm_head = original_model.lm_head
41
+ final_norm = original_model.model.norm
42
+ embed_layer.eval()
43
+ lm_head.eval()
44
+ final_norm.eval()
45
+ del original_model
46
+
47
+ _cache[ee_model_name] = (tokenizer, embed_layer, lm_head, final_norm, hidden_size)
48
+ return tokenizer, embed_layer, lm_head, final_norm, hidden_size
49
+
50
+ load_client_components()
51
+
52
+ def generate_tokens(server_url, tokenizer, embed_layer, lm_head, final_norm,
53
+ sigma_t, sigma_inv_t, formatted_prompt, max_new_tokens):
54
+ """
55
+ Token-by-token generation. No KV cache — client accumulates all embeddings
56
+ and sends the full growing sequence each step.
57
+
58
+ Each step:
59
+ 1. Encrypt all token embeddings so far with sigma
60
+ 2. Send to server → get back last hidden state (sigma-space)
61
+ 3. Decrypt last position: apply sigma_inv
62
+ 4. Run final_norm + lm_head locally → next token
63
+ """
64
+ inputs = tokenizer(formatted_prompt, return_tensors="pt")
65
+ input_ids = inputs.input_ids # (1, seq_len)
66
+
67
+ # Build initial encrypted embeddings for full prompt
68
+ with torch.no_grad():
69
+ all_plain_embeds = embed_layer(input_ids) # (1, seq_len, hidden)
70
+
71
+ generated_ids = []
72
+
73
+ for step in range(max_new_tokens):
74
+ # Encrypt the full sequence so far
75
+ all_encrypted = all_plain_embeds[..., sigma_t].to(torch.float16) # (1, seq, hidden)
76
+ seq_len = all_encrypted.shape[1]
77
+ attention_mask = torch.ones(1, seq_len, dtype=torch.long)
78
+
79
+ payload = {
80
+ "inputs_embeds": all_encrypted.tolist(),
81
+ "attention_mask": attention_mask.tolist(),
82
+ }
83
+
84
+ resp = requests.post(f"{server_url}/generate", json=payload, timeout=120)
85
+ if not resp.ok:
86
+ raise RuntimeError(f"Server {resp.status_code}: {resp.text[:400]}")
87
+
88
+ body = resp.json()
89
+ if "error" in body:
90
+ raise RuntimeError(f"Server error: {body['error']}")
91
+
92
+ # Decrypt last position only
93
+ last_hidden = torch.tensor(body["last_hidden"], dtype=torch.float32) # (1, seq, hidden)
94
+ last_pos_sigma = last_hidden[:, -1:, :] # (1, 1, hidden) sigma-space
95
+ last_pos_plain = last_pos_sigma[..., sigma_inv_t] # (1, 1, hidden) plain-space
96
+
97
+ # Client-side: final norm + lm_head → next token
98
+ with torch.no_grad():
99
+ normed = final_norm(last_pos_plain)
100
+ logits = lm_head(normed) # (1, 1, vocab)
101
+
102
+ next_token_id = logits[0, -1, :].argmax().item()
103
+ generated_ids.append(next_token_id)
104
+
105
+ if next_token_id == tokenizer.eos_token_id:
106
+ break
107
+
108
+ # Append new token's plain embedding to the growing sequence
109
+ next_id_tensor = torch.tensor([[next_token_id]])
110
+ with torch.no_grad():
111
+ next_embed = embed_layer(next_id_tensor) # (1, 1, hidden)
112
+ all_plain_embeds = torch.cat([all_plain_embeds, next_embed], dim=1)
113
+
114
+ return generated_ids
115
+
116
+
117
+ @app.route("/", methods=["GET", "POST"])
118
+ def index():
119
+ result = None
120
+ error = None
121
+ form_data = {}
122
+ ee_model_name = 'broadfield-dev/Qwen3-0.6B-dp-ee'
123
+ tokenizer, embed_layer, lm_head, final_norm, hidden_size = \
124
+ load_client_components(ee_model_name)
125
+ if request.method == "POST":
126
+ form_data = request.form.to_dict()
127
+ server_url = request.form["server_url"].rstrip("/")
128
+ #ee_model_name = request.form["ee_model_name"].strip()
129
+ ee_seed = int(request.form["ee_seed"])
130
+ prompt = request.form["prompt"].strip()
131
+ max_tokens = int(request.form.get("max_tokens", 256))
132
+
133
+ try:
134
+ '''tokenizer, embed_layer, lm_head, final_norm, hidden_size = \
135
+ load_client_components(ee_model_name)'''
136
+
137
+ sigma_t, sigma_inv_t = get_sigma(hidden_size, ee_seed)
138
+
139
+ messages = [{"role": "user", "content": prompt}]
140
+ formatted = tokenizer.apply_chat_template(
141
+ messages,
142
+ tokenize=False,
143
+ add_generation_prompt=True,
144
+ enable_thinking=False, # disable Qwen3 thinking mode
145
+ )
146
+
147
+ gen_ids = generate_tokens(
148
+ server_url, tokenizer, embed_layer, lm_head, final_norm,
149
+ sigma_t, sigma_inv_t, formatted, max_tokens
150
+ )
151
+
152
+ result = tokenizer.decode(gen_ids, skip_special_tokens=True).strip()
153
+
154
+ except RuntimeError as e:
155
+ error = str(e)
156
+ except requests.exceptions.ConnectionError:
157
+ error = f"Could not connect to {server_url} — is the server Space running?"
158
+ except Exception as e:
159
+ error = f"{type(e).__name__}: {e}"
160
+
161
+ return render_template("client.html", result=result, error=error, form=form_data)
162
+
163
+
164
+ if __name__ == "__main__":
165
+ app.run(host="0.0.0.0", port=7860)