Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -2,12 +2,36 @@ import gradio as gr
|
|
| 2 |
import random
|
| 3 |
import berkeleydb
|
| 4 |
import pickle
|
|
|
|
|
|
|
| 5 |
|
| 6 |
class MariannaBot:
|
| 7 |
def __init__(self):
|
| 8 |
self.database = berkeleydb.hashopen("wiki_napoli_main.db", flag="c")
|
| 9 |
self.database_legends = berkeleydb.hashopen("wiki_naples_leggende.db", flag="c")
|
|
|
|
| 10 |
self.reset_state()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
def reset_state(self):
|
| 13 |
self.state = "initial"
|
|
@@ -20,9 +44,9 @@ class MariannaBot:
|
|
| 20 |
def get_welcome_message(self):
|
| 21 |
return """Ciao, benvenuto!\n\nSono Marianna, la testa di Napoli, in napoletano 'a capa 'e Napule, una statua ritrovata per caso nel 1594. \nAll'epoca del mio ritrovamento, si pensò che fossi una rappresentazione della sirena Partenope, dalle cui spoglie, leggenda narra, nacque la città di Napoli. In seguito, diversi studiosi riconobbero in me una statua della dea Venere, probabilmente collocata in uno dei tanti templi che si trovavano nella città in epoca tardo-romana, quando ancora si chiamava Neapolis.
|
| 22 |
\nPosso raccontarti molte storie sulla città di Napoli e mostrarti le sue bellezze. \nC'è qualcosa in particolare che ti interessa?
|
| 23 |
-
\n(Rispondi con 'sì', 'no' o 'non so
|
| 24 |
|
| 25 |
-
def get_safe_example_keys(self, num_examples=
|
| 26 |
"""Safely get example keys from the database."""
|
| 27 |
try:
|
| 28 |
keys = list(self.database.keys())
|
|
@@ -62,10 +86,34 @@ class MariannaBot:
|
|
| 62 |
def handle_query(self, message):
|
| 63 |
"""Handle user queries by searching the database"""
|
| 64 |
try:
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
self.state = "follow_up"
|
| 70 |
self.is_telling_stories = False
|
| 71 |
deserialized_value = pickle.loads(value)
|
|
@@ -73,8 +121,10 @@ class MariannaBot:
|
|
| 73 |
self.current_further_info_values = list(deserialized_value.get('further_info', {}).values())
|
| 74 |
self.current_index = 0
|
| 75 |
return f"{response}\n\nVuoi sapere altro su {self.main_k[-1]}? (sì/no)"
|
| 76 |
-
|
| 77 |
-
|
|
|
|
|
|
|
| 78 |
self.state = "initial"
|
| 79 |
return "Mi dispiace, c'è stato un errore. Puoi riprovare con un'altra domanda?"
|
| 80 |
|
|
@@ -88,14 +138,14 @@ class MariannaBot:
|
|
| 88 |
if message in ["sì", "si"]:
|
| 89 |
self.state = "query"
|
| 90 |
self.is_telling_stories = False
|
| 91 |
-
return "
|
| 92 |
elif message == "no":
|
| 93 |
self.state = "end"
|
| 94 |
return "Va bene, grazie per aver parlato con me."
|
| 95 |
-
elif message == "non so
|
| 96 |
return self.story_flow()
|
| 97 |
else:
|
| 98 |
-
return "Scusa, non ho capito. Puoi rispondere con 'sì', 'no' o 'non so
|
| 99 |
|
| 100 |
elif self.state == "query":
|
| 101 |
return self.handle_query(message)
|
|
@@ -128,6 +178,7 @@ class MariannaBot:
|
|
| 128 |
|
| 129 |
def main():
|
| 130 |
bot = MariannaBot()
|
|
|
|
| 131 |
|
| 132 |
def update_chatbot(message, history):
|
| 133 |
if not message.strip():
|
|
@@ -141,10 +192,10 @@ def main():
|
|
| 141 |
|
| 142 |
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
|
| 143 |
with gr.Row():
|
| 144 |
-
gr.Markdown("##
|
| 145 |
|
| 146 |
with gr.Row():
|
| 147 |
-
gr.Image("marianna-102.jpeg",
|
| 148 |
elem_id="marianna-image",
|
| 149 |
width=250)
|
| 150 |
|
|
|
|
| 2 |
import random
|
| 3 |
import berkeleydb
|
| 4 |
import pickle
|
| 5 |
+
from sentence_transformers import SentenceTransformer, CrossEncoder, util
|
| 6 |
+
|
| 7 |
|
| 8 |
class MariannaBot:
|
| 9 |
def __init__(self):
|
| 10 |
self.database = berkeleydb.hashopen("wiki_napoli_main.db", flag="c")
|
| 11 |
self.database_legends = berkeleydb.hashopen("wiki_naples_leggende.db", flag="c")
|
| 12 |
+
self.db_keys = [key.decode("utf-8") for key, value in self.database.items()]
|
| 13 |
self.reset_state()
|
| 14 |
+
|
| 15 |
+
def initialize_encoder(self):
|
| 16 |
+
"""
|
| 17 |
+
Inizializza il modello encoder e pre-calcola gli embedding delle chiavi del database.
|
| 18 |
+
Questo metodo dovrebbe essere chiamato una sola volta all'avvio del bot.
|
| 19 |
+
"""
|
| 20 |
+
try:
|
| 21 |
+
# Initialize the encoder model
|
| 22 |
+
encoder_model = "nickprock/sentence-bert-base-italian-uncased"
|
| 23 |
+
cross_encoder_model = "nickprock/cross-encoder-italian-bert-stsb"
|
| 24 |
+
self.encoder = SentenceTransformer(encoder_model)
|
| 25 |
+
self.cross_encoder = CrossEncoder(cross_encoder_model)
|
| 26 |
+
|
| 27 |
+
# Pre-encode all database keys
|
| 28 |
+
self.db_keys_embeddings = self.encoder.encode(self.db_keys, convert_to_tensor=True)
|
| 29 |
+
|
| 30 |
+
print(f"Encoder initialized with {len(self.db_keys)} keys.")
|
| 31 |
+
return True
|
| 32 |
+
except Exception as e:
|
| 33 |
+
print(f"Error initializing encoder: {str(e)}")
|
| 34 |
+
return False
|
| 35 |
|
| 36 |
def reset_state(self):
|
| 37 |
self.state = "initial"
|
|
|
|
| 44 |
def get_welcome_message(self):
|
| 45 |
return """Ciao, benvenuto!\n\nSono Marianna, la testa di Napoli, in napoletano 'a capa 'e Napule, una statua ritrovata per caso nel 1594. \nAll'epoca del mio ritrovamento, si pensò che fossi una rappresentazione della sirena Partenope, dalle cui spoglie, leggenda narra, nacque la città di Napoli. In seguito, diversi studiosi riconobbero in me una statua della dea Venere, probabilmente collocata in uno dei tanti templi che si trovavano nella città in epoca tardo-romana, quando ancora si chiamava Neapolis.
|
| 46 |
\nPosso raccontarti molte storie sulla città di Napoli e mostrarti le sue bellezze. \nC'è qualcosa in particolare che ti interessa?
|
| 47 |
+
\n(Rispondi con 'sì', 'no' o 'non so')"""
|
| 48 |
|
| 49 |
+
def get_safe_example_keys(self, num_examples=3):
|
| 50 |
"""Safely get example keys from the database."""
|
| 51 |
try:
|
| 52 |
keys = list(self.database.keys())
|
|
|
|
| 86 |
def handle_query(self, message):
|
| 87 |
"""Handle user queries by searching the database"""
|
| 88 |
try:
|
| 89 |
+
# Encode the user query
|
| 90 |
+
query_embedding = self.encoder.encode(message, convert_to_tensor=True)
|
| 91 |
+
|
| 92 |
+
# Perform semantic search on the keys
|
| 93 |
+
semantic_hits = util.semantic_search(query_embedding, self.db_keys_embeddings, top_k=3)
|
| 94 |
+
semantic_hits = semantic_hits[0]
|
| 95 |
+
|
| 96 |
+
cross_inp = [(message, self.db_keys[hit['corpus_id']]) for hit in semantic_hits]
|
| 97 |
+
cross_scores = self.cross_encoder.predict(cross_inp)
|
| 98 |
+
|
| 99 |
+
reranked_hits = sorted(
|
| 100 |
+
[{'corpus_id': hit['corpus_id'], 'cross-score': score}
|
| 101 |
+
for hit, score in zip(semantic_hits, cross_scores)],
|
| 102 |
+
key=lambda x: x['cross-score'], reverse=True
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
best_hit = reranked_hits[0]
|
| 106 |
+
best_title = self.db_keys[best_hit['corpus_id']]
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
# Using your existing code for handling the best match
|
| 110 |
+
if best_title is not None:
|
| 111 |
+
best_title_bytes = best_title.encode("utf-8") # Converti la stringa in bytes
|
| 112 |
+
|
| 113 |
+
if best_title_bytes in self.database:
|
| 114 |
+
value = self.database[best_title_bytes]
|
| 115 |
+
key = best_title
|
| 116 |
+
self.main_k.append(key)
|
| 117 |
self.state = "follow_up"
|
| 118 |
self.is_telling_stories = False
|
| 119 |
deserialized_value = pickle.loads(value)
|
|
|
|
| 121 |
self.current_further_info_values = list(deserialized_value.get('further_info', {}).values())
|
| 122 |
self.current_index = 0
|
| 123 |
return f"{response}\n\nVuoi sapere altro su {self.main_k[-1]}? (sì/no)"
|
| 124 |
+
else:
|
| 125 |
+
return "Mi dispiace, non ho informazioni riguardo a questa domanda. Prova a chiedermi qualcos'altro sulla città di Napoli."
|
| 126 |
+
|
| 127 |
+
except Exception as e:
|
| 128 |
self.state = "initial"
|
| 129 |
return "Mi dispiace, c'è stato un errore. Puoi riprovare con un'altra domanda?"
|
| 130 |
|
|
|
|
| 138 |
if message in ["sì", "si"]:
|
| 139 |
self.state = "query"
|
| 140 |
self.is_telling_stories = False
|
| 141 |
+
return "Potresti dirmi di cosa vorresti sapere?"
|
| 142 |
elif message == "no":
|
| 143 |
self.state = "end"
|
| 144 |
return "Va bene, grazie per aver parlato con me."
|
| 145 |
+
elif message == "non so":
|
| 146 |
return self.story_flow()
|
| 147 |
else:
|
| 148 |
+
return "Scusa, non ho capito. Puoi rispondere con 'sì', 'no' o 'non so'."
|
| 149 |
|
| 150 |
elif self.state == "query":
|
| 151 |
return self.handle_query(message)
|
|
|
|
| 178 |
|
| 179 |
def main():
|
| 180 |
bot = MariannaBot()
|
| 181 |
+
bot.initialize_encoder()
|
| 182 |
|
| 183 |
def update_chatbot(message, history):
|
| 184 |
if not message.strip():
|
|
|
|
| 192 |
|
| 193 |
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
|
| 194 |
with gr.Row():
|
| 195 |
+
gr.Markdown("## Chat con Marianna - 'La Testa di Napoli'")
|
| 196 |
|
| 197 |
with gr.Row():
|
| 198 |
+
gr.Image("/home/filippo/Scrivania/Marianna_head/Marianna_testa/marianna-102.jpeg",
|
| 199 |
elem_id="marianna-image",
|
| 200 |
width=250)
|
| 201 |
|