HipFil98 commited on
Commit
190fffe
·
verified ·
1 Parent(s): 9be43af

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -9
app.py CHANGED
@@ -12,14 +12,14 @@ class MariannaBot:
12
  self.db_keys = [key.decode("utf-8") for key, value in self.database.items()]
13
  self.reset_state()
14
 
 
15
  def initialize_encoder(self):
16
  """
17
- Inizializza il modello encoder e pre-calcola gli embedding delle chiavi del database.
18
- Questo metodo dovrebbe essere chiamato una sola volta all'avvio del bot.
19
  """
20
  try:
21
  # Initialize the encoder model
22
- encoder_model = "nickprock/sentence-bert-base-italian-uncased"
23
  cross_encoder_model = "nickprock/cross-encoder-italian-bert-stsb"
24
  self.encoder = SentenceTransformer(encoder_model)
25
  self.cross_encoder = CrossEncoder(cross_encoder_model)
@@ -63,7 +63,6 @@ class MariannaBot:
63
  if not legend_keys:
64
  return "Mi dispiace, al momento non ho leggende da raccontare."
65
 
66
- # Se abbiamo già raccontato tutte le storie, ricominciamo
67
  available_keys = [key for key in legend_keys if key.decode('utf-8') not in self.main_k]
68
  if not available_keys:
69
  self.main_k = [] # Reset della lista delle storie raccontate
@@ -81,7 +80,7 @@ class MariannaBot:
81
  except Exception:
82
  self.state = "initial"
83
  self.is_telling_stories = False
84
- return "Mi dispiace, c'è stato un problema. Vuoi provare con qualcos'altro? (sì/no)"
85
 
86
  def handle_query(self, message):
87
  """Handle user queries by searching the database"""
@@ -104,11 +103,38 @@ class MariannaBot:
104
 
105
  best_hit = reranked_hits[0]
106
  best_title = self.db_keys[best_hit['corpus_id']]
 
 
 
 
 
107
 
 
 
 
 
 
108
 
109
- # Using your existing code for handling the best match
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  if best_title is not None:
111
- best_title_bytes = best_title.encode("utf-8") # Converti la stringa in bytes
112
 
113
  if best_title_bytes in self.database:
114
  value = self.database[best_title_bytes]
@@ -138,7 +164,7 @@ class MariannaBot:
138
  if message in ["sì", "si"]:
139
  self.state = "query"
140
  self.is_telling_stories = False
141
- return "Di cosa vorresti sapere?"
142
  elif message == "no":
143
  self.state = "end"
144
  return "Va bene, grazie per aver parlato con me."
@@ -195,7 +221,7 @@ def main():
195
  gr.Markdown("## Chat con Marianna - 'La Testa di Napoli'")
196
 
197
  with gr.Row():
198
- gr.Image("marianna-102.jpeg",
199
  elem_id="marianna-image",
200
  width=250)
201
 
 
12
  self.db_keys = [key.decode("utf-8") for key, value in self.database.items()]
13
  self.reset_state()
14
 
15
+
16
  def initialize_encoder(self):
17
  """
18
+ Initialize encoder and cross-encoder model.
 
19
  """
20
  try:
21
  # Initialize the encoder model
22
+ encoder_model = "nickprock/sentence-bert-base-italian-xxl-uncased"
23
  cross_encoder_model = "nickprock/cross-encoder-italian-bert-stsb"
24
  self.encoder = SentenceTransformer(encoder_model)
25
  self.cross_encoder = CrossEncoder(cross_encoder_model)
 
63
  if not legend_keys:
64
  return "Mi dispiace, al momento non ho leggende da raccontare."
65
 
 
66
  available_keys = [key for key in legend_keys if key.decode('utf-8') not in self.main_k]
67
  if not available_keys:
68
  self.main_k = [] # Reset della lista delle storie raccontate
 
80
  except Exception:
81
  self.state = "initial"
82
  self.is_telling_stories = False
83
+ return "Mi dispiace, c'è stato un problema nel recuperare la storia. Vuoi provare con qualcos'altro? (sì/no)"
84
 
85
  def handle_query(self, message):
86
  """Handle user queries by searching the database"""
 
103
 
104
  best_hit = reranked_hits[0]
105
  best_title = self.db_keys[best_hit['corpus_id']]
106
+ best_score = best_hit['cross-score']
107
+ #print(best_title, best_score)
108
+
109
+ # Main treshold = 0.75
110
+ similarity_threshold = 0.75
111
 
112
+ # treshold granularity
113
+ if best_score < similarity_threshold:
114
+ # low confidence (< 0.35)
115
+ if best_score < 0.55:
116
+ return "Mi dispiace, non ho informazioni su questo argomento. Puoi chiedermi di altro sulla città di Napoli."
117
 
118
+
119
+ # medium confidence(0.55 - 0.75)
120
+ else:
121
+
122
+ alternative_hits = [self.db_keys[hit['corpus_id']] for hit in reranked_hits[:2]]
123
+ suggestions = ", ".join(alternative_hits)
124
+ best_title_bytes = best_title.encode("utf-8")
125
+ if best_title_bytes in self.database:
126
+ value = self.database[best_title_bytes]
127
+ deserialized_value = pickle.loads(value)
128
+ partial_info = deserialized_value.get('short_intro', deserialized_value['intro'].split('.')[0] + '.')
129
+ self.state = "query"
130
+ self.is_telling_stories = False
131
+ return f"Potrei avere alcune informazioni su {best_title}, ma non sono completamente sicura sia ciò che stai cercando. I miei suggerimenti sono {suggestions}. \n\nCosa ti interessa?"
132
+ else:
133
+ return f"Ho trovato qualcosa su {best_title}, ma non sono completamente sicura. Vuoi saperne di più? (sì/no)"
134
+
135
+ # high confidence (above the threshold)
136
  if best_title is not None:
137
+ best_title_bytes = best_title.encode("utf-8")
138
 
139
  if best_title_bytes in self.database:
140
  value = self.database[best_title_bytes]
 
164
  if message in ["sì", "si"]:
165
  self.state = "query"
166
  self.is_telling_stories = False
167
+ return "Potresti dirmi di cosa vorresti sapere?"
168
  elif message == "no":
169
  self.state = "end"
170
  return "Va bene, grazie per aver parlato con me."
 
221
  gr.Markdown("## Chat con Marianna - 'La Testa di Napoli'")
222
 
223
  with gr.Row():
224
+ gr.Image("/home/filippo/Scrivania/Marianna_head/Marianna_testa/marianna-102.jpeg",
225
  elem_id="marianna-image",
226
  width=250)
227