Chananchida commited on
Commit
b2b7ea1
1 Parent(s): 14ce495

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -20
app.py CHANGED
@@ -15,7 +15,7 @@ from unstructured.partition.html import partition_html
15
 
16
  DEFAULT_MODEL = 'wangchanberta'
17
  DEFAULT_SENTENCE_EMBEDDING_MODEL = 'intfloat/multilingual-e5-base'
18
- EMBEDDINGS_PATH = 'data/embeddings.pkl'
19
  MODEL_DICT = {
20
  'wangchanberta': 'Chananchida/wangchanberta-xet_ref-params',
21
  'wangchanberta-hyp': 'Chananchida/wangchanberta-xet_hyp-params',
@@ -59,14 +59,6 @@ def prepare_sentences_vector(encoded_list):
59
  encoded_list = normalize(encoded_list)
60
  return encoded_list
61
 
62
- def load_embeddings(file_path=EMBEDDINGS_PATH):
63
- with open(file_path, "rb") as fIn:
64
- stored_data = pickle.load(fIn)
65
- stored_sentences = stored_data['sentences']
66
- stored_embeddings = stored_data['embeddings']
67
- print('Load (questions) embeddings done')
68
- return stored_embeddings
69
-
70
  def faiss_search(index, question_vector, k=1):
71
  distances, indices = index.search(question_vector, k)
72
  return distances,indices
@@ -81,25 +73,23 @@ def model_pipeline(model, tokenizer, question, context):
81
  Answer = tokenizer.decode(predict_answer_tokens)
82
  return Answer.replace('<unk>','@')
83
 
84
- def predict_test(embedding_model, context, question, index): # sent_tokenize pythainlp
85
  t = time.time()
86
  question = question.strip()
87
  question_vector = get_embeddings(embedding_model, question)
88
  question_vector = prepare_sentences_vector([question_vector])
89
- distances, indices = faiss_search(index, question_vector, 3) # Retrieve top 3 indices
90
 
91
  most_similar_contexts = ''
92
- for i in range(3): # Loop through top 3 indices
93
  most_sim_context = context[indices[0][i]].strip()
94
- # most_similar_contexts.append(most_sim_context)
95
- most_similar_contexts += 'Top '+str(i+1)+': '+most_sim_context + "\n\n"
 
96
  print(most_similar_contexts)
97
  return most_similar_contexts
98
 
99
-
100
-
101
  if __name__ == "__main__":
102
-
103
  url = "https://www.dataxet.co/media-landscape/2024-th"
104
  elements = partition_html(url=url)
105
  context = [str(element) for element in elements if len(str(element)) >60]
@@ -108,7 +98,7 @@ if __name__ == "__main__":
108
  index = set_index(prepare_sentences_vector(get_embeddings(embedding_model, context)))
109
 
110
  def chat_interface(question, history):
111
- response = predict_test(embedding_model, context, question, index)
112
  return response
113
 
114
  examples=['ภูมิทัศน์สื่อไทยในปี 2567 มีแนวโน้มว่า ',
@@ -116,8 +106,7 @@ if __name__ == "__main__":
116
  'ติ๊กต๊อก คือ',
117
  'รายงานจาก Reuters Institute'
118
  ]
119
-
120
-
121
  interface = gr.ChatInterface(fn=chat_interface,
122
  examples=examples)
123
 
 
15
 
16
  DEFAULT_MODEL = 'wangchanberta'
17
  DEFAULT_SENTENCE_EMBEDDING_MODEL = 'intfloat/multilingual-e5-base'
18
+
19
  MODEL_DICT = {
20
  'wangchanberta': 'Chananchida/wangchanberta-xet_ref-params',
21
  'wangchanberta-hyp': 'Chananchida/wangchanberta-xet_hyp-params',
 
59
  encoded_list = normalize(encoded_list)
60
  return encoded_list
61
 
 
 
 
 
 
 
 
 
62
  def faiss_search(index, question_vector, k=1):
63
  distances, indices = index.search(question_vector, k)
64
  return distances,indices
 
73
  Answer = tokenizer.decode(predict_answer_tokens)
74
  return Answer.replace('<unk>','@')
75
 
76
+ def predict_test(embedding_model, context, question, index, url):
77
  t = time.time()
78
  question = question.strip()
79
  question_vector = get_embeddings(embedding_model, question)
80
  question_vector = prepare_sentences_vector([question_vector])
81
+ distances, indices = faiss_search(index, question_vector, 3)
82
 
83
  most_similar_contexts = ''
84
+ for i in range(3):
85
  most_sim_context = context[indices[0][i]].strip()
86
+ answer_url = f"{url}#:~:text={most_sim_context}"
87
+ # encoded_url = urllib.parse.quote(answer_url)
88
+ most_similar_contexts += f'<a href="{answer_url}">[ {i+1} ]: {most_sim_context}</a>\n\n'
89
  print(most_similar_contexts)
90
  return most_similar_contexts
91
 
 
 
92
  if __name__ == "__main__":
 
93
  url = "https://www.dataxet.co/media-landscape/2024-th"
94
  elements = partition_html(url=url)
95
  context = [str(element) for element in elements if len(str(element)) >60]
 
98
  index = set_index(prepare_sentences_vector(get_embeddings(embedding_model, context)))
99
 
100
  def chat_interface(question, history):
101
+ response = predict_test(embedding_model, context, question, index, url)
102
  return response
103
 
104
  examples=['ภูมิทัศน์สื่อไทยในปี 2567 มีแนวโน้มว่า ',
 
106
  'ติ๊กต๊อก คือ',
107
  'รายงานจาก Reuters Institute'
108
  ]
109
+
 
110
  interface = gr.ChatInterface(fn=chat_interface,
111
  examples=examples)
112