SamSak09 commited on
Commit
c9443b0
·
verified ·
1 Parent(s): e4676ce

Update app2.py

Browse files
Files changed (1) hide show
  1. app2.py +92 -2
app2.py CHANGED
@@ -1,85 +1,175 @@
 
1
  import os
 
2
  import shutil
3
 
 
 
4
  # --- SEARCH AND DESTROY POISONED CACHE ---
 
5
  corrupted_dir = "/root/.cache/huggingface/hub/models--google--umt5-base"
 
6
  if os.path.exists(corrupted_dir):
 
7
  print("[SYSTEM] Found corrupted UMT5 cache. Deleting...")
 
8
  shutil.rmtree(corrupted_dir, ignore_errors=True)
 
9
  else:
 
10
  print("[SYSTEM] Cache is clean.")
11
 
 
 
12
  # --- YOUR ORIGINAL CODE STARTS HERE ---
 
13
  from flask import Flask, request, jsonify, send_from_directory # Added send_from_directory
 
14
  from flask_sock import Sock
 
15
  from transformers import AutoModel
 
16
  import torch
 
17
  import time
 
18
  import json
 
19
  from flask_cors import CORS
20
 
 
 
21
  app = Flask(__name__)
 
22
  CORS(app)
 
23
  sock = Sock(app) # Initialize WebSocket support
24
 
 
 
25
  print("[SYSTEM] Booting up Network Server...")
 
26
  print("[SYSTEM] Loading FloodDiffusionTiny model from Hugging Face...")
27
 
 
 
28
  # 1. Load the model
 
29
  model = AutoModel.from_pretrained(
 
30
  "ShandaAI/FloodDiffusionTiny",
 
31
  trust_remote_code=True
 
32
  )
33
 
 
 
34
  # 2. Cloud Architecture Override
 
35
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
36
  model = model.to(device)
 
37
  print(f"[SYSTEM] Model loaded successfully onto device: {device}")
38
 
 
 
39
  @app.route('/')
 
40
  def serve_ui():
 
41
  # This tells Flask to send the index.html file to the user's browser
 
42
  return send_from_directory('.', 'index.html')
43
 
 
 
44
  # --- THE NEW WEBSOCKET PIPELINE ---
 
45
  @sock.route('/api/generate_stream')
 
46
  def stream_motion(ws):
 
47
  print("\n[NETWORK] 🟢 WebSocket Connection Opened! Client connected.")
 
48
 
 
49
  # Keep the connection open forever
 
50
  while True:
 
51
  try:
 
52
  # 1. Wait for the live prompt from the client's text box
 
53
  raw_data = ws.receive()
 
54
  if raw_data is None:
 
55
  continue
 
56
 
 
57
  data = json.loads(raw_data)
 
58
  text_prompt = data.get('prompt', '')
 
 
 
59
  print(f"[NETWORK] Live Prompt Received: '{text_prompt}'")
 
60
 
 
61
  start_time = time.time()
 
62
 
 
63
  # 2. Server Processing (Inference)
64
- motion_joints = model(text_prompt, length=15, output_joints=True)
 
 
65
  processing_time = (time.time() - start_time) * 1000
 
66
 
 
67
  # 3. Format Network Payload
 
68
  payload = {
 
69
  "status": "success",
 
 
 
70
  "latency_ms": round(processing_time, 2),
 
71
  "tensor_shape": list(motion_joints.shape),
 
72
  "data": motion_joints.tolist()
 
73
  }
 
74
 
 
75
  # 4. Push data back through the pipe instantly!
 
76
  ws.send(json.dumps(payload))
 
77
  print(f"[NETWORK] ⚡ Streamed 30 frames to client in {processing_time:.2f}ms")
 
78
 
 
79
  except Exception as e:
 
80
  print(f"[NETWORK] 🔴 WebSocket Error or Disconnect: {e}")
 
81
  break
82
 
 
 
83
  if __name__ == '__main__':
 
84
  # --- PORT 7860 FOR HUGGING FACE ---
85
- app.run(host='0.0.0.0', port=7860, debug=False)
 
 
 
1
+
2
  import os
3
+
4
  import shutil
5
 
6
+
7
+
8
  # --- SEARCH AND DESTROY POISONED CACHE ---
9
+
10
  corrupted_dir = "/root/.cache/huggingface/hub/models--google--umt5-base"
11
+
12
  if os.path.exists(corrupted_dir):
13
+
14
  print("[SYSTEM] Found corrupted UMT5 cache. Deleting...")
15
+
16
  shutil.rmtree(corrupted_dir, ignore_errors=True)
17
+
18
  else:
19
+
20
  print("[SYSTEM] Cache is clean.")
21
 
22
+
23
+
24
  # --- YOUR ORIGINAL CODE STARTS HERE ---
25
+
26
  from flask import Flask, request, jsonify, send_from_directory # Added send_from_directory
27
+
28
  from flask_sock import Sock
29
+
30
  from transformers import AutoModel
31
+
32
  import torch
33
+
34
  import time
35
+
36
  import json
37
+
38
  from flask_cors import CORS
39
 
40
+
41
+
42
  app = Flask(__name__)
43
+
44
  CORS(app)
45
+
46
  sock = Sock(app) # Initialize WebSocket support
47
 
48
+
49
+
50
  print("[SYSTEM] Booting up Network Server...")
51
+
52
  print("[SYSTEM] Loading FloodDiffusionTiny model from Hugging Face...")
53
 
54
+
55
+
56
  # 1. Load the model
57
+
58
  model = AutoModel.from_pretrained(
59
+
60
  "ShandaAI/FloodDiffusionTiny",
61
+
62
  trust_remote_code=True
63
+
64
  )
65
 
66
+
67
+
68
  # 2. Cloud Architecture Override
69
+
70
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
71
+
72
  model = model.to(device)
73
+
74
  print(f"[SYSTEM] Model loaded successfully onto device: {device}")
75
 
76
+
77
+
78
  @app.route('/')
79
+
80
  def serve_ui():
81
+
82
  # This tells Flask to send the index.html file to the user's browser
83
+
84
  return send_from_directory('.', 'index.html')
85
 
86
+
87
+
88
  # --- THE NEW WEBSOCKET PIPELINE ---
89
+
90
  @sock.route('/api/generate_stream')
91
+
92
  def stream_motion(ws):
93
+
94
  print("\n[NETWORK] 🟢 WebSocket Connection Opened! Client connected.")
95
+
96
 
97
+
98
  # Keep the connection open forever
99
+
100
  while True:
101
+
102
  try:
103
+
104
  # 1. Wait for the live prompt from the client's text box
105
+
106
  raw_data = ws.receive()
107
+
108
  if raw_data is None:
109
+
110
  continue
111
+
112
 
113
+
114
  data = json.loads(raw_data)
115
+
116
  text_prompt = data.get('prompt', '')
117
+
118
+ ticket_number = data.get('ticket', 0)
119
+
120
  print(f"[NETWORK] Live Prompt Received: '{text_prompt}'")
121
+
122
 
123
+
124
  start_time = time.time()
125
+
126
 
127
+
128
  # 2. Server Processing (Inference)
129
+
130
+ motion_joints = model(text_prompt, length=150, output_joints=True)
131
+
132
  processing_time = (time.time() - start_time) * 1000
133
+
134
 
135
+
136
  # 3. Format Network Payload
137
+
138
  payload = {
139
+
140
  "status": "success",
141
+
142
+ "ticket": ticket_number,
143
+
144
  "latency_ms": round(processing_time, 2),
145
+
146
  "tensor_shape": list(motion_joints.shape),
147
+
148
  "data": motion_joints.tolist()
149
+
150
  }
151
+
152
 
153
+
154
  # 4. Push data back through the pipe instantly!
155
+
156
  ws.send(json.dumps(payload))
157
+
158
  print(f"[NETWORK] ⚡ Streamed 30 frames to client in {processing_time:.2f}ms")
159
+
160
 
161
+
162
  except Exception as e:
163
+
164
  print(f"[NETWORK] 🔴 WebSocket Error or Disconnect: {e}")
165
+
166
  break
167
 
168
+
169
+
170
  if __name__ == '__main__':
171
+
172
  # --- PORT 7860 FOR HUGGING FACE ---
173
+
174
+ app.run(host='0.0.0.0', port=7860, debug=False)
175
+