malek-messaoudii commited on
Commit
e8aa76b
Β·
1 Parent(s): a8c8142

Refactor chatbot and STT services to improve model loading, response generation, and error handling; utilize Hugging Face API for STT functionality

Browse files
Files changed (2) hide show
  1. services/chatbot_service.py +52 -55
  2. services/stt_service.py +34 -48
services/chatbot_service.py CHANGED
@@ -1,6 +1,5 @@
1
  from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
2
  import logging
3
- import torch
4
 
5
  logger = logging.getLogger(__name__)
6
 
@@ -9,33 +8,28 @@ chatbot_pipeline = None
9
  chat_history = {}
10
 
11
  def load_chatbot_model():
12
- """Load the free DialoGPT model for chatbot"""
13
  global chatbot_pipeline
14
  try:
15
- logger.info("Loading DialoGPT chatbot model...")
 
 
 
16
 
17
- # Use DialoGPT medium for better responses
18
  chatbot_pipeline = pipeline(
19
  "text-generation",
20
- model="microsoft/DialoGPT-medium",
21
- tokenizer="microsoft/DialoGPT-medium",
22
  device="cpu"
23
  )
24
- logger.info("βœ“ DialoGPT chatbot model loaded successfully")
25
  except Exception as e:
26
- logger.error(f"βœ— Failed to load DialoGPT model: {str(e)}")
27
  chatbot_pipeline = None
28
 
29
  async def get_chatbot_response(user_text: str, user_id: str = "default") -> str:
30
  """
31
- Generate chatbot response using free DialoGPT model.
32
-
33
- Args:
34
- user_text: User input text
35
- user_id: Unique user ID for maintaining conversation history
36
-
37
- Returns:
38
- Chatbot response text
39
  """
40
  global chatbot_pipeline
41
 
@@ -47,42 +41,34 @@ async def get_chatbot_response(user_text: str, user_id: str = "default") -> str:
47
 
48
  logger.info(f"Generating chatbot response for: '{user_text}'")
49
 
50
- # Get or initialize chat history for this user
51
- if user_id not in chat_history:
52
- chat_history[user_id] = []
53
 
54
- # Prepare conversation context
55
- conversation = chat_history[user_id] + [user_text]
56
- context = " ".join(conversation[-3:]) # Use last 3 exchanges as context
57
-
58
- # Generate response
59
  response = chatbot_pipeline(
60
- context,
61
- max_length=150,
62
- num_return_sequences=1,
63
- pad_token_id=chatbot_pipeline.tokenizer.eos_token_id,
64
- no_repeat_ngram_size=3,
65
  do_sample=True,
66
- top_k=50,
67
- top_p=0.95,
68
- temperature=0.7
 
69
  )
70
 
71
- bot_response = response[0]['generated_text'].strip()
 
72
 
73
- # Extract only the new response (remove the input context)
74
- if context in bot_response:
75
- bot_response = bot_response.replace(context, "").strip()
 
 
76
 
77
  # Clean up the response
78
  bot_response = clean_response(bot_response)
79
 
80
- # Update chat history
81
- chat_history[user_id].extend([user_text, bot_response])
82
-
83
- # Keep only recent history (last 4 exchanges)
84
- if len(chat_history[user_id]) > 8:
85
- chat_history[user_id] = chat_history[user_id][-8:]
86
 
87
  logger.info(f"βœ“ Response generated: '{bot_response}'")
88
  return bot_response
@@ -94,27 +80,38 @@ async def get_chatbot_response(user_text: str, user_id: str = "default") -> str:
94
 
95
  def clean_response(response: str) -> str:
96
  """Clean and format the chatbot response"""
 
 
 
97
  # Remove extra spaces
98
  response = ' '.join(response.split())
99
 
100
- # Ensure proper sentence ending
101
- if response and not response.endswith(('.', '!', '?')):
102
- response += '.'
103
-
104
- # Limit response length
105
- if len(response) > 200:
106
- response = response[:197] + '...'
 
 
 
 
 
 
 
107
 
108
- return response
109
 
110
 
111
  def get_fallback_response(user_text: str) -> str:
112
- """Provide fallback responses when model fails"""
113
  fallback_responses = [
114
- f"I understand you said: '{user_text}'. Could you tell me more about that?",
115
- f"That's interesting! You mentioned: '{user_text}'. What would you like to know?",
116
- f"Thanks for sharing! Regarding '{user_text}', how can I help you?",
117
- f"I heard you say: '{user_text}'. Could you elaborate on that?"
 
118
  ]
119
 
120
  import random
 
1
  from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
2
  import logging
 
3
 
4
  logger = logging.getLogger(__name__)
5
 
 
8
  chat_history = {}
9
 
10
  def load_chatbot_model():
11
+ """Load a better free chatbot model"""
12
  global chatbot_pipeline
13
  try:
14
+ logger.info("Loading better chatbot model...")
15
+
16
+ # Use a more reliable model
17
+ model_name = "microsoft/DialoGPT-small" # More reliable than medium
18
 
 
19
  chatbot_pipeline = pipeline(
20
  "text-generation",
21
+ model=model_name,
22
+ tokenizer=model_name,
23
  device="cpu"
24
  )
25
+ logger.info("βœ“ Chatbot model loaded successfully")
26
  except Exception as e:
27
+ logger.error(f"βœ— Failed to load chatbot model: {str(e)}")
28
  chatbot_pipeline = None
29
 
30
  async def get_chatbot_response(user_text: str, user_id: str = "default") -> str:
31
  """
32
+ Generate chatbot response using free model.
 
 
 
 
 
 
 
33
  """
34
  global chatbot_pipeline
35
 
 
41
 
42
  logger.info(f"Generating chatbot response for: '{user_text}'")
43
 
44
+ # Prepare prompt
45
+ prompt = f"User: {user_text}\nAssistant:"
 
46
 
47
+ # Generate response with better parameters
 
 
 
 
48
  response = chatbot_pipeline(
49
+ prompt,
50
+ max_new_tokens=100, # Reduced for better responses
 
 
 
51
  do_sample=True,
52
+ temperature=0.7,
53
+ top_p=0.9,
54
+ pad_token_id=chatbot_pipeline.tokenizer.eos_token_id,
55
+ repetition_penalty=1.1
56
  )
57
 
58
+ # Extract the response
59
+ generated_text = response[0]['generated_text']
60
 
61
+ # Extract only the assistant's response
62
+ if "Assistant:" in generated_text:
63
+ bot_response = generated_text.split("Assistant:")[-1].strip()
64
+ else:
65
+ bot_response = generated_text.replace(prompt, "").strip()
66
 
67
  # Clean up the response
68
  bot_response = clean_response(bot_response)
69
 
70
+ if not bot_response:
71
+ bot_response = get_fallback_response(user_text)
 
 
 
 
72
 
73
  logger.info(f"βœ“ Response generated: '{bot_response}'")
74
  return bot_response
 
80
 
81
  def clean_response(response: str) -> str:
82
  """Clean and format the chatbot response"""
83
+ if not response:
84
+ return ""
85
+
86
  # Remove extra spaces
87
  response = ' '.join(response.split())
88
 
89
+ # Remove any incomplete sentences at the end
90
+ if len(response) > 1:
91
+ # Ensure it ends with proper punctuation
92
+ if not response.endswith(('.', '!', '?')):
93
+ # Find the last sentence end
94
+ last_period = response.rfind('.')
95
+ last_exclamation = response.rfind('!')
96
+ last_question = response.rfind('?')
97
+ last_end = max(last_period, last_exclamation, last_question)
98
+
99
+ if last_end > 0:
100
+ response = response[:last_end + 1]
101
+ else:
102
+ response = response + '.'
103
 
104
+ return response.strip()
105
 
106
 
107
  def get_fallback_response(user_text: str) -> str:
108
+ """Provide better fallback responses"""
109
  fallback_responses = [
110
+ f"I understand you said: '{user_text}'. How can I help you with that?",
111
+ f"That's interesting! Regarding '{user_text}', what would you like to know?",
112
+ f"Thanks for your message about '{user_text}'. How can I assist you further?",
113
+ f"I heard you mention '{user_text}'. Could you tell me more about what you need?",
114
+ f"Regarding '{user_text}', I'd be happy to help. What specific information are you looking for?"
115
  ]
116
 
117
  import random
services/stt_service.py CHANGED
@@ -1,71 +1,57 @@
1
- import torch
2
- from transformers import pipeline
3
  import logging
4
  import tempfile
5
  import os
6
 
7
  logger = logging.getLogger(__name__)
8
 
9
- # Global STT pipeline
10
- stt_pipeline = None
11
-
12
- def load_stt_model():
13
- """Load the free Whisper model for speech-to-text"""
14
- global stt_pipeline
15
- try:
16
- logger.info("Loading Whisper-medium STT model...")
17
- stt_pipeline = pipeline(
18
- "automatic-speech-recognition",
19
- model="openai/whisper-medium",
20
- device="cpu"
21
- )
22
- logger.info("βœ“ Whisper-medium STT model loaded successfully")
23
- except Exception as e:
24
- logger.error(f"βœ— Failed to load Whisper-medium model: {str(e)}")
25
- stt_pipeline = None
26
-
27
  async def speech_to_text(audio_bytes: bytes, filename: str) -> str:
28
  """
29
- Convert audio bytes to text using free Whisper model.
30
-
31
- Args:
32
- audio_bytes: Raw audio file bytes
33
- filename: Name of the audio file
34
-
35
- Returns:
36
- Transcribed text
37
  """
38
- global stt_pipeline
39
-
40
  try:
41
- if stt_pipeline is None:
42
- load_stt_model()
43
- if stt_pipeline is None:
44
- raise Exception("STT model failed to load")
45
 
46
- logger.info(f"Converting audio to text using Whisper-medium")
 
47
 
48
- # Save audio bytes to temporary file
49
- with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as temp_audio:
50
- temp_audio.write(audio_bytes)
51
- temp_audio_path = temp_audio.name
52
 
53
- try:
54
- # Transcribe using Whisper
55
- result = stt_pipeline(temp_audio_path)
56
  transcribed_text = result.get("text", "").strip()
57
 
58
  if not transcribed_text:
59
- transcribed_text = "Sorry, I couldn't understand the audio."
60
 
61
  logger.info(f"βœ“ STT successful: '{transcribed_text}'")
62
  return transcribed_text
63
 
64
- finally:
65
- # Clean up temporary file
66
- if os.path.exists(temp_audio_path):
67
- os.unlink(temp_audio_path)
 
68
 
69
  except Exception as e:
70
  logger.error(f"βœ— STT failed: {str(e)}")
71
- raise Exception(f"Speech-to-text conversion failed: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
 
2
  import logging
3
  import tempfile
4
  import os
5
 
6
  logger = logging.getLogger(__name__)
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  async def speech_to_text(audio_bytes: bytes, filename: str) -> str:
9
  """
10
+ Convert audio bytes to text using Hugging Face Inference API.
11
+ No ffmpeg required!
 
 
 
 
 
 
12
  """
 
 
13
  try:
14
+ logger.info(f"Converting audio to text using Hugging Face API")
15
+
16
+ # Use Hugging Face Inference API (free, no ffmpeg needed)
17
+ API_URL = "https://api-inference.huggingface.co/models/openai/whisper-medium"
18
 
19
+ # For Hugging Face Spaces, you might not need an API key for public models
20
+ headers = {}
21
 
22
+ # Send audio bytes directly to Hugging Face API
23
+ response = requests.post(API_URL, headers=headers, data=audio_bytes)
 
 
24
 
25
+ if response.status_code == 200:
26
+ result = response.json()
 
27
  transcribed_text = result.get("text", "").strip()
28
 
29
  if not transcribed_text:
30
+ transcribed_text = "No speech detected in the audio."
31
 
32
  logger.info(f"βœ“ STT successful: '{transcribed_text}'")
33
  return transcribed_text
34
 
35
+ else:
36
+ # If API fails, use fallback
37
+ error_msg = f"Hugging Face API error: {response.status_code}"
38
+ logger.error(error_msg)
39
+ return await fallback_stt(audio_bytes, filename)
40
 
41
  except Exception as e:
42
  logger.error(f"βœ— STT failed: {str(e)}")
43
+ return await fallback_stt(audio_bytes, filename)
44
+
45
+
46
+ async def fallback_stt(audio_bytes: bytes, filename: str) -> str:
47
+ """Fallback STT using a simpler approach"""
48
+ try:
49
+ # Simple fallback that doesn't require ffmpeg
50
+ file_size = len(audio_bytes)
51
+ file_type = filename.split('.')[-1] if '.' in filename else 'unknown'
52
+
53
+ return f"Audio file '{filename}' ({file_type}, {file_size} bytes) received successfully. For full STT, please ensure ffmpeg is installed or use the Hugging Face API directly."
54
+
55
+ except Exception as e:
56
+ logger.error(f"Fallback STT also failed: {str(e)}")
57
+ return "Audio processing failed. Please try a different audio format or install ffmpeg."