allanctan commited on
Commit
94dc091
Β·
1 Parent(s): ea2c226

test: for websocket interface

Browse files
README.md CHANGED
@@ -10,3 +10,9 @@ short_description: better-ed mini
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
13
+
14
+
15
+ curl -X POST https://allanctan-ai.hf.space/be-mini-ai/transcribe \
16
+ -F "file=@voice\a_projectil_is.wav"
17
+
18
+ allanctan-ai/be-mini-ai
__pycache__/parse.cpython-313.pyc ADDED
Binary file (2.34 kB). View file
 
all_questions_with_audio.json ADDED
The diff for this file is too large to render. See raw diff
 
main.py CHANGED
@@ -1,15 +1,50 @@
1
- from fastapi import FastAPI, UploadFile, File
 
2
  from unsloth import FastVisionModel
3
  import torch
4
  import shutil
5
  import os
 
 
 
 
 
 
 
 
6
  os.environ["TORCHINDUCTOR_CACHE_DIR"] = "/tmp/torchinductor"
7
 
8
  app = FastAPI()
9
 
 
 
 
 
 
 
 
 
 
 
10
  model, processor = FastVisionModel.from_pretrained("unsloth/gemma-3n-e2b-it", load_in_4bit=True)
11
  model.generation_config.cache_implementation = "static"
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  @app.post("/transcribe/")
14
  async def transcribe_audio(file: UploadFile = File(...)):
15
  filepath = f"/tmp/{file.filename}"
@@ -29,6 +64,101 @@ async def transcribe_audio(file: UploadFile = File(...)):
29
  tokenize=True, return_dict=True, return_tensors="pt"
30
  ).to(model.device, dtype=model.dtype)
31
 
32
- outputs = model.generate(**input_ids, max_new_tokens=16)
33
  result = processor.batch_decode(outputs, skip_special_tokens=True)[0]
 
 
 
 
 
 
34
  return {"text": result}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, UploadFile, File, WebSocket, WebSocketDisconnect
2
+ from fastapi.middleware.cors import CORSMiddleware
3
  from unsloth import FastVisionModel
4
  import torch
5
  import shutil
6
  import os
7
+ import json
8
+ import base64
9
+ import tempfile
10
+ import logging
11
+
12
+ logging.basicConfig(level=logging.INFO)
13
+ logger = logging.getLogger(__name__)
14
+
15
  os.environ["TORCHINDUCTOR_CACHE_DIR"] = "/tmp/torchinductor"
16
 
17
  app = FastAPI()
18
 
19
+ # Add CORS for WebSocket
20
+ app.add_middleware(
21
+ CORSMiddleware,
22
+ allow_origins=["*"],
23
+ allow_credentials=True,
24
+ allow_methods=["*"],
25
+ allow_headers=["*"],
26
+ )
27
+
28
+ # Load model at startup (same as your original)
29
  model, processor = FastVisionModel.from_pretrained("unsloth/gemma-3n-e2b-it", load_in_4bit=True)
30
  model.generation_config.cache_implementation = "static"
31
 
32
+ @app.get("/")
33
+ async def root():
34
+ return {"message": "API is running"}
35
+
36
+ @app.get("/health")
37
+ async def health_check():
38
+ try:
39
+ return {
40
+ "status": "healthy",
41
+ "model_loaded": model is not None,
42
+ "processor_loaded": processor is not None,
43
+ "device": str(model.device) if model else "none"
44
+ }
45
+ except Exception as e:
46
+ return {"status": "unhealthy", "error": str(e)}
47
+
48
  @app.post("/transcribe/")
49
  async def transcribe_audio(file: UploadFile = File(...)):
50
  filepath = f"/tmp/{file.filename}"
 
64
  tokenize=True, return_dict=True, return_tensors="pt"
65
  ).to(model.device, dtype=model.dtype)
66
 
67
+ outputs = model.generate(**input_ids, max_new_tokens=64, do_sample=False, temperature=0.1)
68
  result = processor.batch_decode(outputs, skip_special_tokens=True)[0]
69
+ result = result.split("model\n")[-1].split("<end_of_turn>")[0].strip()
70
+
71
+ # Cleanup
72
+ if os.path.exists(filepath):
73
+ os.remove(filepath)
74
+
75
  return {"text": result}
76
+
77
+ # Simple WebSocket endpoint
78
+ @app.websocket("/ws")
79
+ async def websocket_endpoint(websocket: WebSocket):
80
+ await websocket.accept()
81
+ logger.info("WebSocket client connected")
82
+
83
+ try:
84
+ while True:
85
+ # Receive message
86
+ data = await websocket.receive_text()
87
+ message = json.loads(data)
88
+ logger.info(f"Received message: {message}")
89
+
90
+ # Handle audio data
91
+ if "audio_data" in message:
92
+ audio_b64 = message["audio_data"]
93
+ mime_type = message.get("mime_type", "audio/wav")
94
+
95
+ try:
96
+ # Use your exact transcribe logic
97
+ transcription = await transcribe_base64_audio(audio_b64, mime_type)
98
+
99
+ # Send response
100
+ response = {
101
+ "type": "transcription",
102
+ "text": transcription
103
+ }
104
+ await websocket.send_text(json.dumps(response))
105
+
106
+ except Exception as e:
107
+ logger.error(f"Transcription error: {e}")
108
+ await websocket.send_text(json.dumps({
109
+ "type": "error",
110
+ "message": str(e)
111
+ }))
112
+
113
+ # Handle ping/pong
114
+ elif message.get("type") == "ping":
115
+ await websocket.send_text(json.dumps({"type": "pong"}))
116
+
117
+ else:
118
+ await websocket.send_text(json.dumps({
119
+ "type": "error",
120
+ "message": "Unknown message format"
121
+ }))
122
+
123
+ except WebSocketDisconnect:
124
+ logger.info("WebSocket client disconnected")
125
+ except Exception as e:
126
+ logger.error(f"WebSocket error: {e}")
127
+
128
+ async def transcribe_base64_audio(audio_b64: str, mime_type: str) -> str:
129
+ """Use your exact transcribe logic but with base64 audio data"""
130
+
131
+ # Convert base64 to file (same as your transcribe logic)
132
+ audio_data = base64.b64decode(audio_b64)
133
+
134
+ # Create temp file
135
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
136
+ temp_file.write(audio_data)
137
+ filepath = temp_file.name
138
+
139
+ try:
140
+ # Your exact transcribe logic
141
+ messages = [{
142
+ "role": "user",
143
+ "content": [
144
+ {"type": "audio", "audio": filepath},
145
+ {"type": "text", "text": "Transcribe this audio"},
146
+ ]
147
+ }]
148
+
149
+ input_ids = processor.apply_chat_template(
150
+ messages, add_generation_prompt=True,
151
+ tokenize=True, return_dict=True, return_tensors="pt"
152
+ ).to(model.device, dtype=model.dtype)
153
+
154
+ outputs = model.generate(**input_ids, max_new_tokens=64, do_sample=False, temperature=0.1)
155
+ result = processor.batch_decode(outputs, skip_special_tokens=True)[0]
156
+ print(result)
157
+ result = result.split("model\n")[-1].split("<end_of_turn>")[0].strip()
158
+
159
+ return result
160
+
161
+ finally:
162
+ # Cleanup temp file
163
+ if os.path.exists(filepath):
164
+ os.remove(filepath)
test_start.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Test script for the start() function
4
+ """
5
+
6
+ from parse import start
7
+
8
+ def test_start_function():
9
+ """Test the start function with different speakers"""
10
+
11
+ print("πŸ§ͺ Testing start() function with different speakers\n")
12
+
13
+ # Test with 'question' speaker (first level questions)
14
+ print("πŸ“’ Results for speaker 'question':")
15
+ results = start('question')
16
+ print(f"Found {len(results)} questions")
17
+ for i, result in enumerate(results[:3]): # Show first 3
18
+ print(f" {i+1}. {result['message']}")
19
+ print(f" Audio: {result['audio'] if result['audio'] else 'No audio'}")
20
+
21
+ print("\n" + "-"*50)
22
+
23
+ # Test with a non-existent speaker
24
+ print("πŸ“’ Results for speaker 'non_existent':")
25
+ results = start('non_existent')
26
+ print(f"Found {len(results)} results")
27
+
28
+ print("\n" + "-"*50)
29
+
30
+ # Test with empty speaker
31
+ print("πŸ“’ Results for speaker '':")
32
+ results = start('')
33
+ print(f"Found {len(results)} results")
34
+
35
+ if __name__ == "__main__":
36
+ test_start_function()
working.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, UploadFile, File
2
+ from unsloth import FastVisionModel
3
+ import torch
4
+ import shutil
5
+ import os
6
+ os.environ["TORCHINDUCTOR_CACHE_DIR"] = "/tmp/torchinductor"
7
+
8
+ app = FastAPI()
9
+
10
+ model, processor = FastVisionModel.from_pretrained("unsloth/gemma-3n-e2b-it", load_in_4bit=True)
11
+ model.generation_config.cache_implementation = "static"
12
+
13
+ @app.get("/")
14
+ async def root():
15
+ return {"message": "API is running"}
16
+
17
+ @app.get("/health")
18
+ async def health_check():
19
+ try:
20
+ return {
21
+ "status": "healthy",
22
+ "model_loaded": model is not None,
23
+ "processor_loaded": processor is not None,
24
+ "device": str(model.device) if model else "none"
25
+ }
26
+ except Exception as e:
27
+ return {"status": "unhealthy", "error": str(e)}
28
+
29
+ @app.post("/transcribe/")
30
+ async def transcribe_audio(file: UploadFile = File(...)):
31
+ filepath = f"/tmp/{file.filename}"
32
+ with open(filepath, "wb") as buffer:
33
+ shutil.copyfileobj(file.file, buffer)
34
+
35
+ messages = [{
36
+ "role": "user",
37
+ "content": [
38
+ {"type": "audio", "audio": filepath},
39
+ {"type": "text", "text": "Transcribe this audio"},
40
+ ]
41
+ }]
42
+
43
+ input_ids = processor.apply_chat_template(
44
+ messages, add_generation_prompt=True,
45
+ tokenize=True, return_dict=True, return_tensors="pt"
46
+ ).to(model.device, dtype=model.dtype)
47
+
48
+ # Generate output from the model
49
+ outputs = model.generate(**input_ids, max_new_tokens=64, do_sample=False,
50
+ temperature=0.1)
51
+
52
+ # decode and print the output as text
53
+ result = processor.batch_decode(outputs, skip_special_tokens=True)[0]
54
+
55
+ # Extract only transcription
56
+ result = result.split("model\n")[-1].split("<end_of_turn>")[0].strip()
57
+ return {"text": result}
58
+
59
+