Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -12,14 +12,14 @@ class MariannaBot:
|
|
| 12 |
self.db_keys = [key.decode("utf-8") for key, value in self.database.items()]
|
| 13 |
self.reset_state()
|
| 14 |
|
|
|
|
| 15 |
def initialize_encoder(self):
|
| 16 |
"""
|
| 17 |
-
|
| 18 |
-
Questo metodo dovrebbe essere chiamato una sola volta all'avvio del bot.
|
| 19 |
"""
|
| 20 |
try:
|
| 21 |
# Initialize the encoder model
|
| 22 |
-
encoder_model = "nickprock/sentence-bert-base-italian-uncased"
|
| 23 |
cross_encoder_model = "nickprock/cross-encoder-italian-bert-stsb"
|
| 24 |
self.encoder = SentenceTransformer(encoder_model)
|
| 25 |
self.cross_encoder = CrossEncoder(cross_encoder_model)
|
|
@@ -63,7 +63,6 @@ class MariannaBot:
|
|
| 63 |
if not legend_keys:
|
| 64 |
return "Mi dispiace, al momento non ho leggende da raccontare."
|
| 65 |
|
| 66 |
-
# Se abbiamo già raccontato tutte le storie, ricominciamo
|
| 67 |
available_keys = [key for key in legend_keys if key.decode('utf-8') not in self.main_k]
|
| 68 |
if not available_keys:
|
| 69 |
self.main_k = [] # Reset della lista delle storie raccontate
|
|
@@ -81,7 +80,7 @@ class MariannaBot:
|
|
| 81 |
except Exception:
|
| 82 |
self.state = "initial"
|
| 83 |
self.is_telling_stories = False
|
| 84 |
-
return "Mi dispiace, c'è stato un problema. Vuoi provare con qualcos'altro? (sì/no)"
|
| 85 |
|
| 86 |
def handle_query(self, message):
|
| 87 |
"""Handle user queries by searching the database"""
|
|
@@ -104,11 +103,38 @@ class MariannaBot:
|
|
| 104 |
|
| 105 |
best_hit = reranked_hits[0]
|
| 106 |
best_title = self.db_keys[best_hit['corpus_id']]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
|
| 109 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
if best_title is not None:
|
| 111 |
-
best_title_bytes = best_title.encode("utf-8")
|
| 112 |
|
| 113 |
if best_title_bytes in self.database:
|
| 114 |
value = self.database[best_title_bytes]
|
|
@@ -138,7 +164,7 @@ class MariannaBot:
|
|
| 138 |
if message in ["sì", "si"]:
|
| 139 |
self.state = "query"
|
| 140 |
self.is_telling_stories = False
|
| 141 |
-
return "
|
| 142 |
elif message == "no":
|
| 143 |
self.state = "end"
|
| 144 |
return "Va bene, grazie per aver parlato con me."
|
|
@@ -195,7 +221,7 @@ def main():
|
|
| 195 |
gr.Markdown("## Chat con Marianna - 'La Testa di Napoli'")
|
| 196 |
|
| 197 |
with gr.Row():
|
| 198 |
-
gr.Image("marianna-102.jpeg",
|
| 199 |
elem_id="marianna-image",
|
| 200 |
width=250)
|
| 201 |
|
|
|
|
| 12 |
self.db_keys = [key.decode("utf-8") for key, value in self.database.items()]
|
| 13 |
self.reset_state()
|
| 14 |
|
| 15 |
+
|
| 16 |
def initialize_encoder(self):
|
| 17 |
"""
|
| 18 |
+
Initialize encoder and cross-encoder model.
|
|
|
|
| 19 |
"""
|
| 20 |
try:
|
| 21 |
# Initialize the encoder model
|
| 22 |
+
encoder_model = "nickprock/sentence-bert-base-italian-xxl-uncased"
|
| 23 |
cross_encoder_model = "nickprock/cross-encoder-italian-bert-stsb"
|
| 24 |
self.encoder = SentenceTransformer(encoder_model)
|
| 25 |
self.cross_encoder = CrossEncoder(cross_encoder_model)
|
|
|
|
| 63 |
if not legend_keys:
|
| 64 |
return "Mi dispiace, al momento non ho leggende da raccontare."
|
| 65 |
|
|
|
|
| 66 |
available_keys = [key for key in legend_keys if key.decode('utf-8') not in self.main_k]
|
| 67 |
if not available_keys:
|
| 68 |
self.main_k = [] # Reset della lista delle storie raccontate
|
|
|
|
| 80 |
except Exception:
|
| 81 |
self.state = "initial"
|
| 82 |
self.is_telling_stories = False
|
| 83 |
+
return "Mi dispiace, c'è stato un problema nel recuperare la storia. Vuoi provare con qualcos'altro? (sì/no)"
|
| 84 |
|
| 85 |
def handle_query(self, message):
|
| 86 |
"""Handle user queries by searching the database"""
|
|
|
|
| 103 |
|
| 104 |
best_hit = reranked_hits[0]
|
| 105 |
best_title = self.db_keys[best_hit['corpus_id']]
|
| 106 |
+
best_score = best_hit['cross-score']
|
| 107 |
+
#print(best_title, best_score)
|
| 108 |
+
|
| 109 |
+
# Main treshold = 0.75
|
| 110 |
+
similarity_threshold = 0.75
|
| 111 |
|
| 112 |
+
# treshold granularity
|
| 113 |
+
if best_score < similarity_threshold:
|
| 114 |
+
# low confidence (< 0.35)
|
| 115 |
+
if best_score < 0.55:
|
| 116 |
+
return "Mi dispiace, non ho informazioni su questo argomento. Puoi chiedermi di altro sulla città di Napoli."
|
| 117 |
|
| 118 |
+
|
| 119 |
+
# medium confidence(0.55 - 0.75)
|
| 120 |
+
else:
|
| 121 |
+
|
| 122 |
+
alternative_hits = [self.db_keys[hit['corpus_id']] for hit in reranked_hits[:2]]
|
| 123 |
+
suggestions = ", ".join(alternative_hits)
|
| 124 |
+
best_title_bytes = best_title.encode("utf-8")
|
| 125 |
+
if best_title_bytes in self.database:
|
| 126 |
+
value = self.database[best_title_bytes]
|
| 127 |
+
deserialized_value = pickle.loads(value)
|
| 128 |
+
partial_info = deserialized_value.get('short_intro', deserialized_value['intro'].split('.')[0] + '.')
|
| 129 |
+
self.state = "query"
|
| 130 |
+
self.is_telling_stories = False
|
| 131 |
+
return f"Potrei avere alcune informazioni su {best_title}, ma non sono completamente sicura sia ciò che stai cercando. I miei suggerimenti sono {suggestions}. \n\nCosa ti interessa?"
|
| 132 |
+
else:
|
| 133 |
+
return f"Ho trovato qualcosa su {best_title}, ma non sono completamente sicura. Vuoi saperne di più? (sì/no)"
|
| 134 |
+
|
| 135 |
+
# high confidence (above the threshold)
|
| 136 |
if best_title is not None:
|
| 137 |
+
best_title_bytes = best_title.encode("utf-8")
|
| 138 |
|
| 139 |
if best_title_bytes in self.database:
|
| 140 |
value = self.database[best_title_bytes]
|
|
|
|
| 164 |
if message in ["sì", "si"]:
|
| 165 |
self.state = "query"
|
| 166 |
self.is_telling_stories = False
|
| 167 |
+
return "Potresti dirmi di cosa vorresti sapere?"
|
| 168 |
elif message == "no":
|
| 169 |
self.state = "end"
|
| 170 |
return "Va bene, grazie per aver parlato con me."
|
|
|
|
| 221 |
gr.Markdown("## Chat con Marianna - 'La Testa di Napoli'")
|
| 222 |
|
| 223 |
with gr.Row():
|
| 224 |
+
gr.Image("/home/filippo/Scrivania/Marianna_head/Marianna_testa/marianna-102.jpeg",
|
| 225 |
elem_id="marianna-image",
|
| 226 |
width=250)
|
| 227 |
|