Denis202 commited on
Commit
fe24b96
Β·
verified Β·
1 Parent(s): bf554e7

Update chat.py

Browse files
Files changed (1) hide show
  1. chat.py +49 -124
chat.py CHANGED
@@ -1,148 +1,73 @@
1
  import torch
2
  from transformers import AutoModelForSequenceClassification, AutoTokenizer
3
- from typing import List, Tuple, Optional
4
  import logging
5
  import json
6
  import os
7
  import numpy as np
8
  import re
9
 
10
- # Set up logging
11
  logging.basicConfig(level=logging.INFO)
12
  logger = logging.getLogger(__name__)
13
 
14
  class KiswahiliChatbot:
15
- def __init__(self, model_name: str = "bert-base-multilingual-cased", device: str = None):
16
- """
17
- BERT-based Kiswahili chatbot with response selection
18
- """
19
- try:
20
- self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
21
- logger.info(f"Inatumia kifaa: {self.device}")
22
 
23
- logger.info(f"Inapakia modeli ya BERT '{model_name}'...")
24
-
25
- # Load model and tokenizer
26
- model_path = "./trained_bert_model"
27
- if os.path.exists(model_path):
28
- self.tokenizer = AutoTokenizer.from_pretrained(model_path)
29
- self.model = AutoModelForSequenceClassification.from_pretrained(model_path)
30
- logger.info("βœ… Modeli iliyofunzwa imepakika!")
31
- else:
32
- logger.info("ℹ️ Modeli ya msingi ya BERT inatumika")
33
- self.tokenizer = AutoTokenizer.from_pretrained(model_name)
34
- self.model = AutoModelForSequenceClassification.from_pretrained(
35
- model_name,
36
- num_labels=2
37
- )
38
-
39
- self.model.to(self.device)
40
- self.model.eval()
41
-
42
- # Load response bank
43
- self.responses = self._load_response_bank()
44
- logger.info(f"πŸ“‹ Benki ya majibu: {len(self.responses)} majibu")
45
-
46
- except Exception as e:
47
- logger.error(f"❌ Hitilafu wakati wa kupakia modeli: {e}")
48
- raise
49
 
50
  def _load_response_bank(self):
51
- """Load response bank from file or use defaults"""
52
  response_file = "./trained_bert_model/responses.json"
53
- responses = []
54
-
55
  if os.path.exists(response_file):
56
- try:
57
- with open(response_file, 'r', encoding='utf-8') as f:
58
- data = json.load(f)
59
- responses = data.get('responses', [])
60
- except Exception as e:
61
- logger.error(f"❌ Hitilafu wakati wa kusoma faili ya majibu: {e}")
62
-
63
- # Add fallback responses if empty
64
- if not responses:
65
- responses = [
66
- "Habari yako? Naitwa KiswahiliChetu, naweza kukusaidia na Kiswahili.",
67
- "Asante kwa kuuliza! Ninafurahi kukusaidia na maswali yako ya Kiswahili.",
68
- "Samahani, sielewi swali lako. Unaweza kuuliza kwa Kiswahili?",
69
- "Ninajua Kiswahili vizuri. Nitaweza kukujibu maswali yako.",
70
- "Tanzania ni nchi nzuri yenye utamaduni mwingi na lugha ya Kiswahili.",
71
- "Hakuna matata inamaanisha 'hamna shida' kwa Kiswahili.",
72
- "Unauliza kuhusu nini hasa? Ninaweza kukusaidia na Kiswahili.",
73
- "Karibu katika masomo ya Kiswahili! Nianzie na swali lako."
74
- ]
75
-
76
- return responses
77
 
78
- def _select_best_response(self, user_input: str) -> str:
79
- """Select the best response using BERT scoring"""
80
- if not self.responses:
81
- return "Samahani, sijafunzwa majibu bado. Tafadhali fanya mafunzo kwanza."
82
-
83
- # Score all responses
84
- scores = []
85
  for response in self.responses:
86
- # Format input for BERT
87
- text = f"{user_input} [SEP] {response}"
88
- inputs = self.tokenizer(
89
- text,
90
- return_tensors="pt",
91
- truncation=True,
92
- max_length=256,
93
- padding=True
94
- ).to(self.device)
95
-
96
- # Get prediction
97
  with torch.no_grad():
98
  outputs = self.model(**inputs)
99
- prediction = torch.softmax(outputs.logits, dim=1)
100
- score = prediction[0][1].item() # Probability it's a good response
101
-
102
- scores.append((response, score))
103
-
104
- # Sort by score and return best response
105
- scores.sort(key=lambda x: x[1], reverse=True)
106
-
107
- # Return the best response
108
- return scores[0][0]
109
 
110
- def _clean_input(self, text: str) -> str:
111
- """Clean user input"""
112
- text = re.sub(r'[^\w\s?]', '', text) # Remove special chars except spaces and ?
113
- text = ' '.join(text.split()) # Remove extra spaces
114
- return text.lower()
115
 
116
- def _clean_response(self, response: str) -> str:
117
- """Clean up the response"""
118
- response = response.strip()
119
- # Ensure proper punctuation
120
- if response and not response.endswith(('.', '!', '?')):
121
- response += '.'
122
- # Capitalize first letter
123
- if response:
124
- response = response[0].upper() + response[1:]
125
- return response
126
 
127
- def chat(self, message: str) -> str:
128
- """
129
- Main chat method
130
- """
131
- try:
132
- if not message.strip():
133
- return "Tafadhali andika ujumbe..."
134
-
135
- # Clean and preprocess input
136
- cleaned_input = self._clean_input(message)
137
-
138
- # Select best response
139
- response = self._select_best_response(cleaned_input)
140
-
141
- # Final cleanup
142
- response = self._clean_response(response)
143
-
144
- return response
145
-
146
- except Exception as e:
147
- logger.error(f"❌ Hitilafu wakati wa kukokotoa jibu: {e}")
148
- return "Samahani, kuna hitilafu ya kiufundi. Tafadhali jaribu tena."
 
1
  import torch
2
  from transformers import AutoModelForSequenceClassification, AutoTokenizer
 
3
  import logging
4
  import json
5
  import os
6
  import numpy as np
7
  import re
8
 
 
9
  logging.basicConfig(level=logging.INFO)
10
  logger = logging.getLogger(__name__)
11
 
12
  class KiswahiliChatbot:
13
+ def __init__(self, model_path="./trained_bert_model", device=None, threshold=0.6):
14
+ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
15
+ logger.info(f"Using device: {self.device}")
 
 
 
 
16
 
17
+ # Load model
18
+ if os.path.exists(model_path):
19
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
20
+ self.model = AutoModelForSequenceClassification.from_pretrained(model_path)
21
+ logger.info("βœ… Trained model loaded!")
22
+ else:
23
+ raise FileNotFoundError(f"{model_path} not found. Please train the model first.")
24
+
25
+ self.model.to(self.device)
26
+ self.model.eval()
27
+ self.threshold = threshold # minimum probability to accept a response
28
+
29
+ # Load responses
30
+ self.responses = self._load_response_bank()
31
+ logger.info(f"πŸ“‹ Loaded {len(self.responses)} responses")
 
 
 
 
 
 
 
 
 
 
 
32
 
33
  def _load_response_bank(self):
 
34
  response_file = "./trained_bert_model/responses.json"
 
 
35
  if os.path.exists(response_file):
36
+ with open(response_file, 'r', encoding='utf-8') as f:
37
+ data = json.load(f)
38
+ return data.get("responses", [])
39
+ return []
40
+
41
+ def _clean_text(self, text: str) -> str:
42
+ text = re.sub(r'[^\w\s?]', '', text)
43
+ return ' '.join(text.split()).lower()
44
+
45
+ def chat(self, user_input: str) -> str:
46
+ user_input_clean = self._clean_text(user_input)
47
+ if not user_input_clean:
48
+ return "Tafadhali andika ujumbe."
49
+
50
+ best_response = None
51
+ best_score = 0.0
 
 
 
 
 
52
 
 
 
 
 
 
 
 
53
  for response in self.responses:
54
+ combined_text = f"{user_input_clean} [SEP] {response}"
55
+ inputs = self.tokenizer(combined_text, return_tensors="pt", truncation=True, max_length=256, padding=True).to(self.device)
 
 
 
 
 
 
 
 
 
56
  with torch.no_grad():
57
  outputs = self.model(**inputs)
58
+ probs = torch.softmax(outputs.logits, dim=1)
59
+ score = probs[0][1].item() # probability of being the correct response
 
 
 
 
 
 
 
 
60
 
61
+ if score > best_score:
62
+ best_score = score
63
+ best_response = response
 
 
64
 
65
+ if best_score < self.threshold:
66
+ return "Samahani, sielewi. Unaweza kuuliza kwa njia nyingine?"
 
 
 
 
 
 
 
 
67
 
68
+ # Capitalize first letter and ensure punctuation
69
+ best_response = best_response.strip()
70
+ if best_response and not best_response.endswith(('.', '!', '?')):
71
+ best_response += '.'
72
+ best_response = best_response[0].upper() + best_response[1:]
73
+ return best_response