HipFil98 commited on
Commit
f60fb65
·
verified ·
1 Parent(s): d89312c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -13
app.py CHANGED
@@ -2,12 +2,36 @@ import gradio as gr
2
  import random
3
  import berkeleydb
4
  import pickle
 
 
5
 
6
  class MariannaBot:
7
  def __init__(self):
8
  self.database = berkeleydb.hashopen("wiki_napoli_main.db", flag="c")
9
  self.database_legends = berkeleydb.hashopen("wiki_naples_leggende.db", flag="c")
 
10
  self.reset_state()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  def reset_state(self):
13
  self.state = "initial"
@@ -20,9 +44,9 @@ class MariannaBot:
20
  def get_welcome_message(self):
21
  return """Ciao, benvenuto!\n\nSono Marianna, la testa di Napoli, in napoletano 'a capa 'e Napule, una statua ritrovata per caso nel 1594. \nAll'epoca del mio ritrovamento, si pensò che fossi una rappresentazione della sirena Partenope, dalle cui spoglie, leggenda narra, nacque la città di Napoli. In seguito, diversi studiosi riconobbero in me una statua della dea Venere, probabilmente collocata in uno dei tanti templi che si trovavano nella città in epoca tardo-romana, quando ancora si chiamava Neapolis.
22
  \nPosso raccontarti molte storie sulla città di Napoli e mostrarti le sue bellezze. \nC'è qualcosa in particolare che ti interessa?
23
- \n(Rispondi con 'sì', 'no' o 'non so, scegli tu')"""
24
 
25
- def get_safe_example_keys(self, num_examples=5):
26
  """Safely get example keys from the database."""
27
  try:
28
  keys = list(self.database.keys())
@@ -62,10 +86,34 @@ class MariannaBot:
62
  def handle_query(self, message):
63
  """Handle user queries by searching the database"""
64
  try:
65
- for key, value in self.database.items():
66
- decoded_key = key.decode("utf-8").lower()
67
- if message == decoded_key:
68
- self.main_k.append(key.decode("utf-8"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  self.state = "follow_up"
70
  self.is_telling_stories = False
71
  deserialized_value = pickle.loads(value)
@@ -73,8 +121,10 @@ class MariannaBot:
73
  self.current_further_info_values = list(deserialized_value.get('further_info', {}).values())
74
  self.current_index = 0
75
  return f"{response}\n\nVuoi sapere altro su {self.main_k[-1]}? (sì/no)"
76
- return "Mi dispiace, non ho informazioni riguardo a questa domanda. Prova a chiedermi qualcos'altro sulla città di Napoli."
77
- except Exception:
 
 
78
  self.state = "initial"
79
  return "Mi dispiace, c'è stato un errore. Puoi riprovare con un'altra domanda?"
80
 
@@ -88,14 +138,14 @@ class MariannaBot:
88
  if message in ["sì", "si"]:
89
  self.state = "query"
90
  self.is_telling_stories = False
91
- return "Di cosa vuoi sapere?"
92
  elif message == "no":
93
  self.state = "end"
94
  return "Va bene, grazie per aver parlato con me."
95
- elif message == "non so, scegli te":
96
  return self.story_flow()
97
  else:
98
- return "Scusa, non ho capito. Puoi rispondere con 'sì', 'no' o 'non so, scegli tu'."
99
 
100
  elif self.state == "query":
101
  return self.handle_query(message)
@@ -128,6 +178,7 @@ class MariannaBot:
128
 
129
  def main():
130
  bot = MariannaBot()
 
131
 
132
  def update_chatbot(message, history):
133
  if not message.strip():
@@ -141,10 +192,10 @@ def main():
141
 
142
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
143
  with gr.Row():
144
- gr.Markdown("## Parla con Marianna - 'La Testa di Napoli'")
145
 
146
  with gr.Row():
147
- gr.Image("marianna-102.jpeg",
148
  elem_id="marianna-image",
149
  width=250)
150
 
 
2
  import random
3
  import berkeleydb
4
  import pickle
5
+ from sentence_transformers import SentenceTransformer, CrossEncoder, util
6
+
7
 
8
  class MariannaBot:
9
  def __init__(self):
10
  self.database = berkeleydb.hashopen("wiki_napoli_main.db", flag="c")
11
  self.database_legends = berkeleydb.hashopen("wiki_naples_leggende.db", flag="c")
12
+ self.db_keys = [key.decode("utf-8") for key, value in self.database.items()]
13
  self.reset_state()
14
+
15
+ def initialize_encoder(self):
16
+ """
17
+ Inizializza il modello encoder e pre-calcola gli embedding delle chiavi del database.
18
+ Questo metodo dovrebbe essere chiamato una sola volta all'avvio del bot.
19
+ """
20
+ try:
21
+ # Initialize the encoder model
22
+ encoder_model = "nickprock/sentence-bert-base-italian-uncased"
23
+ cross_encoder_model = "nickprock/cross-encoder-italian-bert-stsb"
24
+ self.encoder = SentenceTransformer(encoder_model)
25
+ self.cross_encoder = CrossEncoder(cross_encoder_model)
26
+
27
+ # Pre-encode all database keys
28
+ self.db_keys_embeddings = self.encoder.encode(self.db_keys, convert_to_tensor=True)
29
+
30
+ print(f"Encoder initialized with {len(self.db_keys)} keys.")
31
+ return True
32
+ except Exception as e:
33
+ print(f"Error initializing encoder: {str(e)}")
34
+ return False
35
 
36
  def reset_state(self):
37
  self.state = "initial"
 
44
  def get_welcome_message(self):
45
  return """Ciao, benvenuto!\n\nSono Marianna, la testa di Napoli, in napoletano 'a capa 'e Napule, una statua ritrovata per caso nel 1594. \nAll'epoca del mio ritrovamento, si pensò che fossi una rappresentazione della sirena Partenope, dalle cui spoglie, leggenda narra, nacque la città di Napoli. In seguito, diversi studiosi riconobbero in me una statua della dea Venere, probabilmente collocata in uno dei tanti templi che si trovavano nella città in epoca tardo-romana, quando ancora si chiamava Neapolis.
46
  \nPosso raccontarti molte storie sulla città di Napoli e mostrarti le sue bellezze. \nC'è qualcosa in particolare che ti interessa?
47
+ \n(Rispondi con 'sì', 'no' o 'non so')"""
48
 
49
+ def get_safe_example_keys(self, num_examples=3):
50
  """Safely get example keys from the database."""
51
  try:
52
  keys = list(self.database.keys())
 
86
  def handle_query(self, message):
87
  """Handle user queries by searching the database"""
88
  try:
89
+ # Encode the user query
90
+ query_embedding = self.encoder.encode(message, convert_to_tensor=True)
91
+
92
+ # Perform semantic search on the keys
93
+ semantic_hits = util.semantic_search(query_embedding, self.db_keys_embeddings, top_k=3)
94
+ semantic_hits = semantic_hits[0]
95
+
96
+ cross_inp = [(message, self.db_keys[hit['corpus_id']]) for hit in semantic_hits]
97
+ cross_scores = self.cross_encoder.predict(cross_inp)
98
+
99
+ reranked_hits = sorted(
100
+ [{'corpus_id': hit['corpus_id'], 'cross-score': score}
101
+ for hit, score in zip(semantic_hits, cross_scores)],
102
+ key=lambda x: x['cross-score'], reverse=True
103
+ )
104
+
105
+ best_hit = reranked_hits[0]
106
+ best_title = self.db_keys[best_hit['corpus_id']]
107
+
108
+
109
+ # Using your existing code for handling the best match
110
+ if best_title is not None:
111
+ best_title_bytes = best_title.encode("utf-8") # Converti la stringa in bytes
112
+
113
+ if best_title_bytes in self.database:
114
+ value = self.database[best_title_bytes]
115
+ key = best_title
116
+ self.main_k.append(key)
117
  self.state = "follow_up"
118
  self.is_telling_stories = False
119
  deserialized_value = pickle.loads(value)
 
121
  self.current_further_info_values = list(deserialized_value.get('further_info', {}).values())
122
  self.current_index = 0
123
  return f"{response}\n\nVuoi sapere altro su {self.main_k[-1]}? (sì/no)"
124
+ else:
125
+ return "Mi dispiace, non ho informazioni riguardo a questa domanda. Prova a chiedermi qualcos'altro sulla città di Napoli."
126
+
127
+ except Exception as e:
128
  self.state = "initial"
129
  return "Mi dispiace, c'è stato un errore. Puoi riprovare con un'altra domanda?"
130
 
 
138
  if message in ["sì", "si"]:
139
  self.state = "query"
140
  self.is_telling_stories = False
141
+ return "Potresti dirmi di cosa vorresti sapere?"
142
  elif message == "no":
143
  self.state = "end"
144
  return "Va bene, grazie per aver parlato con me."
145
+ elif message == "non so":
146
  return self.story_flow()
147
  else:
148
+ return "Scusa, non ho capito. Puoi rispondere con 'sì', 'no' o 'non so'."
149
 
150
  elif self.state == "query":
151
  return self.handle_query(message)
 
178
 
179
  def main():
180
  bot = MariannaBot()
181
+ bot.initialize_encoder()
182
 
183
  def update_chatbot(message, history):
184
  if not message.strip():
 
192
 
193
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
194
  with gr.Row():
195
+ gr.Markdown("## Chat con Marianna - 'La Testa di Napoli'")
196
 
197
  with gr.Row():
198
+ gr.Image("/home/filippo/Scrivania/Marianna_head/Marianna_testa/marianna-102.jpeg",
199
  elem_id="marianna-image",
200
  width=250)
201