Spaces:
Running
Running
oceansweep
commited on
Commit
•
1be405f
1
Parent(s):
9d34106
Upload 127 files
Browse files- App_Function_Libraries/Audio/Audio_Transcription_Lib.py +334 -283
- App_Function_Libraries/Benchmarks_Evaluations/ms_g_eval.py +1 -1
- App_Function_Libraries/Books/Book_Ingestion_Lib.py +109 -36
- App_Function_Libraries/Character_Chat/Character_Chat_Lib.py +80 -14
- App_Function_Libraries/Chat.py +75 -20
- App_Function_Libraries/DB/Character_Chat_DB.py +701 -701
- App_Function_Libraries/DB/RAG_QA_Chat_DB.py +461 -0
- App_Function_Libraries/DB/SQLite_DB.py +1 -1
- App_Function_Libraries/Gradio_Related.py +423 -423
- App_Function_Libraries/Gradio_UI/Character_Chat_tab.py +119 -449
- App_Function_Libraries/Gradio_UI/Live_Recording.py +142 -0
- App_Function_Libraries/Gradio_UI/Llamafile_tab.py +276 -86
- App_Function_Libraries/Gradio_UI/MMLU_Pro_tab.py +115 -0
- App_Function_Libraries/Gradio_UI/RAG_QA_Chat_Notes.py +243 -0
- App_Function_Libraries/Gradio_UI/Utilities.py +3 -3
- App_Function_Libraries/Local_LLM/Local_LLM_Inference_Engine_Lib.py +317 -0
- App_Function_Libraries/Local_LLM/Local_LLM_huggingface.py +79 -0
- App_Function_Libraries/Local_LLM/Local_LLM_ollama.py +96 -0
- App_Function_Libraries/Metrics/__init__.py +0 -0
- App_Function_Libraries/PDF/PDF_Ingestion_Lib.py +45 -151
- App_Function_Libraries/RAG/ChromaDB_Library.py +226 -25
- App_Function_Libraries/RAG/Embeddings_Create.py +485 -69
- App_Function_Libraries/RAG/RAG_Library_2.py +396 -81
- App_Function_Libraries/RAG/RAG_Persona_Chat.py +103 -0
- App_Function_Libraries/RAG/RAG_QA_Chat.py +137 -84
- App_Function_Libraries/RAG/eval_Chroma_Embeddings.py +133 -0
- App_Function_Libraries/Summarization/Local_Summarization_Lib.py +95 -38
- App_Function_Libraries/Summarization/Summarization_General_Lib.py +6 -2
- App_Function_Libraries/Web_Scraping/Article_Extractor_Lib.py +148 -1
- App_Function_Libraries/Web_Scraping/Article_Summarization_Lib.py +1 -1
- App_Function_Libraries/html_to_markdown/__init__.py +0 -0
- App_Function_Libraries/html_to_markdown/ast_utils.py +59 -0
- App_Function_Libraries/html_to_markdown/conversion_options.py +21 -0
- App_Function_Libraries/html_to_markdown/dom_utils.py +140 -0
- App_Function_Libraries/html_to_markdown/html_to_markdown.py +46 -0
- App_Function_Libraries/html_to_markdown/html_to_markdown_ast.py +212 -0
- App_Function_Libraries/html_to_markdown/main.py +45 -0
- App_Function_Libraries/html_to_markdown/markdown_ast_to_string.py +163 -0
- App_Function_Libraries/html_to_markdown/s_types.py +126 -0
- App_Function_Libraries/html_to_markdown/url_utils.py +55 -0
- App_Function_Libraries/models/pyannote_diarization_config.yaml +13 -0
- App_Function_Libraries/models/pyannote_model_segmentation-3.0.bin +3 -0
- App_Function_Libraries/models/pyannote_model_wespeaker-voxceleb-resnet34-LM.bin +3 -0
- App_Function_Libraries/test.gguf +0 -0
App_Function_Libraries/Audio/Audio_Transcription_Lib.py
CHANGED
@@ -1,284 +1,335 @@
|
|
1 |
-
# Audio_Transcription_Lib.py
|
2 |
-
#########################################
|
3 |
-
# Transcription Library
|
4 |
-
# This library is used to perform transcription of audio files.
|
5 |
-
# Currently, uses faster_whisper for transcription.
|
6 |
-
#
|
7 |
-
####################
|
8 |
-
# Function List
|
9 |
-
#
|
10 |
-
# 1. convert_to_wav(video_file_path, offset=0, overwrite=False)
|
11 |
-
# 2. speech_to_text(audio_file_path, selected_source_lang='en', whisper_model='small.en', vad_filter=False)
|
12 |
-
#
|
13 |
-
####################
|
14 |
-
#
|
15 |
-
# Import necessary libraries to run solo for testing
|
16 |
-
import gc
|
17 |
-
import json
|
18 |
-
import logging
|
19 |
-
import multiprocessing
|
20 |
-
import os
|
21 |
-
import queue
|
22 |
-
import sys
|
23 |
-
import subprocess
|
24 |
-
import tempfile
|
25 |
-
import threading
|
26 |
-
import time
|
27 |
-
# DEBUG Imports
|
28 |
-
#from memory_profiler import profile
|
29 |
-
|
30 |
-
from faster_whisper import WhisperModel as OriginalWhisperModel
|
31 |
-
from typing import Optional, Union, List, Dict, Any
|
32 |
-
#
|
33 |
-
# Import Local
|
34 |
-
from App_Function_Libraries.Utils.Utils import load_comprehensive_config
|
35 |
-
from App_Function_Libraries.Metrics.metrics_logger import log_counter, log_histogram
|
36 |
-
#
|
37 |
-
#######################################################################################################################
|
38 |
-
# Function Definitions
|
39 |
-
#
|
40 |
-
|
41 |
-
# Convert video .m4a into .wav using ffmpeg
|
42 |
-
# ffmpeg -i "example.mp4" -ar 16000 -ac 1 -c:a pcm_s16le "output.wav"
|
43 |
-
# https://www.gyan.dev/ffmpeg/builds/
|
44 |
-
#
|
45 |
-
|
46 |
-
|
47 |
-
whisper_model_instance = None
|
48 |
-
config = load_comprehensive_config()
|
49 |
-
processing_choice = config.get('Processing', 'processing_choice', fallback='cpu')
|
50 |
-
total_thread_count = multiprocessing.cpu_count()
|
51 |
-
|
52 |
-
|
53 |
-
class WhisperModel(OriginalWhisperModel):
|
54 |
-
tldw_dir = os.path.dirname(os.path.dirname(__file__))
|
55 |
-
default_download_root = os.path.join(tldw_dir, 'models', 'Whisper')
|
56 |
-
|
57 |
-
valid_model_sizes = [
|
58 |
-
"tiny.en", "tiny", "base.en", "base", "small.en", "small", "medium.en", "medium",
|
59 |
-
"large-v1", "large-v2", "large-v3", "large", "distil-large-v2", "distil-medium.en",
|
60 |
-
"distil-small.en", "distil-large-v3",
|
61 |
-
]
|
62 |
-
|
63 |
-
def __init__(
|
64 |
-
self,
|
65 |
-
model_size_or_path: str,
|
66 |
-
device: str = processing_choice,
|
67 |
-
device_index: Union[int, List[int]] = 0,
|
68 |
-
compute_type: str = "default",
|
69 |
-
cpu_threads: int = 0,#total_thread_count, FIXME - I think this should be 0
|
70 |
-
num_workers: int = 1,
|
71 |
-
download_root: Optional[str] = None,
|
72 |
-
local_files_only: bool = False,
|
73 |
-
files: Optional[Dict[str, Any]] = None,
|
74 |
-
**model_kwargs: Any
|
75 |
-
):
|
76 |
-
if download_root is None:
|
77 |
-
download_root = self.default_download_root
|
78 |
-
|
79 |
-
os.makedirs(download_root, exist_ok=True)
|
80 |
-
|
81 |
-
# FIXME - validate....
|
82 |
-
# Also write an integration test...
|
83 |
-
# Check if model_size_or_path is a valid model size
|
84 |
-
if model_size_or_path in self.valid_model_sizes:
|
85 |
-
# It's a model size, so we'll use the download_root
|
86 |
-
model_path = os.path.join(download_root, model_size_or_path)
|
87 |
-
if not os.path.isdir(model_path):
|
88 |
-
# If it doesn't exist, we'll let the parent class download it
|
89 |
-
model_size_or_path = model_size_or_path # Keep the original model size
|
90 |
-
else:
|
91 |
-
# If it exists, use the full path
|
92 |
-
model_size_or_path = model_path
|
93 |
-
else:
|
94 |
-
# It's not a valid model size, so assume it's a path
|
95 |
-
model_size_or_path = os.path.abspath(model_size_or_path)
|
96 |
-
|
97 |
-
super().__init__(
|
98 |
-
model_size_or_path,
|
99 |
-
device=device,
|
100 |
-
device_index=device_index,
|
101 |
-
compute_type=compute_type,
|
102 |
-
cpu_threads=cpu_threads,
|
103 |
-
num_workers=num_workers,
|
104 |
-
download_root=download_root,
|
105 |
-
local_files_only=local_files_only,
|
106 |
-
# Maybe? idk, FIXME
|
107 |
-
# files=files,
|
108 |
-
# **model_kwargs
|
109 |
-
)
|
110 |
-
|
111 |
-
def get_whisper_model(model_name, device):
|
112 |
-
global whisper_model_instance
|
113 |
-
if whisper_model_instance is None:
|
114 |
-
logging.info(f"Initializing new WhisperModel with size {model_name} on device {device}")
|
115 |
-
whisper_model_instance = WhisperModel(model_name, device=device)
|
116 |
-
return whisper_model_instance
|
117 |
-
|
118 |
-
# os.system(r'.\Bin\ffmpeg.exe -ss 00:00:00 -i "{video_file_path}" -ar 16000 -ac 1 -c:a pcm_s16le "{out_path}"')
|
119 |
-
#DEBUG
|
120 |
-
#@profile
|
121 |
-
def convert_to_wav(video_file_path, offset=0, overwrite=False):
|
122 |
-
log_counter("convert_to_wav_attempt", labels={"file_path": video_file_path})
|
123 |
-
start_time = time.time()
|
124 |
-
|
125 |
-
out_path = os.path.splitext(video_file_path)[0] + ".wav"
|
126 |
-
|
127 |
-
if os.path.exists(out_path) and not overwrite:
|
128 |
-
print(f"File '{out_path}' already exists. Skipping conversion.")
|
129 |
-
logging.info(f"Skipping conversion as file already exists: {out_path}")
|
130 |
-
log_counter("convert_to_wav_skipped", labels={"file_path": video_file_path})
|
131 |
-
return out_path
|
132 |
-
|
133 |
-
print("Starting conversion process of .m4a to .WAV")
|
134 |
-
out_path = os.path.splitext(video_file_path)[0] + ".wav"
|
135 |
-
|
136 |
-
try:
|
137 |
-
if os.name == "nt":
|
138 |
-
logging.debug("ffmpeg being ran on windows")
|
139 |
-
|
140 |
-
if sys.platform.startswith('win'):
|
141 |
-
ffmpeg_cmd = ".\\Bin\\ffmpeg.exe"
|
142 |
-
logging.debug(f"ffmpeg_cmd: {ffmpeg_cmd}")
|
143 |
-
else:
|
144 |
-
ffmpeg_cmd = 'ffmpeg' # Assume 'ffmpeg' is in PATH for non-Windows systems
|
145 |
-
|
146 |
-
command = [
|
147 |
-
ffmpeg_cmd, # Assuming the working directory is correctly set where .\Bin exists
|
148 |
-
"-ss", "00:00:00", # Start at the beginning of the video
|
149 |
-
"-i", video_file_path,
|
150 |
-
"-ar", "16000", # Audio sample rate
|
151 |
-
"-ac", "1", # Number of audio channels
|
152 |
-
"-c:a", "pcm_s16le", # Audio codec
|
153 |
-
out_path
|
154 |
-
]
|
155 |
-
try:
|
156 |
-
# Redirect stdin from null device to prevent ffmpeg from waiting for input
|
157 |
-
with open(os.devnull, 'rb') as null_file:
|
158 |
-
result = subprocess.run(command, stdin=null_file, text=True, capture_output=True)
|
159 |
-
if result.returncode == 0:
|
160 |
-
logging.info("FFmpeg executed successfully")
|
161 |
-
logging.debug("FFmpeg output: %s", result.stdout)
|
162 |
-
else:
|
163 |
-
logging.error("Error in running FFmpeg")
|
164 |
-
logging.error("FFmpeg stderr: %s", result.stderr)
|
165 |
-
raise RuntimeError(f"FFmpeg error: {result.stderr}")
|
166 |
-
except Exception as e:
|
167 |
-
logging.error("Error occurred - ffmpeg doesn't like windows")
|
168 |
-
raise RuntimeError("ffmpeg failed")
|
169 |
-
elif os.name == "posix":
|
170 |
-
os.system(f'ffmpeg -ss 00:00:00 -i "{video_file_path}" -ar 16000 -ac 1 -c:a pcm_s16le "{out_path}"')
|
171 |
-
else:
|
172 |
-
raise RuntimeError("Unsupported operating system")
|
173 |
-
logging.info("Conversion to WAV completed: %s", out_path)
|
174 |
-
log_counter("convert_to_wav_success", labels={"file_path": video_file_path})
|
175 |
-
except Exception as e:
|
176 |
-
logging.error("speech-to-text: Error transcribing audio: %s", str(e))
|
177 |
-
log_counter("convert_to_wav_error", labels={"file_path": video_file_path, "error": str(e)})
|
178 |
-
return {"error": str(e)}
|
179 |
-
|
180 |
-
conversion_time = time.time() - start_time
|
181 |
-
log_histogram("convert_to_wav_duration", conversion_time, labels={"file_path": video_file_path})
|
182 |
-
|
183 |
-
gc.collect()
|
184 |
-
return out_path
|
185 |
-
|
186 |
-
|
187 |
-
# Transcribe .wav into .segments.json
|
188 |
-
#DEBUG
|
189 |
-
#@profile
|
190 |
-
# FIXME - I feel like the `vad_filter` shoudl be enabled by default....
|
191 |
-
def speech_to_text(audio_file_path, selected_source_lang='en', whisper_model='medium.en', vad_filter=False, diarize=False):
|
192 |
-
log_counter("speech_to_text_attempt", labels={"file_path": audio_file_path, "model": whisper_model})
|
193 |
-
time_start = time.time()
|
194 |
-
|
195 |
-
if audio_file_path is None:
|
196 |
-
log_counter("speech_to_text_error", labels={"error": "No audio file provided"})
|
197 |
-
raise ValueError("speech-to-text: No audio file provided")
|
198 |
-
logging.info("speech-to-text: Audio file path: %s", audio_file_path)
|
199 |
-
|
200 |
-
try:
|
201 |
-
_, file_ending = os.path.splitext(audio_file_path)
|
202 |
-
out_file = audio_file_path.replace(file_ending, "-whisper_model-"+whisper_model+".segments.json")
|
203 |
-
prettified_out_file = audio_file_path.replace(file_ending, "-whisper_model-"+whisper_model+".segments_pretty.json")
|
204 |
-
if os.path.exists(out_file):
|
205 |
-
logging.info("speech-to-text: Segments file already exists: %s", out_file)
|
206 |
-
with open(out_file) as f:
|
207 |
-
global segments
|
208 |
-
segments = json.load(f)
|
209 |
-
return segments
|
210 |
-
|
211 |
-
logging.info('speech-to-text: Starting transcription...')
|
212 |
-
# FIXME - revisit this
|
213 |
-
options = dict(language=selected_source_lang, beam_size=10, best_of=10, vad_filter=vad_filter)
|
214 |
-
transcribe_options = dict(task="transcribe", **options)
|
215 |
-
# use function and config at top of file
|
216 |
-
logging.debug("speech-to-text: Using whisper model: %s", whisper_model)
|
217 |
-
whisper_model_instance = get_whisper_model(whisper_model, processing_choice)
|
218 |
-
# faster_whisper transcription right here - FIXME -test batching - ha
|
219 |
-
segments_raw, info = whisper_model_instance.transcribe(audio_file_path, **transcribe_options)
|
220 |
-
|
221 |
-
segments = []
|
222 |
-
for segment_chunk in segments_raw:
|
223 |
-
chunk = {
|
224 |
-
"Time_Start": segment_chunk.start,
|
225 |
-
"Time_End": segment_chunk.end,
|
226 |
-
"Text": segment_chunk.text
|
227 |
-
}
|
228 |
-
logging.debug("Segment: %s", chunk)
|
229 |
-
segments.append(chunk)
|
230 |
-
# Print to verify its working
|
231 |
-
logging.info(f"{segment_chunk.start:.2f}s - {segment_chunk.end:.2f}s | {segment_chunk.text}")
|
232 |
-
|
233 |
-
# Log it as well.
|
234 |
-
logging.debug(
|
235 |
-
f"Transcribed Segment: {segment_chunk.start:.2f}s - {segment_chunk.end:.2f}s | {segment_chunk.text}")
|
236 |
-
|
237 |
-
if segments:
|
238 |
-
segments[0]["Text"] = f"This text was transcribed using whisper model: {whisper_model}\n\n" + segments[0]["Text"]
|
239 |
-
|
240 |
-
if not segments:
|
241 |
-
log_counter("speech_to_text_error", labels={"error": "No transcription produced"})
|
242 |
-
raise RuntimeError("No transcription produced. The audio file may be invalid or empty.")
|
243 |
-
|
244 |
-
transcription_time = time.time() - time_start
|
245 |
-
logging.info("speech-to-text: Transcription completed in %.2f seconds", transcription_time)
|
246 |
-
log_histogram("speech_to_text_duration", transcription_time, labels={"file_path": audio_file_path, "model": whisper_model})
|
247 |
-
log_counter("speech_to_text_success", labels={"file_path": audio_file_path, "model": whisper_model})
|
248 |
-
# Save the segments to a JSON file - prettified and non-prettified
|
249 |
-
# FIXME refactor so this is an optional flag to save either the prettified json file or the normal one
|
250 |
-
save_json = True
|
251 |
-
if save_json:
|
252 |
-
logging.info("speech-to-text: Saving segments to JSON file")
|
253 |
-
output_data = {'segments': segments}
|
254 |
-
logging.info("speech-to-text: Saving prettified JSON to %s", prettified_out_file)
|
255 |
-
with open(prettified_out_file, 'w') as f:
|
256 |
-
json.dump(output_data, f, indent=2)
|
257 |
-
|
258 |
-
logging.info("speech-to-text: Saving JSON to %s", out_file)
|
259 |
-
with open(out_file, 'w') as f:
|
260 |
-
json.dump(output_data, f)
|
261 |
-
|
262 |
-
logging.debug(f"speech-to-text: returning {segments[:500]}")
|
263 |
-
gc.collect()
|
264 |
-
return segments
|
265 |
-
|
266 |
-
except Exception as e:
|
267 |
-
logging.error("speech-to-text: Error transcribing audio: %s", str(e))
|
268 |
-
log_counter("speech_to_text_error", labels={"file_path": audio_file_path, "model": whisper_model, "error": str(e)})
|
269 |
-
raise RuntimeError("speech-to-text: Error transcribing audio")
|
270 |
-
|
271 |
-
|
272 |
-
def record_audio(duration, sample_rate=16000, chunk_size=1024):
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
284 |
#######################################################################################################################
|
|
|
1 |
+
# Audio_Transcription_Lib.py
|
2 |
+
#########################################
|
3 |
+
# Transcription Library
|
4 |
+
# This library is used to perform transcription of audio files.
|
5 |
+
# Currently, uses faster_whisper for transcription.
|
6 |
+
#
|
7 |
+
####################
|
8 |
+
# Function List
|
9 |
+
#
|
10 |
+
# 1. convert_to_wav(video_file_path, offset=0, overwrite=False)
|
11 |
+
# 2. speech_to_text(audio_file_path, selected_source_lang='en', whisper_model='small.en', vad_filter=False)
|
12 |
+
#
|
13 |
+
####################
|
14 |
+
#
|
15 |
+
# Import necessary libraries to run solo for testing
|
16 |
+
import gc
|
17 |
+
import json
|
18 |
+
import logging
|
19 |
+
import multiprocessing
|
20 |
+
import os
|
21 |
+
import queue
|
22 |
+
import sys
|
23 |
+
import subprocess
|
24 |
+
import tempfile
|
25 |
+
import threading
|
26 |
+
import time
|
27 |
+
# DEBUG Imports
|
28 |
+
#from memory_profiler import profile
|
29 |
+
import pyaudio
|
30 |
+
from faster_whisper import WhisperModel as OriginalWhisperModel
|
31 |
+
from typing import Optional, Union, List, Dict, Any
|
32 |
+
#
|
33 |
+
# Import Local
|
34 |
+
from App_Function_Libraries.Utils.Utils import load_comprehensive_config
|
35 |
+
from App_Function_Libraries.Metrics.metrics_logger import log_counter, log_histogram
|
36 |
+
#
|
37 |
+
#######################################################################################################################
|
38 |
+
# Function Definitions
|
39 |
+
#
|
40 |
+
|
41 |
+
# Convert video .m4a into .wav using ffmpeg
|
42 |
+
# ffmpeg -i "example.mp4" -ar 16000 -ac 1 -c:a pcm_s16le "output.wav"
|
43 |
+
# https://www.gyan.dev/ffmpeg/builds/
|
44 |
+
#
|
45 |
+
|
46 |
+
|
47 |
+
whisper_model_instance = None
|
48 |
+
config = load_comprehensive_config()
|
49 |
+
processing_choice = config.get('Processing', 'processing_choice', fallback='cpu')
|
50 |
+
total_thread_count = multiprocessing.cpu_count()
|
51 |
+
|
52 |
+
|
53 |
+
class WhisperModel(OriginalWhisperModel):
|
54 |
+
tldw_dir = os.path.dirname(os.path.dirname(__file__))
|
55 |
+
default_download_root = os.path.join(tldw_dir, 'models', 'Whisper')
|
56 |
+
|
57 |
+
valid_model_sizes = [
|
58 |
+
"tiny.en", "tiny", "base.en", "base", "small.en", "small", "medium.en", "medium",
|
59 |
+
"large-v1", "large-v2", "large-v3", "large", "distil-large-v2", "distil-medium.en",
|
60 |
+
"distil-small.en", "distil-large-v3",
|
61 |
+
]
|
62 |
+
|
63 |
+
def __init__(
|
64 |
+
self,
|
65 |
+
model_size_or_path: str,
|
66 |
+
device: str = processing_choice,
|
67 |
+
device_index: Union[int, List[int]] = 0,
|
68 |
+
compute_type: str = "default",
|
69 |
+
cpu_threads: int = 0,#total_thread_count, FIXME - I think this should be 0
|
70 |
+
num_workers: int = 1,
|
71 |
+
download_root: Optional[str] = None,
|
72 |
+
local_files_only: bool = False,
|
73 |
+
files: Optional[Dict[str, Any]] = None,
|
74 |
+
**model_kwargs: Any
|
75 |
+
):
|
76 |
+
if download_root is None:
|
77 |
+
download_root = self.default_download_root
|
78 |
+
|
79 |
+
os.makedirs(download_root, exist_ok=True)
|
80 |
+
|
81 |
+
# FIXME - validate....
|
82 |
+
# Also write an integration test...
|
83 |
+
# Check if model_size_or_path is a valid model size
|
84 |
+
if model_size_or_path in self.valid_model_sizes:
|
85 |
+
# It's a model size, so we'll use the download_root
|
86 |
+
model_path = os.path.join(download_root, model_size_or_path)
|
87 |
+
if not os.path.isdir(model_path):
|
88 |
+
# If it doesn't exist, we'll let the parent class download it
|
89 |
+
model_size_or_path = model_size_or_path # Keep the original model size
|
90 |
+
else:
|
91 |
+
# If it exists, use the full path
|
92 |
+
model_size_or_path = model_path
|
93 |
+
else:
|
94 |
+
# It's not a valid model size, so assume it's a path
|
95 |
+
model_size_or_path = os.path.abspath(model_size_or_path)
|
96 |
+
|
97 |
+
super().__init__(
|
98 |
+
model_size_or_path,
|
99 |
+
device=device,
|
100 |
+
device_index=device_index,
|
101 |
+
compute_type=compute_type,
|
102 |
+
cpu_threads=cpu_threads,
|
103 |
+
num_workers=num_workers,
|
104 |
+
download_root=download_root,
|
105 |
+
local_files_only=local_files_only,
|
106 |
+
# Maybe? idk, FIXME
|
107 |
+
# files=files,
|
108 |
+
# **model_kwargs
|
109 |
+
)
|
110 |
+
|
111 |
+
def get_whisper_model(model_name, device):
|
112 |
+
global whisper_model_instance
|
113 |
+
if whisper_model_instance is None:
|
114 |
+
logging.info(f"Initializing new WhisperModel with size {model_name} on device {device}")
|
115 |
+
whisper_model_instance = WhisperModel(model_name, device=device)
|
116 |
+
return whisper_model_instance
|
117 |
+
|
118 |
+
# os.system(r'.\Bin\ffmpeg.exe -ss 00:00:00 -i "{video_file_path}" -ar 16000 -ac 1 -c:a pcm_s16le "{out_path}"')
|
119 |
+
#DEBUG
|
120 |
+
#@profile
|
121 |
+
def convert_to_wav(video_file_path, offset=0, overwrite=False):
|
122 |
+
log_counter("convert_to_wav_attempt", labels={"file_path": video_file_path})
|
123 |
+
start_time = time.time()
|
124 |
+
|
125 |
+
out_path = os.path.splitext(video_file_path)[0] + ".wav"
|
126 |
+
|
127 |
+
if os.path.exists(out_path) and not overwrite:
|
128 |
+
print(f"File '{out_path}' already exists. Skipping conversion.")
|
129 |
+
logging.info(f"Skipping conversion as file already exists: {out_path}")
|
130 |
+
log_counter("convert_to_wav_skipped", labels={"file_path": video_file_path})
|
131 |
+
return out_path
|
132 |
+
|
133 |
+
print("Starting conversion process of .m4a to .WAV")
|
134 |
+
out_path = os.path.splitext(video_file_path)[0] + ".wav"
|
135 |
+
|
136 |
+
try:
|
137 |
+
if os.name == "nt":
|
138 |
+
logging.debug("ffmpeg being ran on windows")
|
139 |
+
|
140 |
+
if sys.platform.startswith('win'):
|
141 |
+
ffmpeg_cmd = ".\\Bin\\ffmpeg.exe"
|
142 |
+
logging.debug(f"ffmpeg_cmd: {ffmpeg_cmd}")
|
143 |
+
else:
|
144 |
+
ffmpeg_cmd = 'ffmpeg' # Assume 'ffmpeg' is in PATH for non-Windows systems
|
145 |
+
|
146 |
+
command = [
|
147 |
+
ffmpeg_cmd, # Assuming the working directory is correctly set where .\Bin exists
|
148 |
+
"-ss", "00:00:00", # Start at the beginning of the video
|
149 |
+
"-i", video_file_path,
|
150 |
+
"-ar", "16000", # Audio sample rate
|
151 |
+
"-ac", "1", # Number of audio channels
|
152 |
+
"-c:a", "pcm_s16le", # Audio codec
|
153 |
+
out_path
|
154 |
+
]
|
155 |
+
try:
|
156 |
+
# Redirect stdin from null device to prevent ffmpeg from waiting for input
|
157 |
+
with open(os.devnull, 'rb') as null_file:
|
158 |
+
result = subprocess.run(command, stdin=null_file, text=True, capture_output=True)
|
159 |
+
if result.returncode == 0:
|
160 |
+
logging.info("FFmpeg executed successfully")
|
161 |
+
logging.debug("FFmpeg output: %s", result.stdout)
|
162 |
+
else:
|
163 |
+
logging.error("Error in running FFmpeg")
|
164 |
+
logging.error("FFmpeg stderr: %s", result.stderr)
|
165 |
+
raise RuntimeError(f"FFmpeg error: {result.stderr}")
|
166 |
+
except Exception as e:
|
167 |
+
logging.error("Error occurred - ffmpeg doesn't like windows")
|
168 |
+
raise RuntimeError("ffmpeg failed")
|
169 |
+
elif os.name == "posix":
|
170 |
+
os.system(f'ffmpeg -ss 00:00:00 -i "{video_file_path}" -ar 16000 -ac 1 -c:a pcm_s16le "{out_path}"')
|
171 |
+
else:
|
172 |
+
raise RuntimeError("Unsupported operating system")
|
173 |
+
logging.info("Conversion to WAV completed: %s", out_path)
|
174 |
+
log_counter("convert_to_wav_success", labels={"file_path": video_file_path})
|
175 |
+
except Exception as e:
|
176 |
+
logging.error("speech-to-text: Error transcribing audio: %s", str(e))
|
177 |
+
log_counter("convert_to_wav_error", labels={"file_path": video_file_path, "error": str(e)})
|
178 |
+
return {"error": str(e)}
|
179 |
+
|
180 |
+
conversion_time = time.time() - start_time
|
181 |
+
log_histogram("convert_to_wav_duration", conversion_time, labels={"file_path": video_file_path})
|
182 |
+
|
183 |
+
gc.collect()
|
184 |
+
return out_path
|
185 |
+
|
186 |
+
|
187 |
+
# Transcribe .wav into .segments.json
|
188 |
+
#DEBUG
|
189 |
+
#@profile
|
190 |
+
# FIXME - I feel like the `vad_filter` shoudl be enabled by default....
|
191 |
+
def speech_to_text(audio_file_path, selected_source_lang='en', whisper_model='medium.en', vad_filter=False, diarize=False):
|
192 |
+
log_counter("speech_to_text_attempt", labels={"file_path": audio_file_path, "model": whisper_model})
|
193 |
+
time_start = time.time()
|
194 |
+
|
195 |
+
if audio_file_path is None:
|
196 |
+
log_counter("speech_to_text_error", labels={"error": "No audio file provided"})
|
197 |
+
raise ValueError("speech-to-text: No audio file provided")
|
198 |
+
logging.info("speech-to-text: Audio file path: %s", audio_file_path)
|
199 |
+
|
200 |
+
try:
|
201 |
+
_, file_ending = os.path.splitext(audio_file_path)
|
202 |
+
out_file = audio_file_path.replace(file_ending, "-whisper_model-"+whisper_model+".segments.json")
|
203 |
+
prettified_out_file = audio_file_path.replace(file_ending, "-whisper_model-"+whisper_model+".segments_pretty.json")
|
204 |
+
if os.path.exists(out_file):
|
205 |
+
logging.info("speech-to-text: Segments file already exists: %s", out_file)
|
206 |
+
with open(out_file) as f:
|
207 |
+
global segments
|
208 |
+
segments = json.load(f)
|
209 |
+
return segments
|
210 |
+
|
211 |
+
logging.info('speech-to-text: Starting transcription...')
|
212 |
+
# FIXME - revisit this
|
213 |
+
options = dict(language=selected_source_lang, beam_size=10, best_of=10, vad_filter=vad_filter)
|
214 |
+
transcribe_options = dict(task="transcribe", **options)
|
215 |
+
# use function and config at top of file
|
216 |
+
logging.debug("speech-to-text: Using whisper model: %s", whisper_model)
|
217 |
+
whisper_model_instance = get_whisper_model(whisper_model, processing_choice)
|
218 |
+
# faster_whisper transcription right here - FIXME -test batching - ha
|
219 |
+
segments_raw, info = whisper_model_instance.transcribe(audio_file_path, **transcribe_options)
|
220 |
+
|
221 |
+
segments = []
|
222 |
+
for segment_chunk in segments_raw:
|
223 |
+
chunk = {
|
224 |
+
"Time_Start": segment_chunk.start,
|
225 |
+
"Time_End": segment_chunk.end,
|
226 |
+
"Text": segment_chunk.text
|
227 |
+
}
|
228 |
+
logging.debug("Segment: %s", chunk)
|
229 |
+
segments.append(chunk)
|
230 |
+
# Print to verify its working
|
231 |
+
logging.info(f"{segment_chunk.start:.2f}s - {segment_chunk.end:.2f}s | {segment_chunk.text}")
|
232 |
+
|
233 |
+
# Log it as well.
|
234 |
+
logging.debug(
|
235 |
+
f"Transcribed Segment: {segment_chunk.start:.2f}s - {segment_chunk.end:.2f}s | {segment_chunk.text}")
|
236 |
+
|
237 |
+
if segments:
|
238 |
+
segments[0]["Text"] = f"This text was transcribed using whisper model: {whisper_model}\n\n" + segments[0]["Text"]
|
239 |
+
|
240 |
+
if not segments:
|
241 |
+
log_counter("speech_to_text_error", labels={"error": "No transcription produced"})
|
242 |
+
raise RuntimeError("No transcription produced. The audio file may be invalid or empty.")
|
243 |
+
|
244 |
+
transcription_time = time.time() - time_start
|
245 |
+
logging.info("speech-to-text: Transcription completed in %.2f seconds", transcription_time)
|
246 |
+
log_histogram("speech_to_text_duration", transcription_time, labels={"file_path": audio_file_path, "model": whisper_model})
|
247 |
+
log_counter("speech_to_text_success", labels={"file_path": audio_file_path, "model": whisper_model})
|
248 |
+
# Save the segments to a JSON file - prettified and non-prettified
|
249 |
+
# FIXME refactor so this is an optional flag to save either the prettified json file or the normal one
|
250 |
+
save_json = True
|
251 |
+
if save_json:
|
252 |
+
logging.info("speech-to-text: Saving segments to JSON file")
|
253 |
+
output_data = {'segments': segments}
|
254 |
+
logging.info("speech-to-text: Saving prettified JSON to %s", prettified_out_file)
|
255 |
+
with open(prettified_out_file, 'w') as f:
|
256 |
+
json.dump(output_data, f, indent=2)
|
257 |
+
|
258 |
+
logging.info("speech-to-text: Saving JSON to %s", out_file)
|
259 |
+
with open(out_file, 'w') as f:
|
260 |
+
json.dump(output_data, f)
|
261 |
+
|
262 |
+
logging.debug(f"speech-to-text: returning {segments[:500]}")
|
263 |
+
gc.collect()
|
264 |
+
return segments
|
265 |
+
|
266 |
+
except Exception as e:
|
267 |
+
logging.error("speech-to-text: Error transcribing audio: %s", str(e))
|
268 |
+
log_counter("speech_to_text_error", labels={"file_path": audio_file_path, "model": whisper_model, "error": str(e)})
|
269 |
+
raise RuntimeError("speech-to-text: Error transcribing audio")
|
270 |
+
|
271 |
+
|
272 |
+
def record_audio(duration, sample_rate=16000, chunk_size=1024):
|
273 |
+
log_counter("record_audio_attempt", labels={"duration": duration})
|
274 |
+
p = pyaudio.PyAudio()
|
275 |
+
stream = p.open(format=pyaudio.paInt16,
|
276 |
+
channels=1,
|
277 |
+
rate=sample_rate,
|
278 |
+
input=True,
|
279 |
+
frames_per_buffer=chunk_size)
|
280 |
+
|
281 |
+
print("Recording...")
|
282 |
+
frames = []
|
283 |
+
stop_recording = threading.Event()
|
284 |
+
audio_queue = queue.Queue()
|
285 |
+
|
286 |
+
def audio_callback():
|
287 |
+
for _ in range(0, int(sample_rate / chunk_size * duration)):
|
288 |
+
if stop_recording.is_set():
|
289 |
+
break
|
290 |
+
data = stream.read(chunk_size)
|
291 |
+
audio_queue.put(data)
|
292 |
+
|
293 |
+
audio_thread = threading.Thread(target=audio_callback)
|
294 |
+
audio_thread.start()
|
295 |
+
|
296 |
+
return p, stream, audio_queue, stop_recording, audio_thread
|
297 |
+
|
298 |
+
|
299 |
+
def stop_recording(p, stream, audio_queue, stop_recording_event, audio_thread):
|
300 |
+
log_counter("stop_recording_attempt")
|
301 |
+
start_time = time.time()
|
302 |
+
stop_recording_event.set()
|
303 |
+
audio_thread.join()
|
304 |
+
|
305 |
+
frames = []
|
306 |
+
while not audio_queue.empty():
|
307 |
+
frames.append(audio_queue.get())
|
308 |
+
|
309 |
+
print("Recording finished.")
|
310 |
+
|
311 |
+
stream.stop_stream()
|
312 |
+
stream.close()
|
313 |
+
p.terminate()
|
314 |
+
|
315 |
+
stop_time = time.time() - start_time
|
316 |
+
log_histogram("stop_recording_duration", stop_time)
|
317 |
+
log_counter("stop_recording_success")
|
318 |
+
return b''.join(frames)
|
319 |
+
|
320 |
+
def save_audio_temp(audio_data, sample_rate=16000):
|
321 |
+
log_counter("save_audio_temp_attempt")
|
322 |
+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
|
323 |
+
import wave
|
324 |
+
wf = wave.open(temp_file.name, 'wb')
|
325 |
+
wf.setnchannels(1)
|
326 |
+
wf.setsampwidth(2)
|
327 |
+
wf.setframerate(sample_rate)
|
328 |
+
wf.writeframes(audio_data)
|
329 |
+
wf.close()
|
330 |
+
log_counter("save_audio_temp_success")
|
331 |
+
return temp_file.name
|
332 |
+
|
333 |
+
#
|
334 |
+
#
|
335 |
#######################################################################################################################
|
App_Function_Libraries/Benchmarks_Evaluations/ms_g_eval.py
CHANGED
@@ -259,7 +259,7 @@ def run_geval(transcript: str, summary: str, api_key: str, api_name: str = None,
|
|
259 |
|
260 |
|
261 |
def create_geval_tab():
|
262 |
-
with gr.Tab("G-Eval"):
|
263 |
gr.Markdown("# G-Eval Summarization Evaluation")
|
264 |
with gr.Row():
|
265 |
with gr.Column():
|
|
|
259 |
|
260 |
|
261 |
def create_geval_tab():
|
262 |
+
with gr.Tab("G-Eval", id="g-eval"):
|
263 |
gr.Markdown("# G-Eval Summarization Evaluation")
|
264 |
with gr.Row():
|
265 |
with gr.Column():
|
App_Function_Libraries/Books/Book_Ingestion_Lib.py
CHANGED
@@ -11,32 +11,42 @@
|
|
11 |
#
|
12 |
####################
|
13 |
#
|
14 |
-
#
|
15 |
import os
|
16 |
import re
|
17 |
import tempfile
|
18 |
import zipfile
|
19 |
from datetime import datetime
|
20 |
import logging
|
21 |
-
|
|
|
22 |
import ebooklib
|
23 |
from bs4 import BeautifulSoup
|
24 |
from ebooklib import epub
|
25 |
-
|
26 |
-
from App_Function_Libraries.Chunk_Lib import chunk_ebook_by_chapters
|
27 |
#
|
28 |
# Import Local
|
29 |
from App_Function_Libraries.DB.DB_Manager import add_media_with_keywords, add_media_to_database
|
30 |
from App_Function_Libraries.Summarization.Summarization_General_Lib import perform_summarization
|
31 |
-
|
32 |
-
|
33 |
#
|
34 |
#######################################################################################################################
|
35 |
# Function Definitions
|
36 |
#
|
37 |
|
38 |
-
def import_epub(file_path,
|
39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
"""
|
41 |
Imports an EPUB file, extracts its content, chunks it, optionally summarizes it, and adds it to the database.
|
42 |
|
@@ -58,6 +68,9 @@ def import_epub(file_path, title=None, author=None, keywords=None, custom_prompt
|
|
58 |
"""
|
59 |
try:
|
60 |
logging.info(f"Importing EPUB file from {file_path}")
|
|
|
|
|
|
|
61 |
|
62 |
# Convert EPUB to Markdown
|
63 |
markdown_content = epub_to_markdown(file_path)
|
@@ -90,10 +103,11 @@ def import_epub(file_path, title=None, author=None, keywords=None, custom_prompt
|
|
90 |
# Chunk the content by chapters
|
91 |
chunks = chunk_ebook_by_chapters(markdown_content, chunk_options)
|
92 |
logging.info(f"Total chunks created: {len(chunks)}")
|
|
|
|
|
93 |
if chunks:
|
94 |
logging.debug(f"Structure of first chunk: {chunks[0].keys()}")
|
95 |
|
96 |
-
|
97 |
# Handle summarization if enabled
|
98 |
if auto_summarize and api_name and api_key:
|
99 |
logging.info("Auto-summarization is enabled.")
|
@@ -101,11 +115,15 @@ def import_epub(file_path, title=None, author=None, keywords=None, custom_prompt
|
|
101 |
for chunk in chunks:
|
102 |
chunk_text = chunk.get('text', '')
|
103 |
if chunk_text:
|
104 |
-
summary_text = perform_summarization(api_name, chunk_text, custom_prompt, api_key,
|
|
|
|
|
|
|
105 |
chunk['metadata']['summary'] = summary_text
|
106 |
summarized_chunks.append(chunk)
|
107 |
chunks = summarized_chunks
|
108 |
logging.info("Summarization of chunks completed.")
|
|
|
109 |
else:
|
110 |
# If not summarizing, set a default summary or use provided summary
|
111 |
if summary:
|
@@ -137,15 +155,33 @@ def import_epub(file_path, title=None, author=None, keywords=None, custom_prompt
|
|
137 |
overwrite=False
|
138 |
)
|
139 |
|
|
|
|
|
|
|
|
|
140 |
logging.info(f"Ebook '{title}' by {author} imported successfully. Database result: {result}")
|
|
|
141 |
return f"Ebook '{title}' by {author} imported successfully. Database result: {result}"
|
142 |
|
143 |
except Exception as e:
|
144 |
logging.exception(f"Error importing ebook: {str(e)}")
|
|
|
145 |
return f"Error importing ebook: {str(e)}"
|
146 |
|
|
|
147 |
# FIXME
|
148 |
-
def process_zip_file(zip_file,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
149 |
"""
|
150 |
Processes a ZIP file containing multiple EPUB files and imports each one.
|
151 |
|
@@ -169,38 +205,58 @@ def process_zip_file(zip_file, title, author, keywords, custom_prompt, system_pr
|
|
169 |
with tempfile.TemporaryDirectory() as temp_dir:
|
170 |
zip_path = zip_file.name if hasattr(zip_file, 'name') else zip_file.path
|
171 |
logging.info(f"Extracting ZIP file {zip_path} to temporary directory {temp_dir}")
|
|
|
|
|
172 |
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
|
173 |
zip_ref.extractall(temp_dir)
|
174 |
|
175 |
-
for
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
|
|
|
|
|
|
193 |
logging.info("Completed processing all EPUB files in the ZIP.")
|
|
|
194 |
except Exception as e:
|
195 |
logging.exception(f"Error processing ZIP file: {str(e)}")
|
|
|
196 |
return f"Error processing ZIP file: {str(e)}"
|
197 |
|
198 |
return "\n".join(results)
|
199 |
|
200 |
|
201 |
-
def import_file_handler(file,
|
202 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
203 |
try:
|
|
|
|
|
204 |
# Handle max_chunk_size
|
205 |
if isinstance(max_chunk_size, str):
|
206 |
max_chunk_size = int(max_chunk_size) if max_chunk_size.strip() else 4000
|
@@ -221,12 +277,16 @@ def import_file_handler(file, title, author, keywords, system_prompt, custom_pro
|
|
221 |
}
|
222 |
|
223 |
if file is None:
|
|
|
224 |
return "No file uploaded."
|
225 |
|
226 |
file_path = file.name
|
227 |
if not os.path.exists(file_path):
|
|
|
228 |
return "Uploaded file not found."
|
229 |
|
|
|
|
|
230 |
if file_path.lower().endswith('.epub'):
|
231 |
status = import_epub(
|
232 |
file_path,
|
@@ -242,7 +302,8 @@ def import_file_handler(file, title, author, keywords, system_prompt, custom_pro
|
|
242 |
chunk_options=chunk_options,
|
243 |
custom_chapter_pattern=custom_chapter_pattern
|
244 |
)
|
245 |
-
|
|
|
246 |
elif file.name.lower().endswith('.zip'):
|
247 |
status = process_zip_file(
|
248 |
zip_file=file,
|
@@ -251,26 +312,38 @@ def import_file_handler(file, title, author, keywords, system_prompt, custom_pro
|
|
251 |
keywords=keywords,
|
252 |
custom_prompt=custom_prompt,
|
253 |
system_prompt=system_prompt,
|
254 |
-
summary=None,
|
255 |
auto_summarize=auto_summarize,
|
256 |
api_name=api_name,
|
257 |
api_key=api_key,
|
258 |
chunk_options=chunk_options
|
259 |
)
|
260 |
-
|
|
|
261 |
elif file.name.lower().endswith(('.chm', '.html', '.pdf', '.xml', '.opml')):
|
262 |
file_type = file.name.split('.')[-1].upper()
|
263 |
-
|
|
|
264 |
else:
|
265 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
266 |
|
267 |
except ValueError as ve:
|
268 |
logging.exception(f"Error parsing input values: {str(ve)}")
|
|
|
269 |
return f"❌ Error: Invalid input for chunk size or overlap. Please enter valid numbers."
|
270 |
except Exception as e:
|
271 |
logging.exception(f"Error during file import: {str(e)}")
|
|
|
272 |
return f"❌ Error during import: {str(e)}"
|
273 |
|
|
|
274 |
def read_epub(file_path):
|
275 |
"""
|
276 |
Reads and extracts text from an EPUB file.
|
|
|
11 |
#
|
12 |
####################
|
13 |
#
|
14 |
+
# Imports
|
15 |
import os
|
16 |
import re
|
17 |
import tempfile
|
18 |
import zipfile
|
19 |
from datetime import datetime
|
20 |
import logging
|
21 |
+
#
|
22 |
+
# External Imports
|
23 |
import ebooklib
|
24 |
from bs4 import BeautifulSoup
|
25 |
from ebooklib import epub
|
|
|
|
|
26 |
#
|
27 |
# Import Local
|
28 |
from App_Function_Libraries.DB.DB_Manager import add_media_with_keywords, add_media_to_database
|
29 |
from App_Function_Libraries.Summarization.Summarization_General_Lib import perform_summarization
|
30 |
+
from App_Function_Libraries.Chunk_Lib import chunk_ebook_by_chapters
|
31 |
+
from App_Function_Libraries.Metrics.metrics_logger import log_counter, log_histogram
|
32 |
#
|
33 |
#######################################################################################################################
|
34 |
# Function Definitions
|
35 |
#
|
36 |
|
37 |
+
def import_epub(file_path,
|
38 |
+
title=None,
|
39 |
+
author=None,
|
40 |
+
keywords=None,
|
41 |
+
custom_prompt=None,
|
42 |
+
system_prompt=None,
|
43 |
+
summary=None,
|
44 |
+
auto_summarize=False,
|
45 |
+
api_name=None,
|
46 |
+
api_key=None,
|
47 |
+
chunk_options=None,
|
48 |
+
custom_chapter_pattern=None
|
49 |
+
):
|
50 |
"""
|
51 |
Imports an EPUB file, extracts its content, chunks it, optionally summarizes it, and adds it to the database.
|
52 |
|
|
|
68 |
"""
|
69 |
try:
|
70 |
logging.info(f"Importing EPUB file from {file_path}")
|
71 |
+
log_counter("epub_import_attempt", labels={"file_path": file_path})
|
72 |
+
|
73 |
+
start_time = datetime.now()
|
74 |
|
75 |
# Convert EPUB to Markdown
|
76 |
markdown_content = epub_to_markdown(file_path)
|
|
|
103 |
# Chunk the content by chapters
|
104 |
chunks = chunk_ebook_by_chapters(markdown_content, chunk_options)
|
105 |
logging.info(f"Total chunks created: {len(chunks)}")
|
106 |
+
log_histogram("epub_chunks_created", len(chunks), labels={"file_path": file_path})
|
107 |
+
|
108 |
if chunks:
|
109 |
logging.debug(f"Structure of first chunk: {chunks[0].keys()}")
|
110 |
|
|
|
111 |
# Handle summarization if enabled
|
112 |
if auto_summarize and api_name and api_key:
|
113 |
logging.info("Auto-summarization is enabled.")
|
|
|
115 |
for chunk in chunks:
|
116 |
chunk_text = chunk.get('text', '')
|
117 |
if chunk_text:
|
118 |
+
summary_text = perform_summarization(api_name, chunk_text, custom_prompt, api_key,
|
119 |
+
recursive_summarization=False, temp=None,
|
120 |
+
system_message=system_prompt
|
121 |
+
)
|
122 |
chunk['metadata']['summary'] = summary_text
|
123 |
summarized_chunks.append(chunk)
|
124 |
chunks = summarized_chunks
|
125 |
logging.info("Summarization of chunks completed.")
|
126 |
+
log_counter("epub_chunks_summarized", value=len(chunks), labels={"file_path": file_path})
|
127 |
else:
|
128 |
# If not summarizing, set a default summary or use provided summary
|
129 |
if summary:
|
|
|
155 |
overwrite=False
|
156 |
)
|
157 |
|
158 |
+
end_time = datetime.now()
|
159 |
+
processing_time = (end_time - start_time).total_seconds()
|
160 |
+
log_histogram("epub_import_duration", processing_time, labels={"file_path": file_path})
|
161 |
+
|
162 |
logging.info(f"Ebook '{title}' by {author} imported successfully. Database result: {result}")
|
163 |
+
log_counter("epub ingested into the DB successfully", labels={"file_path": file_path})
|
164 |
return f"Ebook '{title}' by {author} imported successfully. Database result: {result}"
|
165 |
|
166 |
except Exception as e:
|
167 |
logging.exception(f"Error importing ebook: {str(e)}")
|
168 |
+
log_counter("epub_import_error", labels={"file_path": file_path, "error": str(e)})
|
169 |
return f"Error importing ebook: {str(e)}"
|
170 |
|
171 |
+
|
172 |
# FIXME
|
173 |
+
def process_zip_file(zip_file,
|
174 |
+
title,
|
175 |
+
author,
|
176 |
+
keywords,
|
177 |
+
custom_prompt,
|
178 |
+
system_prompt,
|
179 |
+
summary,
|
180 |
+
auto_summarize,
|
181 |
+
api_name,
|
182 |
+
api_key,
|
183 |
+
chunk_options
|
184 |
+
):
|
185 |
"""
|
186 |
Processes a ZIP file containing multiple EPUB files and imports each one.
|
187 |
|
|
|
205 |
with tempfile.TemporaryDirectory() as temp_dir:
|
206 |
zip_path = zip_file.name if hasattr(zip_file, 'name') else zip_file.path
|
207 |
logging.info(f"Extracting ZIP file {zip_path} to temporary directory {temp_dir}")
|
208 |
+
log_counter("zip_processing_attempt", labels={"zip_path": zip_path})
|
209 |
+
|
210 |
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
|
211 |
zip_ref.extractall(temp_dir)
|
212 |
|
213 |
+
epub_files = [f for f in os.listdir(temp_dir) if f.lower().endswith('.epub')]
|
214 |
+
log_histogram("epub_files_in_zip", len(epub_files), labels={"zip_path": zip_path})
|
215 |
+
|
216 |
+
for filename in epub_files:
|
217 |
+
file_path = os.path.join(temp_dir, filename)
|
218 |
+
logging.info(f"Processing EPUB file {filename} from ZIP.")
|
219 |
+
result = import_epub(
|
220 |
+
file_path=file_path,
|
221 |
+
title=title,
|
222 |
+
author=author,
|
223 |
+
keywords=keywords,
|
224 |
+
custom_prompt=custom_prompt,
|
225 |
+
summary=summary,
|
226 |
+
auto_summarize=auto_summarize,
|
227 |
+
api_name=api_name,
|
228 |
+
api_key=api_key,
|
229 |
+
chunk_options=chunk_options,
|
230 |
+
custom_chapter_pattern=chunk_options.get('custom_chapter_pattern') if chunk_options else None
|
231 |
+
)
|
232 |
+
results.append(f"File: {filename} - {result}")
|
233 |
+
|
234 |
logging.info("Completed processing all EPUB files in the ZIP.")
|
235 |
+
log_counter("zip_processing_success", labels={"zip_path": zip_path})
|
236 |
except Exception as e:
|
237 |
logging.exception(f"Error processing ZIP file: {str(e)}")
|
238 |
+
log_counter("zip_processing_error", labels={"zip_path": zip_path, "error": str(e)})
|
239 |
return f"Error processing ZIP file: {str(e)}"
|
240 |
|
241 |
return "\n".join(results)
|
242 |
|
243 |
|
244 |
+
def import_file_handler(file,
|
245 |
+
title,
|
246 |
+
author,
|
247 |
+
keywords,
|
248 |
+
system_prompt,
|
249 |
+
custom_prompt,
|
250 |
+
auto_summarize,
|
251 |
+
api_name,
|
252 |
+
api_key,
|
253 |
+
max_chunk_size,
|
254 |
+
chunk_overlap,
|
255 |
+
custom_chapter_pattern
|
256 |
+
):
|
257 |
try:
|
258 |
+
log_counter("file_import_attempt", labels={"file_name": file.name})
|
259 |
+
|
260 |
# Handle max_chunk_size
|
261 |
if isinstance(max_chunk_size, str):
|
262 |
max_chunk_size = int(max_chunk_size) if max_chunk_size.strip() else 4000
|
|
|
277 |
}
|
278 |
|
279 |
if file is None:
|
280 |
+
log_counter("file_import_error", labels={"error": "No file uploaded"})
|
281 |
return "No file uploaded."
|
282 |
|
283 |
file_path = file.name
|
284 |
if not os.path.exists(file_path):
|
285 |
+
log_counter("file_import_error", labels={"error": "File not found", "file_name": file.name})
|
286 |
return "Uploaded file not found."
|
287 |
|
288 |
+
start_time = datetime.now()
|
289 |
+
|
290 |
if file_path.lower().endswith('.epub'):
|
291 |
status = import_epub(
|
292 |
file_path,
|
|
|
302 |
chunk_options=chunk_options,
|
303 |
custom_chapter_pattern=custom_chapter_pattern
|
304 |
)
|
305 |
+
log_counter("epub_import_success", labels={"file_name": file.name})
|
306 |
+
result = f"📚 EPUB Imported Successfully:\n{status}"
|
307 |
elif file.name.lower().endswith('.zip'):
|
308 |
status = process_zip_file(
|
309 |
zip_file=file,
|
|
|
312 |
keywords=keywords,
|
313 |
custom_prompt=custom_prompt,
|
314 |
system_prompt=system_prompt,
|
315 |
+
summary=None,
|
316 |
auto_summarize=auto_summarize,
|
317 |
api_name=api_name,
|
318 |
api_key=api_key,
|
319 |
chunk_options=chunk_options
|
320 |
)
|
321 |
+
log_counter("zip_import_success", labels={"file_name": file.name})
|
322 |
+
result = f"📦 ZIP Processed Successfully:\n{status}"
|
323 |
elif file.name.lower().endswith(('.chm', '.html', '.pdf', '.xml', '.opml')):
|
324 |
file_type = file.name.split('.')[-1].upper()
|
325 |
+
log_counter("unsupported_file_type", labels={"file_type": file_type})
|
326 |
+
result = f"{file_type} file import is not yet supported."
|
327 |
else:
|
328 |
+
log_counter("unsupported_file_type", labels={"file_type": file.name.split('.')[-1]})
|
329 |
+
result = "❌ Unsupported file type. Please upload an `.epub` file or a `.zip` file containing `.epub` files."
|
330 |
+
|
331 |
+
end_time = datetime.now()
|
332 |
+
processing_time = (end_time - start_time).total_seconds()
|
333 |
+
log_histogram("file_import_duration", processing_time, labels={"file_name": file.name})
|
334 |
+
|
335 |
+
return result
|
336 |
|
337 |
except ValueError as ve:
|
338 |
logging.exception(f"Error parsing input values: {str(ve)}")
|
339 |
+
log_counter("file_import_error", labels={"error": "Invalid input", "file_name": file.name})
|
340 |
return f"❌ Error: Invalid input for chunk size or overlap. Please enter valid numbers."
|
341 |
except Exception as e:
|
342 |
logging.exception(f"Error during file import: {str(e)}")
|
343 |
+
log_counter("file_import_error", labels={"error": str(e), "file_name": file.name})
|
344 |
return f"❌ Error during import: {str(e)}"
|
345 |
|
346 |
+
|
347 |
def read_epub(file_path):
|
348 |
"""
|
349 |
Reads and extracts text from an EPUB file.
|
App_Function_Libraries/Character_Chat/Character_Chat_Lib.py
CHANGED
@@ -6,6 +6,7 @@ import json
|
|
6 |
import logging
|
7 |
import io
|
8 |
import base64
|
|
|
9 |
from typing import Dict, Any, Optional, List, Tuple
|
10 |
#
|
11 |
# External Imports
|
@@ -13,6 +14,7 @@ from PIL import Image
|
|
13 |
#
|
14 |
# Local imports
|
15 |
from App_Function_Libraries.DB.DB_Manager import get_character_card_by_id, get_character_chat_by_id
|
|
|
16 |
#
|
17 |
# Constants
|
18 |
####################################################################################################
|
@@ -79,16 +81,32 @@ def replace_user_placeholder(history, user_name):
|
|
79 |
|
80 |
#################################################################################
|
81 |
#
|
82 |
-
#
|
83 |
|
84 |
def extract_character_id(choice: str) -> int:
|
85 |
"""Extract the character ID from the dropdown selection string."""
|
86 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
|
88 |
def load_character_wrapper(character_id: int, user_name: str) -> Tuple[Dict[str, Any], List[Tuple[Optional[str], str]], Optional[Image.Image]]:
|
89 |
"""Wrapper function to load character and image using the extracted ID."""
|
90 |
-
|
91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
|
93 |
def parse_character_book(book_data: Dict[str, Any]) -> Dict[str, Any]:
|
94 |
"""
|
@@ -143,9 +161,12 @@ def load_character_and_image(character_id: int, user_name: str) -> Tuple[Optiona
|
|
143 |
Tuple[Optional[Dict[str, Any]], List[Tuple[Optional[str], str]], Optional[Image.Image]]:
|
144 |
A tuple containing the character data, chat history, and character image (if available).
|
145 |
"""
|
|
|
|
|
146 |
try:
|
147 |
char_data = get_character_card_by_id(character_id)
|
148 |
if not char_data:
|
|
|
149 |
logging.warning(f"No character data found for ID: {character_id}")
|
150 |
return None, [], None
|
151 |
|
@@ -165,12 +186,18 @@ def load_character_and_image(character_id: int, user_name: str) -> Tuple[Optiona
|
|
165 |
try:
|
166 |
image_data = base64.b64decode(char_data['image'])
|
167 |
img = Image.open(io.BytesIO(image_data)).convert("RGBA")
|
|
|
168 |
except Exception as e:
|
|
|
169 |
logging.error(f"Error processing image for character '{char_data['name']}': {e}")
|
170 |
|
|
|
|
|
|
|
171 |
return char_data, chat_history, img
|
172 |
|
173 |
except Exception as e:
|
|
|
174 |
logging.error(f"Error in load_character_and_image: {e}")
|
175 |
return None, [], None
|
176 |
|
@@ -186,10 +213,13 @@ def load_chat_and_character(chat_id: int, user_name: str) -> Tuple[Optional[Dict
|
|
186 |
Tuple[Optional[Dict[str, Any]], List[Tuple[str, str]], Optional[Image.Image]]:
|
187 |
A tuple containing the character data, processed chat history, and character image (if available).
|
188 |
"""
|
|
|
|
|
189 |
try:
|
190 |
# Load the chat
|
191 |
chat = get_character_chat_by_id(chat_id)
|
192 |
if not chat:
|
|
|
193 |
logging.warning(f"No chat found with ID: {chat_id}")
|
194 |
return None, [], None
|
195 |
|
@@ -197,6 +227,7 @@ def load_chat_and_character(chat_id: int, user_name: str) -> Tuple[Optional[Dict
|
|
197 |
character_id = chat['character_id']
|
198 |
char_data = get_character_card_by_id(character_id)
|
199 |
if not char_data:
|
|
|
200 |
logging.warning(f"No character found for chat ID: {chat_id}")
|
201 |
return None, chat['chat_history'], None
|
202 |
|
@@ -209,7 +240,9 @@ def load_chat_and_character(chat_id: int, user_name: str) -> Tuple[Optional[Dict
|
|
209 |
try:
|
210 |
image_data = base64.b64decode(char_data['image'])
|
211 |
img = Image.open(io.BytesIO(image_data)).convert("RGBA")
|
|
|
212 |
except Exception as e:
|
|
|
213 |
logging.error(f"Error processing image for character '{char_data['name']}': {e}")
|
214 |
|
215 |
# Process character data templates
|
@@ -217,14 +250,21 @@ def load_chat_and_character(chat_id: int, user_name: str) -> Tuple[Optional[Dict
|
|
217 |
if field in char_data:
|
218 |
char_data[field] = replace_placeholders(char_data[field], char_data['name'], user_name)
|
219 |
|
|
|
|
|
|
|
220 |
return char_data, processed_history, img
|
221 |
|
222 |
except Exception as e:
|
|
|
223 |
logging.error(f"Error in load_chat_and_character: {e}")
|
224 |
return None, [], None
|
225 |
|
|
|
226 |
def extract_json_from_image(image_file):
|
227 |
logging.debug(f"Attempting to extract JSON from image: {image_file.name}")
|
|
|
|
|
228 |
try:
|
229 |
with Image.open(image_file) as img:
|
230 |
logging.debug("Image opened successfully")
|
@@ -236,16 +276,18 @@ def extract_json_from_image(image_file):
|
|
236 |
try:
|
237 |
decoded_content = base64.b64decode(chara_content).decode('utf-8')
|
238 |
logging.debug(f"Decoded content (first 100 chars): {decoded_content[:100]}...")
|
|
|
239 |
return decoded_content
|
240 |
except Exception as e:
|
241 |
logging.error(f"Error decoding base64 content: {e}")
|
|
|
242 |
|
243 |
logging.warning("'chara' not found in metadata, attempting to find JSON data in image bytes")
|
244 |
# Alternative method to extract embedded JSON from image bytes if metadata is not available
|
245 |
img_byte_arr = io.BytesIO()
|
246 |
img.save(img_byte_arr, format='PNG')
|
247 |
img_bytes = img_byte_arr.getvalue()
|
248 |
-
img_str = img_bytes.decode('latin1')
|
249 |
|
250 |
# Search for JSON-like structures in the image bytes
|
251 |
json_start = img_str.find('{')
|
@@ -255,18 +297,26 @@ def extract_json_from_image(image_file):
|
|
255 |
try:
|
256 |
json.loads(possible_json)
|
257 |
logging.debug("Found JSON data in image bytes")
|
|
|
258 |
return possible_json
|
259 |
except json.JSONDecodeError:
|
260 |
logging.debug("No valid JSON found in image bytes")
|
|
|
261 |
|
262 |
logging.warning("No JSON data found in the image")
|
|
|
263 |
except Exception as e:
|
|
|
264 |
logging.error(f"Error extracting JSON from image: {e}")
|
265 |
-
return None
|
266 |
|
|
|
|
|
|
|
267 |
|
268 |
|
269 |
def load_chat_history(file):
|
|
|
|
|
270 |
try:
|
271 |
content = file.read().decode('utf-8')
|
272 |
chat_data = json.loads(content)
|
@@ -276,11 +326,16 @@ def load_chat_history(file):
|
|
276 |
character_name = chat_data.get('character') or chat_data.get('character_name')
|
277 |
|
278 |
if not history or not character_name:
|
|
|
279 |
logging.error("Chat history or character name missing in the imported file.")
|
280 |
return None, None
|
281 |
|
|
|
|
|
|
|
282 |
return history, character_name
|
283 |
except Exception as e:
|
|
|
284 |
logging.error(f"Error loading chat history: {e}")
|
285 |
return None, None
|
286 |
|
@@ -297,14 +352,25 @@ def process_chat_history(chat_history: List[Tuple[str, str]], char_name: str, us
|
|
297 |
Returns:
|
298 |
List[Tuple[str, str]]: The processed chat history.
|
299 |
"""
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
308 |
|
309 |
def validate_character_book(book_data):
|
310 |
"""
|
|
|
6 |
import logging
|
7 |
import io
|
8 |
import base64
|
9 |
+
import time
|
10 |
from typing import Dict, Any, Optional, List, Tuple
|
11 |
#
|
12 |
# External Imports
|
|
|
14 |
#
|
15 |
# Local imports
|
16 |
from App_Function_Libraries.DB.DB_Manager import get_character_card_by_id, get_character_chat_by_id
|
17 |
+
from App_Function_Libraries.Metrics.metrics_logger import log_counter, log_histogram
|
18 |
#
|
19 |
# Constants
|
20 |
####################################################################################################
|
|
|
81 |
|
82 |
#################################################################################
|
83 |
#
|
84 |
+
# Functions for character card processing:
|
85 |
|
86 |
def extract_character_id(choice: str) -> int:
|
87 |
"""Extract the character ID from the dropdown selection string."""
|
88 |
+
log_counter("extract_character_id_attempt")
|
89 |
+
try:
|
90 |
+
character_id = int(choice.split('(ID: ')[1].rstrip(')'))
|
91 |
+
log_counter("extract_character_id_success")
|
92 |
+
return character_id
|
93 |
+
except Exception as e:
|
94 |
+
log_counter("extract_character_id_error", labels={"error": str(e)})
|
95 |
+
raise
|
96 |
|
97 |
def load_character_wrapper(character_id: int, user_name: str) -> Tuple[Dict[str, Any], List[Tuple[Optional[str], str]], Optional[Image.Image]]:
|
98 |
"""Wrapper function to load character and image using the extracted ID."""
|
99 |
+
log_counter("load_character_wrapper_attempt")
|
100 |
+
start_time = time.time()
|
101 |
+
try:
|
102 |
+
char_data, chat_history, img = load_character_and_image(character_id, user_name)
|
103 |
+
load_duration = time.time() - start_time
|
104 |
+
log_histogram("load_character_wrapper_duration", load_duration)
|
105 |
+
log_counter("load_character_wrapper_success")
|
106 |
+
return char_data, chat_history, img
|
107 |
+
except Exception as e:
|
108 |
+
log_counter("load_character_wrapper_error", labels={"error": str(e)})
|
109 |
+
raise
|
110 |
|
111 |
def parse_character_book(book_data: Dict[str, Any]) -> Dict[str, Any]:
|
112 |
"""
|
|
|
161 |
Tuple[Optional[Dict[str, Any]], List[Tuple[Optional[str], str]], Optional[Image.Image]]:
|
162 |
A tuple containing the character data, chat history, and character image (if available).
|
163 |
"""
|
164 |
+
log_counter("load_character_and_image_attempt")
|
165 |
+
start_time = time.time()
|
166 |
try:
|
167 |
char_data = get_character_card_by_id(character_id)
|
168 |
if not char_data:
|
169 |
+
log_counter("load_character_and_image_no_data")
|
170 |
logging.warning(f"No character data found for ID: {character_id}")
|
171 |
return None, [], None
|
172 |
|
|
|
186 |
try:
|
187 |
image_data = base64.b64decode(char_data['image'])
|
188 |
img = Image.open(io.BytesIO(image_data)).convert("RGBA")
|
189 |
+
log_counter("load_character_image_success")
|
190 |
except Exception as e:
|
191 |
+
log_counter("load_character_image_error", labels={"error": str(e)})
|
192 |
logging.error(f"Error processing image for character '{char_data['name']}': {e}")
|
193 |
|
194 |
+
load_duration = time.time() - start_time
|
195 |
+
log_histogram("load_character_and_image_duration", load_duration)
|
196 |
+
log_counter("load_character_and_image_success")
|
197 |
return char_data, chat_history, img
|
198 |
|
199 |
except Exception as e:
|
200 |
+
log_counter("load_character_and_image_error", labels={"error": str(e)})
|
201 |
logging.error(f"Error in load_character_and_image: {e}")
|
202 |
return None, [], None
|
203 |
|
|
|
213 |
Tuple[Optional[Dict[str, Any]], List[Tuple[str, str]], Optional[Image.Image]]:
|
214 |
A tuple containing the character data, processed chat history, and character image (if available).
|
215 |
"""
|
216 |
+
log_counter("load_chat_and_character_attempt")
|
217 |
+
start_time = time.time()
|
218 |
try:
|
219 |
# Load the chat
|
220 |
chat = get_character_chat_by_id(chat_id)
|
221 |
if not chat:
|
222 |
+
log_counter("load_chat_and_character_no_chat")
|
223 |
logging.warning(f"No chat found with ID: {chat_id}")
|
224 |
return None, [], None
|
225 |
|
|
|
227 |
character_id = chat['character_id']
|
228 |
char_data = get_character_card_by_id(character_id)
|
229 |
if not char_data:
|
230 |
+
log_counter("load_chat_and_character_no_character")
|
231 |
logging.warning(f"No character found for chat ID: {chat_id}")
|
232 |
return None, chat['chat_history'], None
|
233 |
|
|
|
240 |
try:
|
241 |
image_data = base64.b64decode(char_data['image'])
|
242 |
img = Image.open(io.BytesIO(image_data)).convert("RGBA")
|
243 |
+
log_counter("load_chat_character_image_success")
|
244 |
except Exception as e:
|
245 |
+
log_counter("load_chat_character_image_error", labels={"error": str(e)})
|
246 |
logging.error(f"Error processing image for character '{char_data['name']}': {e}")
|
247 |
|
248 |
# Process character data templates
|
|
|
250 |
if field in char_data:
|
251 |
char_data[field] = replace_placeholders(char_data[field], char_data['name'], user_name)
|
252 |
|
253 |
+
load_duration = time.time() - start_time
|
254 |
+
log_histogram("load_chat_and_character_duration", load_duration)
|
255 |
+
log_counter("load_chat_and_character_success")
|
256 |
return char_data, processed_history, img
|
257 |
|
258 |
except Exception as e:
|
259 |
+
log_counter("load_chat_and_character_error", labels={"error": str(e)})
|
260 |
logging.error(f"Error in load_chat_and_character: {e}")
|
261 |
return None, [], None
|
262 |
|
263 |
+
|
264 |
def extract_json_from_image(image_file):
|
265 |
logging.debug(f"Attempting to extract JSON from image: {image_file.name}")
|
266 |
+
log_counter("extract_json_from_image_attempt")
|
267 |
+
start_time = time.time()
|
268 |
try:
|
269 |
with Image.open(image_file) as img:
|
270 |
logging.debug("Image opened successfully")
|
|
|
276 |
try:
|
277 |
decoded_content = base64.b64decode(chara_content).decode('utf-8')
|
278 |
logging.debug(f"Decoded content (first 100 chars): {decoded_content[:100]}...")
|
279 |
+
log_counter("extract_json_from_image_metadata_success")
|
280 |
return decoded_content
|
281 |
except Exception as e:
|
282 |
logging.error(f"Error decoding base64 content: {e}")
|
283 |
+
log_counter("extract_json_from_image_decode_error", labels={"error": str(e)})
|
284 |
|
285 |
logging.warning("'chara' not found in metadata, attempting to find JSON data in image bytes")
|
286 |
# Alternative method to extract embedded JSON from image bytes if metadata is not available
|
287 |
img_byte_arr = io.BytesIO()
|
288 |
img.save(img_byte_arr, format='PNG')
|
289 |
img_bytes = img_byte_arr.getvalue()
|
290 |
+
img_str = img_bytes.decode('latin1')
|
291 |
|
292 |
# Search for JSON-like structures in the image bytes
|
293 |
json_start = img_str.find('{')
|
|
|
297 |
try:
|
298 |
json.loads(possible_json)
|
299 |
logging.debug("Found JSON data in image bytes")
|
300 |
+
log_counter("extract_json_from_image_bytes_success")
|
301 |
return possible_json
|
302 |
except json.JSONDecodeError:
|
303 |
logging.debug("No valid JSON found in image bytes")
|
304 |
+
log_counter("extract_json_from_image_invalid_json")
|
305 |
|
306 |
logging.warning("No JSON data found in the image")
|
307 |
+
log_counter("extract_json_from_image_no_json_found")
|
308 |
except Exception as e:
|
309 |
+
log_counter("extract_json_from_image_error", labels={"error": str(e)})
|
310 |
logging.error(f"Error extracting JSON from image: {e}")
|
|
|
311 |
|
312 |
+
extract_duration = time.time() - start_time
|
313 |
+
log_histogram("extract_json_from_image_duration", extract_duration)
|
314 |
+
return None
|
315 |
|
316 |
|
317 |
def load_chat_history(file):
|
318 |
+
log_counter("load_chat_history_attempt")
|
319 |
+
start_time = time.time()
|
320 |
try:
|
321 |
content = file.read().decode('utf-8')
|
322 |
chat_data = json.loads(content)
|
|
|
326 |
character_name = chat_data.get('character') or chat_data.get('character_name')
|
327 |
|
328 |
if not history or not character_name:
|
329 |
+
log_counter("load_chat_history_incomplete_data")
|
330 |
logging.error("Chat history or character name missing in the imported file.")
|
331 |
return None, None
|
332 |
|
333 |
+
load_duration = time.time() - start_time
|
334 |
+
log_histogram("load_chat_history_duration", load_duration)
|
335 |
+
log_counter("load_chat_history_success")
|
336 |
return history, character_name
|
337 |
except Exception as e:
|
338 |
+
log_counter("load_chat_history_error", labels={"error": str(e)})
|
339 |
logging.error(f"Error loading chat history: {e}")
|
340 |
return None, None
|
341 |
|
|
|
352 |
Returns:
|
353 |
List[Tuple[str, str]]: The processed chat history.
|
354 |
"""
|
355 |
+
log_counter("process_chat_history_attempt")
|
356 |
+
start_time = time.time()
|
357 |
+
try:
|
358 |
+
processed_history = []
|
359 |
+
for user_msg, char_msg in chat_history:
|
360 |
+
if user_msg:
|
361 |
+
user_msg = replace_placeholders(user_msg, char_name, user_name)
|
362 |
+
if char_msg:
|
363 |
+
char_msg = replace_placeholders(char_msg, char_name, user_name)
|
364 |
+
processed_history.append((user_msg, char_msg))
|
365 |
+
|
366 |
+
process_duration = time.time() - start_time
|
367 |
+
log_histogram("process_chat_history_duration", process_duration)
|
368 |
+
log_counter("process_chat_history_success", labels={"message_count": len(chat_history)})
|
369 |
+
return processed_history
|
370 |
+
except Exception as e:
|
371 |
+
log_counter("process_chat_history_error", labels={"error": str(e)})
|
372 |
+
logging.error(f"Error processing chat history: {e}")
|
373 |
+
raise
|
374 |
|
375 |
def validate_character_book(book_data):
|
376 |
"""
|
App_Function_Libraries/Chat.py
CHANGED
@@ -7,6 +7,7 @@ import logging
|
|
7 |
import os
|
8 |
import re
|
9 |
import tempfile
|
|
|
10 |
from datetime import datetime
|
11 |
from pathlib import Path
|
12 |
#
|
@@ -15,20 +16,20 @@ from pathlib import Path
|
|
15 |
# Local Imports
|
16 |
from App_Function_Libraries.DB.DB_Manager import get_conversation_name, save_chat_history_to_database
|
17 |
from App_Function_Libraries.LLM_API_Calls import chat_with_openai, chat_with_anthropic, chat_with_cohere, \
|
18 |
-
chat_with_groq, chat_with_openrouter, chat_with_deepseek, chat_with_mistral, chat_with_huggingface
|
19 |
from App_Function_Libraries.LLM_API_Calls_Local import chat_with_aphrodite, chat_with_local_llm, chat_with_ollama, \
|
20 |
chat_with_kobold, chat_with_llama, chat_with_oobabooga, chat_with_tabbyapi, chat_with_vllm, chat_with_custom_openai
|
21 |
from App_Function_Libraries.DB.SQLite_DB import load_media_content
|
22 |
from App_Function_Libraries.Utils.Utils import generate_unique_filename, load_and_log_configs
|
23 |
-
|
24 |
-
|
25 |
#
|
26 |
####################################################################################################
|
27 |
#
|
28 |
# Functions:
|
29 |
|
30 |
-
|
31 |
def chat_api_call(api_endpoint, api_key, input_data, prompt, temp, system_message=None):
|
|
|
|
|
32 |
if not api_key:
|
33 |
api_key = None
|
34 |
model = None
|
@@ -105,14 +106,21 @@ def chat_api_call(api_endpoint, api_key, input_data, prompt, temp, system_messag
|
|
105 |
else:
|
106 |
raise ValueError(f"Unsupported API endpoint: {api_endpoint}")
|
107 |
|
|
|
|
|
|
|
108 |
return response
|
109 |
|
110 |
except Exception as e:
|
|
|
111 |
logging.error(f"Error in chat function: {str(e)}")
|
112 |
return f"An error occurred: {str(e)}"
|
113 |
|
|
|
114 |
def chat(message, history, media_content, selected_parts, api_endpoint, api_key, prompt, temperature,
|
115 |
system_message=None):
|
|
|
|
|
116 |
try:
|
117 |
logging.info(f"Debug - Chat Function - Message: {message}")
|
118 |
logging.info(f"Debug - Chat Function - Media Content: {media_content}")
|
@@ -151,14 +159,19 @@ def chat(message, history, media_content, selected_parts, api_endpoint, api_key,
|
|
151 |
# Use the existing API request code based on the selected endpoint
|
152 |
response = chat_api_call(api_endpoint, api_key, input_data, prompt, temp, system_message)
|
153 |
|
|
|
|
|
|
|
154 |
return response
|
155 |
except Exception as e:
|
|
|
156 |
logging.error(f"Error in chat function: {str(e)}")
|
157 |
return f"An error occurred: {str(e)}"
|
158 |
|
159 |
|
160 |
-
|
161 |
def save_chat_history_to_db_wrapper(chatbot, conversation_id, media_content, media_name=None):
|
|
|
|
|
162 |
logging.info(f"Attempting to save chat history. Media content type: {type(media_content)}")
|
163 |
try:
|
164 |
# Extract the media_id and media_name from the media_content
|
@@ -206,14 +219,20 @@ def save_chat_history_to_db_wrapper(chatbot, conversation_id, media_content, med
|
|
206 |
|
207 |
new_conversation_id = save_chat_history_to_database(chatbot, conversation_id, media_id, media_name,
|
208 |
conversation_name)
|
|
|
|
|
|
|
209 |
return new_conversation_id, f"Chat history saved successfully as {conversation_name}!"
|
210 |
except Exception as e:
|
|
|
211 |
error_message = f"Failed to save chat history: {str(e)}"
|
212 |
logging.error(error_message, exc_info=True)
|
213 |
return conversation_id, error_message
|
214 |
|
215 |
|
216 |
def save_chat_history(history, conversation_id, media_content):
|
|
|
|
|
217 |
try:
|
218 |
content, conversation_name = generate_chat_history_content(history, conversation_id, media_content)
|
219 |
|
@@ -233,8 +252,12 @@ def save_chat_history(history, conversation_id, media_content):
|
|
233 |
# Rename the temporary file to the unique filename
|
234 |
os.rename(temp_file_path, final_path)
|
235 |
|
|
|
|
|
|
|
236 |
return final_path
|
237 |
except Exception as e:
|
|
|
238 |
logging.error(f"Error saving chat history: {str(e)}")
|
239 |
return None
|
240 |
|
@@ -286,6 +309,8 @@ def extract_media_name(media_content):
|
|
286 |
|
287 |
|
288 |
def update_chat_content(selected_item, use_content, use_summary, use_prompt, item_mapping):
|
|
|
|
|
289 |
logging.debug(f"Debug - Update Chat Content - Selected Item: {selected_item}\n")
|
290 |
logging.debug(f"Debug - Update Chat Content - Use Content: {use_content}\n\n\n\n")
|
291 |
logging.debug(f"Debug - Update Chat Content - Use Summary: {use_summary}\n\n")
|
@@ -312,17 +337,21 @@ def update_chat_content(selected_item, use_content, use_summary, use_prompt, ite
|
|
312 |
print(f"Debug - Update Chat Content - Content(first 500 char): {str(content)[:500]}\n\n\n\n")
|
313 |
|
314 |
print(f"Debug - Update Chat Content - Selected Parts: {selected_parts}")
|
|
|
|
|
|
|
315 |
return content, selected_parts
|
316 |
else:
|
|
|
317 |
print(f"Debug - Update Chat Content - No item selected or item not in mapping")
|
318 |
return {}, []
|
319 |
|
320 |
#
|
321 |
# End of Chat functions
|
322 |
-
|
323 |
|
324 |
|
325 |
-
|
326 |
#
|
327 |
# Character Card Functions
|
328 |
|
@@ -330,6 +359,8 @@ CHARACTERS_FILE = Path('.', 'Helper_Scripts', 'Character_Cards', 'Characters.jso
|
|
330 |
|
331 |
|
332 |
def save_character(character_data):
|
|
|
|
|
333 |
characters_file = os.path.join(os.path.dirname(__file__), '..', 'Helper_Scripts', 'Character_Cards', 'Characters.json')
|
334 |
characters_dir = os.path.dirname(characters_file)
|
335 |
|
@@ -357,28 +388,52 @@ def save_character(character_data):
|
|
357 |
with open(characters_file, 'w') as f:
|
358 |
json.dump(characters, f, indent=2)
|
359 |
|
|
|
|
|
|
|
360 |
logging.info(f"Character '{char_name}' saved successfully.")
|
361 |
except Exception as e:
|
|
|
362 |
logging.error(f"Error saving character: {str(e)}")
|
363 |
|
364 |
|
365 |
-
|
366 |
def load_characters():
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
375 |
|
376 |
|
377 |
-
def get_character_names():
|
378 |
-
characters = load_characters()
|
379 |
-
return list(characters.keys())
|
380 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
381 |
|
382 |
#
|
383 |
# End of Chat.py
|
384 |
-
##########################################################################################################################
|
|
|
7 |
import os
|
8 |
import re
|
9 |
import tempfile
|
10 |
+
import time
|
11 |
from datetime import datetime
|
12 |
from pathlib import Path
|
13 |
#
|
|
|
16 |
# Local Imports
|
17 |
from App_Function_Libraries.DB.DB_Manager import get_conversation_name, save_chat_history_to_database
|
18 |
from App_Function_Libraries.LLM_API_Calls import chat_with_openai, chat_with_anthropic, chat_with_cohere, \
|
19 |
+
chat_with_groq, chat_with_openrouter, chat_with_deepseek, chat_with_mistral, chat_with_huggingface
|
20 |
from App_Function_Libraries.LLM_API_Calls_Local import chat_with_aphrodite, chat_with_local_llm, chat_with_ollama, \
|
21 |
chat_with_kobold, chat_with_llama, chat_with_oobabooga, chat_with_tabbyapi, chat_with_vllm, chat_with_custom_openai
|
22 |
from App_Function_Libraries.DB.SQLite_DB import load_media_content
|
23 |
from App_Function_Libraries.Utils.Utils import generate_unique_filename, load_and_log_configs
|
24 |
+
from App_Function_Libraries.Metrics.metrics_logger import log_counter, log_histogram
|
|
|
25 |
#
|
26 |
####################################################################################################
|
27 |
#
|
28 |
# Functions:
|
29 |
|
|
|
30 |
def chat_api_call(api_endpoint, api_key, input_data, prompt, temp, system_message=None):
|
31 |
+
log_counter("chat_api_call_attempt", labels={"api_endpoint": api_endpoint})
|
32 |
+
start_time = time.time()
|
33 |
if not api_key:
|
34 |
api_key = None
|
35 |
model = None
|
|
|
106 |
else:
|
107 |
raise ValueError(f"Unsupported API endpoint: {api_endpoint}")
|
108 |
|
109 |
+
call_duration = time.time() - start_time
|
110 |
+
log_histogram("chat_api_call_duration", call_duration, labels={"api_endpoint": api_endpoint})
|
111 |
+
log_counter("chat_api_call_success", labels={"api_endpoint": api_endpoint})
|
112 |
return response
|
113 |
|
114 |
except Exception as e:
|
115 |
+
log_counter("chat_api_call_error", labels={"api_endpoint": api_endpoint, "error": str(e)})
|
116 |
logging.error(f"Error in chat function: {str(e)}")
|
117 |
return f"An error occurred: {str(e)}"
|
118 |
|
119 |
+
|
120 |
def chat(message, history, media_content, selected_parts, api_endpoint, api_key, prompt, temperature,
|
121 |
system_message=None):
|
122 |
+
log_counter("chat_attempt", labels={"api_endpoint": api_endpoint})
|
123 |
+
start_time = time.time()
|
124 |
try:
|
125 |
logging.info(f"Debug - Chat Function - Message: {message}")
|
126 |
logging.info(f"Debug - Chat Function - Media Content: {media_content}")
|
|
|
159 |
# Use the existing API request code based on the selected endpoint
|
160 |
response = chat_api_call(api_endpoint, api_key, input_data, prompt, temp, system_message)
|
161 |
|
162 |
+
chat_duration = time.time() - start_time
|
163 |
+
log_histogram("chat_duration", chat_duration, labels={"api_endpoint": api_endpoint})
|
164 |
+
log_counter("chat_success", labels={"api_endpoint": api_endpoint})
|
165 |
return response
|
166 |
except Exception as e:
|
167 |
+
log_counter("chat_error", labels={"api_endpoint": api_endpoint, "error": str(e)})
|
168 |
logging.error(f"Error in chat function: {str(e)}")
|
169 |
return f"An error occurred: {str(e)}"
|
170 |
|
171 |
|
|
|
172 |
def save_chat_history_to_db_wrapper(chatbot, conversation_id, media_content, media_name=None):
|
173 |
+
log_counter("save_chat_history_to_db_attempt")
|
174 |
+
start_time = time.time()
|
175 |
logging.info(f"Attempting to save chat history. Media content type: {type(media_content)}")
|
176 |
try:
|
177 |
# Extract the media_id and media_name from the media_content
|
|
|
219 |
|
220 |
new_conversation_id = save_chat_history_to_database(chatbot, conversation_id, media_id, media_name,
|
221 |
conversation_name)
|
222 |
+
save_duration = time.time() - start_time
|
223 |
+
log_histogram("save_chat_history_to_db_duration", save_duration)
|
224 |
+
log_counter("save_chat_history_to_db_success")
|
225 |
return new_conversation_id, f"Chat history saved successfully as {conversation_name}!"
|
226 |
except Exception as e:
|
227 |
+
log_counter("save_chat_history_to_db_error", labels={"error": str(e)})
|
228 |
error_message = f"Failed to save chat history: {str(e)}"
|
229 |
logging.error(error_message, exc_info=True)
|
230 |
return conversation_id, error_message
|
231 |
|
232 |
|
233 |
def save_chat_history(history, conversation_id, media_content):
|
234 |
+
log_counter("save_chat_history_attempt")
|
235 |
+
start_time = time.time()
|
236 |
try:
|
237 |
content, conversation_name = generate_chat_history_content(history, conversation_id, media_content)
|
238 |
|
|
|
252 |
# Rename the temporary file to the unique filename
|
253 |
os.rename(temp_file_path, final_path)
|
254 |
|
255 |
+
save_duration = time.time() - start_time
|
256 |
+
log_histogram("save_chat_history_duration", save_duration)
|
257 |
+
log_counter("save_chat_history_success")
|
258 |
return final_path
|
259 |
except Exception as e:
|
260 |
+
log_counter("save_chat_history_error", labels={"error": str(e)})
|
261 |
logging.error(f"Error saving chat history: {str(e)}")
|
262 |
return None
|
263 |
|
|
|
309 |
|
310 |
|
311 |
def update_chat_content(selected_item, use_content, use_summary, use_prompt, item_mapping):
|
312 |
+
log_counter("update_chat_content_attempt")
|
313 |
+
start_time = time.time()
|
314 |
logging.debug(f"Debug - Update Chat Content - Selected Item: {selected_item}\n")
|
315 |
logging.debug(f"Debug - Update Chat Content - Use Content: {use_content}\n\n\n\n")
|
316 |
logging.debug(f"Debug - Update Chat Content - Use Summary: {use_summary}\n\n")
|
|
|
337 |
print(f"Debug - Update Chat Content - Content(first 500 char): {str(content)[:500]}\n\n\n\n")
|
338 |
|
339 |
print(f"Debug - Update Chat Content - Selected Parts: {selected_parts}")
|
340 |
+
update_duration = time.time() - start_time
|
341 |
+
log_histogram("update_chat_content_duration", update_duration)
|
342 |
+
log_counter("update_chat_content_success")
|
343 |
return content, selected_parts
|
344 |
else:
|
345 |
+
log_counter("update_chat_content_error", labels={"error": str("No item selected or item not in mapping")})
|
346 |
print(f"Debug - Update Chat Content - No item selected or item not in mapping")
|
347 |
return {}, []
|
348 |
|
349 |
#
|
350 |
# End of Chat functions
|
351 |
+
#######################################################################################################################
|
352 |
|
353 |
|
354 |
+
#######################################################################################################################
|
355 |
#
|
356 |
# Character Card Functions
|
357 |
|
|
|
359 |
|
360 |
|
361 |
def save_character(character_data):
|
362 |
+
log_counter("save_character_attempt")
|
363 |
+
start_time = time.time()
|
364 |
characters_file = os.path.join(os.path.dirname(__file__), '..', 'Helper_Scripts', 'Character_Cards', 'Characters.json')
|
365 |
characters_dir = os.path.dirname(characters_file)
|
366 |
|
|
|
388 |
with open(characters_file, 'w') as f:
|
389 |
json.dump(characters, f, indent=2)
|
390 |
|
391 |
+
save_duration = time.time() - start_time
|
392 |
+
log_histogram("save_character_duration", save_duration)
|
393 |
+
log_counter("save_character_success")
|
394 |
logging.info(f"Character '{char_name}' saved successfully.")
|
395 |
except Exception as e:
|
396 |
+
log_counter("save_character_error", labels={"error": str(e)})
|
397 |
logging.error(f"Error saving character: {str(e)}")
|
398 |
|
399 |
|
|
|
400 |
def load_characters():
|
401 |
+
log_counter("load_characters_attempt")
|
402 |
+
start_time = time.time()
|
403 |
+
try:
|
404 |
+
characters_file = os.path.join(os.path.dirname(__file__), '..', 'Helper_Scripts', 'Character_Cards', 'Characters.json')
|
405 |
+
if os.path.exists(characters_file):
|
406 |
+
with open(characters_file, 'r') as f:
|
407 |
+
characters = json.load(f)
|
408 |
+
logging.debug(f"Loaded {len(characters)} characters from {characters_file}")
|
409 |
+
load_duration = time.time() - start_time
|
410 |
+
log_histogram("load_characters_duration", load_duration)
|
411 |
+
log_counter("load_characters_success", labels={"character_count": len(characters)})
|
412 |
+
return characters
|
413 |
+
else:
|
414 |
+
logging.warning(f"Characters file not found: {characters_file}")
|
415 |
+
return {}
|
416 |
+
except Exception as e:
|
417 |
+
log_counter("load_characters_error", labels={"error": str(e)})
|
418 |
+
return {}
|
419 |
|
420 |
|
|
|
|
|
|
|
421 |
|
422 |
+
def get_character_names():
|
423 |
+
log_counter("get_character_names_attempt")
|
424 |
+
start_time = time.time()
|
425 |
+
try:
|
426 |
+
characters = load_characters()
|
427 |
+
names = list(characters.keys())
|
428 |
+
get_names_duration = time.time() - start_time
|
429 |
+
log_histogram("get_character_names_duration", get_names_duration)
|
430 |
+
log_counter("get_character_names_success", labels={"name_count": len(names)})
|
431 |
+
return names
|
432 |
+
except Exception as e:
|
433 |
+
log_counter("get_character_names_error", labels={"error": str(e)})
|
434 |
+
logging.error(f"Error getting character names: {str(e)}")
|
435 |
+
return []
|
436 |
|
437 |
#
|
438 |
# End of Chat.py
|
439 |
+
##########################################################################################################################
|
App_Function_Libraries/DB/Character_Chat_DB.py
CHANGED
@@ -1,701 +1,701 @@
|
|
1 |
-
# character_chat_db.py
|
2 |
-
# Database functions for managing character cards and chat histories.
|
3 |
-
# #
|
4 |
-
# Imports
|
5 |
-
import configparser
|
6 |
-
import sqlite3
|
7 |
-
import
|
8 |
-
import
|
9 |
-
import
|
10 |
-
import
|
11 |
-
|
12 |
-
|
13 |
-
from
|
14 |
-
|
15 |
-
#
|
16 |
-
#######################################################################################################################
|
17 |
-
#
|
18 |
-
#
|
19 |
-
|
20 |
-
def ensure_database_directory():
|
21 |
-
os.makedirs(get_database_dir(), exist_ok=True)
|
22 |
-
|
23 |
-
ensure_database_directory()
|
24 |
-
|
25 |
-
|
26 |
-
# Construct the path to the config file
|
27 |
-
config_path = get_project_relative_path('Config_Files/config.txt')
|
28 |
-
|
29 |
-
# Read the config file
|
30 |
-
config = configparser.ConfigParser()
|
31 |
-
config.read(config_path)
|
32 |
-
|
33 |
-
# Get the chat db path from the config, or use the default if not specified
|
34 |
-
chat_DB_PATH = config.get('Database', 'chatDB_path', fallback=get_database_path('chatDB.db'))
|
35 |
-
print(f"Chat Database path: {chat_DB_PATH}")
|
36 |
-
|
37 |
-
########################################################################################################
|
38 |
-
#
|
39 |
-
# Functions
|
40 |
-
|
41 |
-
# FIXME - Setup properly and test/add documentation for its existence...
|
42 |
-
def initialize_database():
|
43 |
-
"""Initialize the SQLite database with required tables and FTS5 virtual tables."""
|
44 |
-
conn = None
|
45 |
-
try:
|
46 |
-
conn = sqlite3.connect(chat_DB_PATH)
|
47 |
-
cursor = conn.cursor()
|
48 |
-
|
49 |
-
# Enable foreign key constraints
|
50 |
-
cursor.execute("PRAGMA foreign_keys = ON;")
|
51 |
-
|
52 |
-
# Create CharacterCards table with V2 fields
|
53 |
-
cursor.execute("""
|
54 |
-
CREATE TABLE IF NOT EXISTS CharacterCards (
|
55 |
-
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
56 |
-
name TEXT UNIQUE NOT NULL,
|
57 |
-
description TEXT,
|
58 |
-
personality TEXT,
|
59 |
-
scenario TEXT,
|
60 |
-
image BLOB,
|
61 |
-
post_history_instructions TEXT,
|
62 |
-
first_mes TEXT,
|
63 |
-
mes_example TEXT,
|
64 |
-
creator_notes TEXT,
|
65 |
-
system_prompt TEXT,
|
66 |
-
alternate_greetings TEXT,
|
67 |
-
tags TEXT,
|
68 |
-
creator TEXT,
|
69 |
-
character_version TEXT,
|
70 |
-
extensions TEXT,
|
71 |
-
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
72 |
-
);
|
73 |
-
""")
|
74 |
-
|
75 |
-
# Create CharacterChats table
|
76 |
-
cursor.execute("""
|
77 |
-
CREATE TABLE IF NOT EXISTS CharacterChats (
|
78 |
-
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
79 |
-
character_id INTEGER NOT NULL,
|
80 |
-
conversation_name TEXT,
|
81 |
-
chat_history TEXT,
|
82 |
-
is_snapshot BOOLEAN DEFAULT FALSE,
|
83 |
-
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
84 |
-
FOREIGN KEY (character_id) REFERENCES CharacterCards(id) ON DELETE CASCADE
|
85 |
-
);
|
86 |
-
""")
|
87 |
-
|
88 |
-
# Create FTS5 virtual table for CharacterChats
|
89 |
-
cursor.execute("""
|
90 |
-
CREATE VIRTUAL TABLE IF NOT EXISTS CharacterChats_fts USING fts5(
|
91 |
-
conversation_name,
|
92 |
-
chat_history,
|
93 |
-
content='CharacterChats',
|
94 |
-
content_rowid='id'
|
95 |
-
);
|
96 |
-
""")
|
97 |
-
|
98 |
-
# Create triggers to keep FTS5 table in sync with CharacterChats
|
99 |
-
cursor.executescript("""
|
100 |
-
CREATE TRIGGER IF NOT EXISTS CharacterChats_ai AFTER INSERT ON CharacterChats BEGIN
|
101 |
-
INSERT INTO CharacterChats_fts(rowid, conversation_name, chat_history)
|
102 |
-
VALUES (new.id, new.conversation_name, new.chat_history);
|
103 |
-
END;
|
104 |
-
|
105 |
-
CREATE TRIGGER IF NOT EXISTS CharacterChats_ad AFTER DELETE ON CharacterChats BEGIN
|
106 |
-
DELETE FROM CharacterChats_fts WHERE rowid = old.id;
|
107 |
-
END;
|
108 |
-
|
109 |
-
CREATE TRIGGER IF NOT EXISTS CharacterChats_au AFTER UPDATE ON CharacterChats BEGIN
|
110 |
-
UPDATE CharacterChats_fts SET conversation_name = new.conversation_name, chat_history = new.chat_history
|
111 |
-
WHERE rowid = new.id;
|
112 |
-
END;
|
113 |
-
""")
|
114 |
-
|
115 |
-
# Create ChatKeywords table
|
116 |
-
cursor.execute("""
|
117 |
-
CREATE TABLE IF NOT EXISTS ChatKeywords (
|
118 |
-
chat_id INTEGER NOT NULL,
|
119 |
-
keyword TEXT NOT NULL,
|
120 |
-
FOREIGN KEY (chat_id) REFERENCES CharacterChats(id) ON DELETE CASCADE
|
121 |
-
);
|
122 |
-
""")
|
123 |
-
|
124 |
-
# Create indexes for faster searches
|
125 |
-
cursor.execute("""
|
126 |
-
CREATE INDEX IF NOT EXISTS idx_chatkeywords_keyword ON ChatKeywords(keyword);
|
127 |
-
""")
|
128 |
-
cursor.execute("""
|
129 |
-
CREATE INDEX IF NOT EXISTS idx_chatkeywords_chat_id ON ChatKeywords(chat_id);
|
130 |
-
""")
|
131 |
-
|
132 |
-
conn.commit()
|
133 |
-
logging.info("Database initialized successfully.")
|
134 |
-
except sqlite3.Error as e:
|
135 |
-
logging.error(f"SQLite error occurred during database initialization: {e}")
|
136 |
-
if conn:
|
137 |
-
conn.rollback()
|
138 |
-
raise
|
139 |
-
except Exception as e:
|
140 |
-
logging.error(f"Unexpected error occurred during database initialization: {e}")
|
141 |
-
if conn:
|
142 |
-
conn.rollback()
|
143 |
-
raise
|
144 |
-
finally:
|
145 |
-
if conn:
|
146 |
-
conn.close()
|
147 |
-
|
148 |
-
# Call initialize_database() at the start of your application
|
149 |
-
def setup_chat_database():
|
150 |
-
try:
|
151 |
-
initialize_database()
|
152 |
-
except Exception as e:
|
153 |
-
logging.critical(f"Failed to initialize database: {e}")
|
154 |
-
sys.exit(1)
|
155 |
-
|
156 |
-
setup_chat_database()
|
157 |
-
|
158 |
-
########################################################################################################
|
159 |
-
#
|
160 |
-
# Character Card handling
|
161 |
-
|
162 |
-
def parse_character_card(card_data: Dict[str, Any]) -> Dict[str, Any]:
|
163 |
-
"""Parse and validate a character card according to V2 specification."""
|
164 |
-
v2_data = {
|
165 |
-
'name': card_data.get('name', ''),
|
166 |
-
'description': card_data.get('description', ''),
|
167 |
-
'personality': card_data.get('personality', ''),
|
168 |
-
'scenario': card_data.get('scenario', ''),
|
169 |
-
'first_mes': card_data.get('first_mes', ''),
|
170 |
-
'mes_example': card_data.get('mes_example', ''),
|
171 |
-
'creator_notes': card_data.get('creator_notes', ''),
|
172 |
-
'system_prompt': card_data.get('system_prompt', ''),
|
173 |
-
'post_history_instructions': card_data.get('post_history_instructions', ''),
|
174 |
-
'alternate_greetings': json.dumps(card_data.get('alternate_greetings', [])),
|
175 |
-
'tags': json.dumps(card_data.get('tags', [])),
|
176 |
-
'creator': card_data.get('creator', ''),
|
177 |
-
'character_version': card_data.get('character_version', ''),
|
178 |
-
'extensions': json.dumps(card_data.get('extensions', {}))
|
179 |
-
}
|
180 |
-
|
181 |
-
# Handle 'image' separately as it might be binary data
|
182 |
-
if 'image' in card_data:
|
183 |
-
v2_data['image'] = card_data['image']
|
184 |
-
|
185 |
-
return v2_data
|
186 |
-
|
187 |
-
|
188 |
-
def add_character_card(card_data: Dict[str, Any]) -> Optional[int]:
|
189 |
-
"""Add or update a character card in the database."""
|
190 |
-
conn = sqlite3.connect(chat_DB_PATH)
|
191 |
-
cursor = conn.cursor()
|
192 |
-
try:
|
193 |
-
parsed_card = parse_character_card(card_data)
|
194 |
-
|
195 |
-
# Check if character already exists
|
196 |
-
cursor.execute("SELECT id FROM CharacterCards WHERE name = ?", (parsed_card['name'],))
|
197 |
-
row = cursor.fetchone()
|
198 |
-
|
199 |
-
if row:
|
200 |
-
# Update existing character
|
201 |
-
character_id = row[0]
|
202 |
-
update_query = """
|
203 |
-
UPDATE CharacterCards
|
204 |
-
SET description = ?, personality = ?, scenario = ?, image = ?,
|
205 |
-
post_history_instructions = ?, first_mes = ?, mes_example = ?,
|
206 |
-
creator_notes = ?, system_prompt = ?, alternate_greetings = ?,
|
207 |
-
tags = ?, creator = ?, character_version = ?, extensions = ?
|
208 |
-
WHERE id = ?
|
209 |
-
"""
|
210 |
-
cursor.execute(update_query, (
|
211 |
-
parsed_card['description'], parsed_card['personality'], parsed_card['scenario'],
|
212 |
-
parsed_card['image'], parsed_card['post_history_instructions'], parsed_card['first_mes'],
|
213 |
-
parsed_card['mes_example'], parsed_card['creator_notes'], parsed_card['system_prompt'],
|
214 |
-
parsed_card['alternate_greetings'], parsed_card['tags'], parsed_card['creator'],
|
215 |
-
parsed_card['character_version'], parsed_card['extensions'], character_id
|
216 |
-
))
|
217 |
-
else:
|
218 |
-
# Insert new character
|
219 |
-
insert_query = """
|
220 |
-
INSERT INTO CharacterCards (name, description, personality, scenario, image,
|
221 |
-
post_history_instructions, first_mes, mes_example, creator_notes, system_prompt,
|
222 |
-
alternate_greetings, tags, creator, character_version, extensions)
|
223 |
-
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
224 |
-
"""
|
225 |
-
cursor.execute(insert_query, (
|
226 |
-
parsed_card['name'], parsed_card['description'], parsed_card['personality'],
|
227 |
-
parsed_card['scenario'], parsed_card['image'], parsed_card['post_history_instructions'],
|
228 |
-
parsed_card['first_mes'], parsed_card['mes_example'], parsed_card['creator_notes'],
|
229 |
-
parsed_card['system_prompt'], parsed_card['alternate_greetings'], parsed_card['tags'],
|
230 |
-
parsed_card['creator'], parsed_card['character_version'], parsed_card['extensions']
|
231 |
-
))
|
232 |
-
character_id = cursor.lastrowid
|
233 |
-
|
234 |
-
conn.commit()
|
235 |
-
return character_id
|
236 |
-
except sqlite3.IntegrityError as e:
|
237 |
-
logging.error(f"Error adding character card: {e}")
|
238 |
-
return None
|
239 |
-
except Exception as e:
|
240 |
-
logging.error(f"Unexpected error adding character card: {e}")
|
241 |
-
return None
|
242 |
-
finally:
|
243 |
-
conn.close()
|
244 |
-
|
245 |
-
# def add_character_card(card_data: Dict) -> Optional[int]:
|
246 |
-
# """Add or update a character card in the database.
|
247 |
-
#
|
248 |
-
# Returns the ID of the inserted character or None if failed.
|
249 |
-
# """
|
250 |
-
# conn = sqlite3.connect(chat_DB_PATH)
|
251 |
-
# cursor = conn.cursor()
|
252 |
-
# try:
|
253 |
-
# # Ensure all required fields are present
|
254 |
-
# required_fields = ['name', 'description', 'personality', 'scenario', 'image', 'post_history_instructions', 'first_message']
|
255 |
-
# for field in required_fields:
|
256 |
-
# if field not in card_data:
|
257 |
-
# card_data[field] = '' # Assign empty string if field is missing
|
258 |
-
#
|
259 |
-
# # Check if character already exists
|
260 |
-
# cursor.execute("SELECT id FROM CharacterCards WHERE name = ?", (card_data['name'],))
|
261 |
-
# row = cursor.fetchone()
|
262 |
-
#
|
263 |
-
# if row:
|
264 |
-
# # Update existing character
|
265 |
-
# character_id = row[0]
|
266 |
-
# cursor.execute("""
|
267 |
-
# UPDATE CharacterCards
|
268 |
-
# SET description = ?, personality = ?, scenario = ?, image = ?, post_history_instructions = ?, first_message = ?
|
269 |
-
# WHERE id = ?
|
270 |
-
# """, (
|
271 |
-
# card_data['description'],
|
272 |
-
# card_data['personality'],
|
273 |
-
# card_data['scenario'],
|
274 |
-
# card_data['image'],
|
275 |
-
# card_data['post_history_instructions'],
|
276 |
-
# card_data['first_message'],
|
277 |
-
# character_id
|
278 |
-
# ))
|
279 |
-
# else:
|
280 |
-
# # Insert new character
|
281 |
-
# cursor.execute("""
|
282 |
-
# INSERT INTO CharacterCards (name, description, personality, scenario, image, post_history_instructions, first_message)
|
283 |
-
# VALUES (?, ?, ?, ?, ?, ?, ?)
|
284 |
-
# """, (
|
285 |
-
# card_data['name'],
|
286 |
-
# card_data['description'],
|
287 |
-
# card_data['personality'],
|
288 |
-
# card_data['scenario'],
|
289 |
-
# card_data['image'],
|
290 |
-
# card_data['post_history_instructions'],
|
291 |
-
# card_data['first_message']
|
292 |
-
# ))
|
293 |
-
# character_id = cursor.lastrowid
|
294 |
-
#
|
295 |
-
# conn.commit()
|
296 |
-
# return cursor.lastrowid
|
297 |
-
# except sqlite3.IntegrityError as e:
|
298 |
-
# logging.error(f"Error adding character card: {e}")
|
299 |
-
# return None
|
300 |
-
# except Exception as e:
|
301 |
-
# logging.error(f"Unexpected error adding character card: {e}")
|
302 |
-
# return None
|
303 |
-
# finally:
|
304 |
-
# conn.close()
|
305 |
-
|
306 |
-
|
307 |
-
def get_character_cards() -> List[Dict]:
|
308 |
-
"""Retrieve all character cards from the database."""
|
309 |
-
logging.debug(f"Fetching characters from DB: {chat_DB_PATH}")
|
310 |
-
conn = sqlite3.connect(chat_DB_PATH)
|
311 |
-
cursor = conn.cursor()
|
312 |
-
cursor.execute("SELECT * FROM CharacterCards")
|
313 |
-
rows = cursor.fetchall()
|
314 |
-
columns = [description[0] for description in cursor.description]
|
315 |
-
conn.close()
|
316 |
-
characters = [dict(zip(columns, row)) for row in rows]
|
317 |
-
#logging.debug(f"Characters fetched from DB: {characters}")
|
318 |
-
return characters
|
319 |
-
|
320 |
-
|
321 |
-
def get_character_card_by_id(character_id: Union[int, Dict[str, Any]]) -> Optional[Dict[str, Any]]:
|
322 |
-
"""
|
323 |
-
Retrieve a single character card by its ID.
|
324 |
-
|
325 |
-
Args:
|
326 |
-
character_id: Can be either an integer ID or a dictionary containing character data.
|
327 |
-
|
328 |
-
Returns:
|
329 |
-
A dictionary containing the character card data, or None if not found.
|
330 |
-
"""
|
331 |
-
conn = sqlite3.connect(chat_DB_PATH)
|
332 |
-
cursor = conn.cursor()
|
333 |
-
try:
|
334 |
-
if isinstance(character_id, dict):
|
335 |
-
# If a dictionary is passed, assume it's already a character card
|
336 |
-
return character_id
|
337 |
-
elif isinstance(character_id, int):
|
338 |
-
# If an integer is passed, fetch the character from the database
|
339 |
-
cursor.execute("SELECT * FROM CharacterCards WHERE id = ?", (character_id,))
|
340 |
-
row = cursor.fetchone()
|
341 |
-
if row:
|
342 |
-
columns = [description[0] for description in cursor.description]
|
343 |
-
return dict(zip(columns, row))
|
344 |
-
else:
|
345 |
-
logging.warning(f"Invalid type for character_id: {type(character_id)}")
|
346 |
-
return None
|
347 |
-
except Exception as e:
|
348 |
-
logging.error(f"Error in get_character_card_by_id: {e}")
|
349 |
-
return None
|
350 |
-
finally:
|
351 |
-
conn.close()
|
352 |
-
|
353 |
-
|
354 |
-
def update_character_card(character_id: int, card_data: Dict) -> bool:
|
355 |
-
"""Update an existing character card."""
|
356 |
-
conn = sqlite3.connect(chat_DB_PATH)
|
357 |
-
cursor = conn.cursor()
|
358 |
-
try:
|
359 |
-
cursor.execute("""
|
360 |
-
UPDATE CharacterCards
|
361 |
-
SET name = ?, description = ?, personality = ?, scenario = ?, image = ?, post_history_instructions = ?, first_message = ?
|
362 |
-
WHERE id = ?
|
363 |
-
""", (
|
364 |
-
card_data.get('name'),
|
365 |
-
card_data.get('description'),
|
366 |
-
card_data.get('personality'),
|
367 |
-
card_data.get('scenario'),
|
368 |
-
card_data.get('image'),
|
369 |
-
card_data.get('post_history_instructions', ''),
|
370 |
-
card_data.get('first_message', "Hello! I'm ready to chat."),
|
371 |
-
character_id
|
372 |
-
))
|
373 |
-
conn.commit()
|
374 |
-
return cursor.rowcount > 0
|
375 |
-
except sqlite3.IntegrityError as e:
|
376 |
-
logging.error(f"Error updating character card: {e}")
|
377 |
-
return False
|
378 |
-
finally:
|
379 |
-
conn.close()
|
380 |
-
|
381 |
-
|
382 |
-
def delete_character_card(character_id: int) -> bool:
|
383 |
-
"""Delete a character card and its associated chats."""
|
384 |
-
conn = sqlite3.connect(chat_DB_PATH)
|
385 |
-
cursor = conn.cursor()
|
386 |
-
try:
|
387 |
-
# Delete associated chats first due to foreign key constraint
|
388 |
-
cursor.execute("DELETE FROM CharacterChats WHERE character_id = ?", (character_id,))
|
389 |
-
cursor.execute("DELETE FROM CharacterCards WHERE id = ?", (character_id,))
|
390 |
-
conn.commit()
|
391 |
-
return cursor.rowcount > 0
|
392 |
-
except sqlite3.Error as e:
|
393 |
-
logging.error(f"Error deleting character card: {e}")
|
394 |
-
return False
|
395 |
-
finally:
|
396 |
-
conn.close()
|
397 |
-
|
398 |
-
|
399 |
-
def add_character_chat(character_id: int, conversation_name: str, chat_history: List[Tuple[str, str]], keywords: Optional[List[str]] = None, is_snapshot: bool = False) -> Optional[int]:
|
400 |
-
"""
|
401 |
-
Add a new chat history for a character, optionally associating keywords.
|
402 |
-
|
403 |
-
Args:
|
404 |
-
character_id (int): The ID of the character.
|
405 |
-
conversation_name (str): Name of the conversation.
|
406 |
-
chat_history (List[Tuple[str, str]]): List of (user, bot) message tuples.
|
407 |
-
keywords (Optional[List[str]]): List of keywords to associate with this chat.
|
408 |
-
is_snapshot (bool, optional): Whether this chat is a snapshot.
|
409 |
-
|
410 |
-
Returns:
|
411 |
-
Optional[int]: The ID of the inserted chat or None if failed.
|
412 |
-
"""
|
413 |
-
conn = sqlite3.connect(chat_DB_PATH)
|
414 |
-
cursor = conn.cursor()
|
415 |
-
try:
|
416 |
-
chat_history_json = json.dumps(chat_history)
|
417 |
-
cursor.execute("""
|
418 |
-
INSERT INTO CharacterChats (character_id, conversation_name, chat_history, is_snapshot)
|
419 |
-
VALUES (?, ?, ?, ?)
|
420 |
-
""", (
|
421 |
-
character_id,
|
422 |
-
conversation_name,
|
423 |
-
chat_history_json,
|
424 |
-
is_snapshot
|
425 |
-
))
|
426 |
-
chat_id = cursor.lastrowid
|
427 |
-
|
428 |
-
if keywords:
|
429 |
-
# Insert keywords into ChatKeywords table
|
430 |
-
keyword_records = [(chat_id, keyword.strip().lower()) for keyword in keywords]
|
431 |
-
cursor.executemany("""
|
432 |
-
INSERT INTO ChatKeywords (chat_id, keyword)
|
433 |
-
VALUES (?, ?)
|
434 |
-
""", keyword_records)
|
435 |
-
|
436 |
-
conn.commit()
|
437 |
-
return chat_id
|
438 |
-
except sqlite3.Error as e:
|
439 |
-
logging.error(f"Error adding character chat: {e}")
|
440 |
-
return None
|
441 |
-
finally:
|
442 |
-
conn.close()
|
443 |
-
|
444 |
-
|
445 |
-
def get_character_chats(character_id: Optional[int] = None) -> List[Dict]:
|
446 |
-
"""Retrieve all chats, or chats for a specific character if character_id is provided."""
|
447 |
-
conn = sqlite3.connect(chat_DB_PATH)
|
448 |
-
cursor = conn.cursor()
|
449 |
-
if character_id is not None:
|
450 |
-
cursor.execute("SELECT * FROM CharacterChats WHERE character_id = ?", (character_id,))
|
451 |
-
else:
|
452 |
-
cursor.execute("SELECT * FROM CharacterChats")
|
453 |
-
rows = cursor.fetchall()
|
454 |
-
columns = [description[0] for description in cursor.description]
|
455 |
-
conn.close()
|
456 |
-
return [dict(zip(columns, row)) for row in rows]
|
457 |
-
|
458 |
-
|
459 |
-
def get_character_chat_by_id(chat_id: int) -> Optional[Dict]:
|
460 |
-
"""Retrieve a single chat by its ID."""
|
461 |
-
conn = sqlite3.connect(chat_DB_PATH)
|
462 |
-
cursor = conn.cursor()
|
463 |
-
cursor.execute("SELECT * FROM CharacterChats WHERE id = ?", (chat_id,))
|
464 |
-
row = cursor.fetchone()
|
465 |
-
conn.close()
|
466 |
-
if row:
|
467 |
-
columns = [description[0] for description in cursor.description]
|
468 |
-
chat = dict(zip(columns, row))
|
469 |
-
chat['chat_history'] = json.loads(chat['chat_history'])
|
470 |
-
return chat
|
471 |
-
return None
|
472 |
-
|
473 |
-
|
474 |
-
def search_character_chats(query: str, character_id: Optional[int] = None) -> Tuple[List[Dict], str]:
|
475 |
-
"""
|
476 |
-
Search for character chats using FTS5, optionally filtered by character_id.
|
477 |
-
|
478 |
-
Args:
|
479 |
-
query (str): The search query.
|
480 |
-
character_id (Optional[int]): The ID of the character to filter chats by.
|
481 |
-
|
482 |
-
Returns:
|
483 |
-
Tuple[List[Dict], str]: A list of matching chats and a status message.
|
484 |
-
"""
|
485 |
-
if not query.strip():
|
486 |
-
return [], "Please enter a search query."
|
487 |
-
|
488 |
-
conn = sqlite3.connect(chat_DB_PATH)
|
489 |
-
cursor = conn.cursor()
|
490 |
-
try:
|
491 |
-
if character_id is not None:
|
492 |
-
# Search with character_id filter
|
493 |
-
cursor.execute("""
|
494 |
-
SELECT CharacterChats.id, CharacterChats.conversation_name, CharacterChats.chat_history
|
495 |
-
FROM CharacterChats_fts
|
496 |
-
JOIN CharacterChats ON CharacterChats_fts.rowid = CharacterChats.id
|
497 |
-
WHERE CharacterChats_fts MATCH ? AND CharacterChats.character_id = ?
|
498 |
-
ORDER BY rank
|
499 |
-
""", (query, character_id))
|
500 |
-
else:
|
501 |
-
# Search without character_id filter
|
502 |
-
cursor.execute("""
|
503 |
-
SELECT CharacterChats.id, CharacterChats.conversation_name, CharacterChats.chat_history
|
504 |
-
FROM CharacterChats_fts
|
505 |
-
JOIN CharacterChats ON CharacterChats_fts.rowid = CharacterChats.id
|
506 |
-
WHERE CharacterChats_fts MATCH ?
|
507 |
-
ORDER BY rank
|
508 |
-
""", (query,))
|
509 |
-
|
510 |
-
rows = cursor.fetchall()
|
511 |
-
columns = [description[0] for description in cursor.description]
|
512 |
-
results = [dict(zip(columns, row)) for row in rows]
|
513 |
-
|
514 |
-
if character_id is not None:
|
515 |
-
status_message = f"Found {len(results)} chat(s) matching '{query}' for the selected character."
|
516 |
-
else:
|
517 |
-
status_message = f"Found {len(results)} chat(s) matching '{query}' across all characters."
|
518 |
-
|
519 |
-
return results, status_message
|
520 |
-
except Exception as e:
|
521 |
-
logging.error(f"Error searching chats with FTS5: {e}")
|
522 |
-
return [], f"Error occurred during search: {e}"
|
523 |
-
finally:
|
524 |
-
conn.close()
|
525 |
-
|
526 |
-
def update_character_chat(chat_id: int, chat_history: List[Tuple[str, str]]) -> bool:
|
527 |
-
"""Update an existing chat history."""
|
528 |
-
conn = sqlite3.connect(chat_DB_PATH)
|
529 |
-
cursor = conn.cursor()
|
530 |
-
try:
|
531 |
-
chat_history_json = json.dumps(chat_history)
|
532 |
-
cursor.execute("""
|
533 |
-
UPDATE CharacterChats
|
534 |
-
SET chat_history = ?
|
535 |
-
WHERE id = ?
|
536 |
-
""", (
|
537 |
-
chat_history_json,
|
538 |
-
chat_id
|
539 |
-
))
|
540 |
-
conn.commit()
|
541 |
-
return cursor.rowcount > 0
|
542 |
-
except sqlite3.Error as e:
|
543 |
-
logging.error(f"Error updating character chat: {e}")
|
544 |
-
return False
|
545 |
-
finally:
|
546 |
-
conn.close()
|
547 |
-
|
548 |
-
|
549 |
-
def delete_character_chat(chat_id: int) -> bool:
|
550 |
-
"""Delete a specific chat."""
|
551 |
-
conn = sqlite3.connect(chat_DB_PATH)
|
552 |
-
cursor = conn.cursor()
|
553 |
-
try:
|
554 |
-
cursor.execute("DELETE FROM CharacterChats WHERE id = ?", (chat_id,))
|
555 |
-
conn.commit()
|
556 |
-
return cursor.rowcount > 0
|
557 |
-
except sqlite3.Error as e:
|
558 |
-
logging.error(f"Error deleting character chat: {e}")
|
559 |
-
return False
|
560 |
-
finally:
|
561 |
-
conn.close()
|
562 |
-
|
563 |
-
def fetch_keywords_for_chats(keywords: List[str]) -> List[int]:
|
564 |
-
"""
|
565 |
-
Fetch chat IDs associated with any of the specified keywords.
|
566 |
-
|
567 |
-
Args:
|
568 |
-
keywords (List[str]): List of keywords to search for.
|
569 |
-
|
570 |
-
Returns:
|
571 |
-
List[int]: List of chat IDs associated with the keywords.
|
572 |
-
"""
|
573 |
-
if not keywords:
|
574 |
-
return []
|
575 |
-
|
576 |
-
conn = sqlite3.connect(chat_DB_PATH)
|
577 |
-
cursor = conn.cursor()
|
578 |
-
try:
|
579 |
-
# Construct the WHERE clause to search for each keyword
|
580 |
-
keyword_clauses = " OR ".join(["keyword = ?"] * len(keywords))
|
581 |
-
sql_query = f"SELECT DISTINCT chat_id FROM ChatKeywords WHERE {keyword_clauses}"
|
582 |
-
cursor.execute(sql_query, keywords)
|
583 |
-
rows = cursor.fetchall()
|
584 |
-
chat_ids = [row[0] for row in rows]
|
585 |
-
return chat_ids
|
586 |
-
except Exception as e:
|
587 |
-
logging.error(f"Error in fetch_keywords_for_chats: {e}")
|
588 |
-
return []
|
589 |
-
finally:
|
590 |
-
conn.close()
|
591 |
-
|
592 |
-
def save_chat_history_to_character_db(character_id: int, conversation_name: str, chat_history: List[Tuple[str, str]]) -> Optional[int]:
|
593 |
-
"""Save chat history to the CharacterChats table.
|
594 |
-
|
595 |
-
Returns the ID of the inserted chat or None if failed.
|
596 |
-
"""
|
597 |
-
return add_character_chat(character_id, conversation_name, chat_history)
|
598 |
-
|
599 |
-
def migrate_chat_to_media_db():
|
600 |
-
pass
|
601 |
-
|
602 |
-
|
603 |
-
def search_db(query: str, fields: List[str], where_clause: str = "", page: int = 1, results_per_page: int = 5) -> List[Dict[str, Any]]:
|
604 |
-
"""
|
605 |
-
Perform a full-text search on specified fields with optional filtering and pagination.
|
606 |
-
|
607 |
-
Args:
|
608 |
-
query (str): The search query.
|
609 |
-
fields (List[str]): List of fields to search in.
|
610 |
-
where_clause (str, optional): Additional SQL WHERE clause to filter results.
|
611 |
-
page (int, optional): Page number for pagination.
|
612 |
-
results_per_page (int, optional): Number of results per page.
|
613 |
-
|
614 |
-
Returns:
|
615 |
-
List[Dict[str, Any]]: List of matching chat records with content and metadata.
|
616 |
-
"""
|
617 |
-
if not query.strip():
|
618 |
-
return []
|
619 |
-
|
620 |
-
conn = sqlite3.connect(chat_DB_PATH)
|
621 |
-
cursor = conn.cursor()
|
622 |
-
try:
|
623 |
-
# Construct the MATCH query for FTS5
|
624 |
-
match_query = " AND ".join(fields) + f" MATCH ?"
|
625 |
-
# Adjust the query with the fields
|
626 |
-
fts_query = f"""
|
627 |
-
SELECT CharacterChats.id, CharacterChats.conversation_name, CharacterChats.chat_history
|
628 |
-
FROM CharacterChats_fts
|
629 |
-
JOIN CharacterChats ON CharacterChats_fts.rowid = CharacterChats.id
|
630 |
-
WHERE {match_query}
|
631 |
-
"""
|
632 |
-
if where_clause:
|
633 |
-
fts_query += f" AND ({where_clause})"
|
634 |
-
fts_query += " ORDER BY rank LIMIT ? OFFSET ?"
|
635 |
-
offset = (page - 1) * results_per_page
|
636 |
-
cursor.execute(fts_query, (query, results_per_page, offset))
|
637 |
-
rows = cursor.fetchall()
|
638 |
-
columns = [description[0] for description in cursor.description]
|
639 |
-
results = [dict(zip(columns, row)) for row in rows]
|
640 |
-
return results
|
641 |
-
except Exception as e:
|
642 |
-
logging.error(f"Error in search_db: {e}")
|
643 |
-
return []
|
644 |
-
finally:
|
645 |
-
conn.close()
|
646 |
-
|
647 |
-
|
648 |
-
def perform_full_text_search_chat(query: str, relevant_chat_ids: List[int], page: int = 1, results_per_page: int = 5) -> \
|
649 |
-
List[Dict[str, Any]]:
|
650 |
-
"""
|
651 |
-
Perform a full-text search within the specified chat IDs using FTS5.
|
652 |
-
|
653 |
-
Args:
|
654 |
-
query (str): The user's query.
|
655 |
-
relevant_chat_ids (List[int]): List of chat IDs to search within.
|
656 |
-
page (int): Pagination page number.
|
657 |
-
results_per_page (int): Number of results per page.
|
658 |
-
|
659 |
-
Returns:
|
660 |
-
List[Dict[str, Any]]: List of search results with content and metadata.
|
661 |
-
"""
|
662 |
-
try:
|
663 |
-
# Construct a WHERE clause to limit the search to relevant chat IDs
|
664 |
-
where_clause = " OR ".join([f"media_id = {chat_id}" for chat_id in relevant_chat_ids])
|
665 |
-
if not where_clause:
|
666 |
-
where_clause = "1" # No restriction if no chat IDs
|
667 |
-
|
668 |
-
# Perform full-text search using FTS5
|
669 |
-
fts_results = search_db(query, ["content"], where_clause, page=page, results_per_page=results_per_page)
|
670 |
-
|
671 |
-
filtered_fts_results = [
|
672 |
-
{
|
673 |
-
"content": result['content'],
|
674 |
-
"metadata": {"media_id": result['id']}
|
675 |
-
}
|
676 |
-
for result in fts_results
|
677 |
-
if result['id'] in relevant_chat_ids
|
678 |
-
]
|
679 |
-
return filtered_fts_results
|
680 |
-
except Exception as e:
|
681 |
-
logging.error(f"Error in perform_full_text_search_chat: {str(e)}")
|
682 |
-
return []
|
683 |
-
|
684 |
-
|
685 |
-
def fetch_all_chats() -> List[Dict[str, Any]]:
|
686 |
-
"""
|
687 |
-
Fetch all chat messages from the database.
|
688 |
-
|
689 |
-
Returns:
|
690 |
-
List[Dict[str, Any]]: List of chat messages with relevant metadata.
|
691 |
-
"""
|
692 |
-
try:
|
693 |
-
chats = get_character_chats() # Modify this function to retrieve all chats
|
694 |
-
return chats
|
695 |
-
except Exception as e:
|
696 |
-
logging.error(f"Error fetching all chats: {str(e)}")
|
697 |
-
return []
|
698 |
-
|
699 |
-
#
|
700 |
-
# End of Character_Chat_DB.py
|
701 |
-
#######################################################################################################################
|
|
|
1 |
+
# character_chat_db.py
|
2 |
+
# Database functions for managing character cards and chat histories.
|
3 |
+
# #
|
4 |
+
# Imports
|
5 |
+
import configparser
|
6 |
+
import sqlite3
|
7 |
+
import json
|
8 |
+
import os
|
9 |
+
import sys
|
10 |
+
from typing import List, Dict, Optional, Tuple, Any, Union
|
11 |
+
|
12 |
+
from App_Function_Libraries.Utils.Utils import get_database_dir, get_project_relative_path, get_database_path
|
13 |
+
from Tests.Chat_APIs.Chat_APIs_Integration_test import logging
|
14 |
+
|
15 |
+
#
|
16 |
+
#######################################################################################################################
|
17 |
+
#
|
18 |
+
#
|
19 |
+
|
20 |
+
def ensure_database_directory():
|
21 |
+
os.makedirs(get_database_dir(), exist_ok=True)
|
22 |
+
|
23 |
+
ensure_database_directory()
|
24 |
+
|
25 |
+
|
26 |
+
# Construct the path to the config file
|
27 |
+
config_path = get_project_relative_path('Config_Files/config.txt')
|
28 |
+
|
29 |
+
# Read the config file
|
30 |
+
config = configparser.ConfigParser()
|
31 |
+
config.read(config_path)
|
32 |
+
|
33 |
+
# Get the chat db path from the config, or use the default if not specified
|
34 |
+
chat_DB_PATH = config.get('Database', 'chatDB_path', fallback=get_database_path('chatDB.db'))
|
35 |
+
print(f"Chat Database path: {chat_DB_PATH}")
|
36 |
+
|
37 |
+
########################################################################################################
|
38 |
+
#
|
39 |
+
# Functions
|
40 |
+
|
41 |
+
# FIXME - Setup properly and test/add documentation for its existence...
|
42 |
+
def initialize_database():
|
43 |
+
"""Initialize the SQLite database with required tables and FTS5 virtual tables."""
|
44 |
+
conn = None
|
45 |
+
try:
|
46 |
+
conn = sqlite3.connect(chat_DB_PATH)
|
47 |
+
cursor = conn.cursor()
|
48 |
+
|
49 |
+
# Enable foreign key constraints
|
50 |
+
cursor.execute("PRAGMA foreign_keys = ON;")
|
51 |
+
|
52 |
+
# Create CharacterCards table with V2 fields
|
53 |
+
cursor.execute("""
|
54 |
+
CREATE TABLE IF NOT EXISTS CharacterCards (
|
55 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
56 |
+
name TEXT UNIQUE NOT NULL,
|
57 |
+
description TEXT,
|
58 |
+
personality TEXT,
|
59 |
+
scenario TEXT,
|
60 |
+
image BLOB,
|
61 |
+
post_history_instructions TEXT,
|
62 |
+
first_mes TEXT,
|
63 |
+
mes_example TEXT,
|
64 |
+
creator_notes TEXT,
|
65 |
+
system_prompt TEXT,
|
66 |
+
alternate_greetings TEXT,
|
67 |
+
tags TEXT,
|
68 |
+
creator TEXT,
|
69 |
+
character_version TEXT,
|
70 |
+
extensions TEXT,
|
71 |
+
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
72 |
+
);
|
73 |
+
""")
|
74 |
+
|
75 |
+
# Create CharacterChats table
|
76 |
+
cursor.execute("""
|
77 |
+
CREATE TABLE IF NOT EXISTS CharacterChats (
|
78 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
79 |
+
character_id INTEGER NOT NULL,
|
80 |
+
conversation_name TEXT,
|
81 |
+
chat_history TEXT,
|
82 |
+
is_snapshot BOOLEAN DEFAULT FALSE,
|
83 |
+
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
84 |
+
FOREIGN KEY (character_id) REFERENCES CharacterCards(id) ON DELETE CASCADE
|
85 |
+
);
|
86 |
+
""")
|
87 |
+
|
88 |
+
# Create FTS5 virtual table for CharacterChats
|
89 |
+
cursor.execute("""
|
90 |
+
CREATE VIRTUAL TABLE IF NOT EXISTS CharacterChats_fts USING fts5(
|
91 |
+
conversation_name,
|
92 |
+
chat_history,
|
93 |
+
content='CharacterChats',
|
94 |
+
content_rowid='id'
|
95 |
+
);
|
96 |
+
""")
|
97 |
+
|
98 |
+
# Create triggers to keep FTS5 table in sync with CharacterChats
|
99 |
+
cursor.executescript("""
|
100 |
+
CREATE TRIGGER IF NOT EXISTS CharacterChats_ai AFTER INSERT ON CharacterChats BEGIN
|
101 |
+
INSERT INTO CharacterChats_fts(rowid, conversation_name, chat_history)
|
102 |
+
VALUES (new.id, new.conversation_name, new.chat_history);
|
103 |
+
END;
|
104 |
+
|
105 |
+
CREATE TRIGGER IF NOT EXISTS CharacterChats_ad AFTER DELETE ON CharacterChats BEGIN
|
106 |
+
DELETE FROM CharacterChats_fts WHERE rowid = old.id;
|
107 |
+
END;
|
108 |
+
|
109 |
+
CREATE TRIGGER IF NOT EXISTS CharacterChats_au AFTER UPDATE ON CharacterChats BEGIN
|
110 |
+
UPDATE CharacterChats_fts SET conversation_name = new.conversation_name, chat_history = new.chat_history
|
111 |
+
WHERE rowid = new.id;
|
112 |
+
END;
|
113 |
+
""")
|
114 |
+
|
115 |
+
# Create ChatKeywords table
|
116 |
+
cursor.execute("""
|
117 |
+
CREATE TABLE IF NOT EXISTS ChatKeywords (
|
118 |
+
chat_id INTEGER NOT NULL,
|
119 |
+
keyword TEXT NOT NULL,
|
120 |
+
FOREIGN KEY (chat_id) REFERENCES CharacterChats(id) ON DELETE CASCADE
|
121 |
+
);
|
122 |
+
""")
|
123 |
+
|
124 |
+
# Create indexes for faster searches
|
125 |
+
cursor.execute("""
|
126 |
+
CREATE INDEX IF NOT EXISTS idx_chatkeywords_keyword ON ChatKeywords(keyword);
|
127 |
+
""")
|
128 |
+
cursor.execute("""
|
129 |
+
CREATE INDEX IF NOT EXISTS idx_chatkeywords_chat_id ON ChatKeywords(chat_id);
|
130 |
+
""")
|
131 |
+
|
132 |
+
conn.commit()
|
133 |
+
logging.info("Database initialized successfully.")
|
134 |
+
except sqlite3.Error as e:
|
135 |
+
logging.error(f"SQLite error occurred during database initialization: {e}")
|
136 |
+
if conn:
|
137 |
+
conn.rollback()
|
138 |
+
raise
|
139 |
+
except Exception as e:
|
140 |
+
logging.error(f"Unexpected error occurred during database initialization: {e}")
|
141 |
+
if conn:
|
142 |
+
conn.rollback()
|
143 |
+
raise
|
144 |
+
finally:
|
145 |
+
if conn:
|
146 |
+
conn.close()
|
147 |
+
|
148 |
+
# Call initialize_database() at the start of your application
|
149 |
+
def setup_chat_database():
|
150 |
+
try:
|
151 |
+
initialize_database()
|
152 |
+
except Exception as e:
|
153 |
+
logging.critical(f"Failed to initialize database: {e}")
|
154 |
+
sys.exit(1)
|
155 |
+
|
156 |
+
setup_chat_database()
|
157 |
+
|
158 |
+
########################################################################################################
|
159 |
+
#
|
160 |
+
# Character Card handling
|
161 |
+
|
162 |
+
def parse_character_card(card_data: Dict[str, Any]) -> Dict[str, Any]:
|
163 |
+
"""Parse and validate a character card according to V2 specification."""
|
164 |
+
v2_data = {
|
165 |
+
'name': card_data.get('name', ''),
|
166 |
+
'description': card_data.get('description', ''),
|
167 |
+
'personality': card_data.get('personality', ''),
|
168 |
+
'scenario': card_data.get('scenario', ''),
|
169 |
+
'first_mes': card_data.get('first_mes', ''),
|
170 |
+
'mes_example': card_data.get('mes_example', ''),
|
171 |
+
'creator_notes': card_data.get('creator_notes', ''),
|
172 |
+
'system_prompt': card_data.get('system_prompt', ''),
|
173 |
+
'post_history_instructions': card_data.get('post_history_instructions', ''),
|
174 |
+
'alternate_greetings': json.dumps(card_data.get('alternate_greetings', [])),
|
175 |
+
'tags': json.dumps(card_data.get('tags', [])),
|
176 |
+
'creator': card_data.get('creator', ''),
|
177 |
+
'character_version': card_data.get('character_version', ''),
|
178 |
+
'extensions': json.dumps(card_data.get('extensions', {}))
|
179 |
+
}
|
180 |
+
|
181 |
+
# Handle 'image' separately as it might be binary data
|
182 |
+
if 'image' in card_data:
|
183 |
+
v2_data['image'] = card_data['image']
|
184 |
+
|
185 |
+
return v2_data
|
186 |
+
|
187 |
+
|
188 |
+
def add_character_card(card_data: Dict[str, Any]) -> Optional[int]:
|
189 |
+
"""Add or update a character card in the database."""
|
190 |
+
conn = sqlite3.connect(chat_DB_PATH)
|
191 |
+
cursor = conn.cursor()
|
192 |
+
try:
|
193 |
+
parsed_card = parse_character_card(card_data)
|
194 |
+
|
195 |
+
# Check if character already exists
|
196 |
+
cursor.execute("SELECT id FROM CharacterCards WHERE name = ?", (parsed_card['name'],))
|
197 |
+
row = cursor.fetchone()
|
198 |
+
|
199 |
+
if row:
|
200 |
+
# Update existing character
|
201 |
+
character_id = row[0]
|
202 |
+
update_query = """
|
203 |
+
UPDATE CharacterCards
|
204 |
+
SET description = ?, personality = ?, scenario = ?, image = ?,
|
205 |
+
post_history_instructions = ?, first_mes = ?, mes_example = ?,
|
206 |
+
creator_notes = ?, system_prompt = ?, alternate_greetings = ?,
|
207 |
+
tags = ?, creator = ?, character_version = ?, extensions = ?
|
208 |
+
WHERE id = ?
|
209 |
+
"""
|
210 |
+
cursor.execute(update_query, (
|
211 |
+
parsed_card['description'], parsed_card['personality'], parsed_card['scenario'],
|
212 |
+
parsed_card['image'], parsed_card['post_history_instructions'], parsed_card['first_mes'],
|
213 |
+
parsed_card['mes_example'], parsed_card['creator_notes'], parsed_card['system_prompt'],
|
214 |
+
parsed_card['alternate_greetings'], parsed_card['tags'], parsed_card['creator'],
|
215 |
+
parsed_card['character_version'], parsed_card['extensions'], character_id
|
216 |
+
))
|
217 |
+
else:
|
218 |
+
# Insert new character
|
219 |
+
insert_query = """
|
220 |
+
INSERT INTO CharacterCards (name, description, personality, scenario, image,
|
221 |
+
post_history_instructions, first_mes, mes_example, creator_notes, system_prompt,
|
222 |
+
alternate_greetings, tags, creator, character_version, extensions)
|
223 |
+
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
224 |
+
"""
|
225 |
+
cursor.execute(insert_query, (
|
226 |
+
parsed_card['name'], parsed_card['description'], parsed_card['personality'],
|
227 |
+
parsed_card['scenario'], parsed_card['image'], parsed_card['post_history_instructions'],
|
228 |
+
parsed_card['first_mes'], parsed_card['mes_example'], parsed_card['creator_notes'],
|
229 |
+
parsed_card['system_prompt'], parsed_card['alternate_greetings'], parsed_card['tags'],
|
230 |
+
parsed_card['creator'], parsed_card['character_version'], parsed_card['extensions']
|
231 |
+
))
|
232 |
+
character_id = cursor.lastrowid
|
233 |
+
|
234 |
+
conn.commit()
|
235 |
+
return character_id
|
236 |
+
except sqlite3.IntegrityError as e:
|
237 |
+
logging.error(f"Error adding character card: {e}")
|
238 |
+
return None
|
239 |
+
except Exception as e:
|
240 |
+
logging.error(f"Unexpected error adding character card: {e}")
|
241 |
+
return None
|
242 |
+
finally:
|
243 |
+
conn.close()
|
244 |
+
|
245 |
+
# def add_character_card(card_data: Dict) -> Optional[int]:
|
246 |
+
# """Add or update a character card in the database.
|
247 |
+
#
|
248 |
+
# Returns the ID of the inserted character or None if failed.
|
249 |
+
# """
|
250 |
+
# conn = sqlite3.connect(chat_DB_PATH)
|
251 |
+
# cursor = conn.cursor()
|
252 |
+
# try:
|
253 |
+
# # Ensure all required fields are present
|
254 |
+
# required_fields = ['name', 'description', 'personality', 'scenario', 'image', 'post_history_instructions', 'first_message']
|
255 |
+
# for field in required_fields:
|
256 |
+
# if field not in card_data:
|
257 |
+
# card_data[field] = '' # Assign empty string if field is missing
|
258 |
+
#
|
259 |
+
# # Check if character already exists
|
260 |
+
# cursor.execute("SELECT id FROM CharacterCards WHERE name = ?", (card_data['name'],))
|
261 |
+
# row = cursor.fetchone()
|
262 |
+
#
|
263 |
+
# if row:
|
264 |
+
# # Update existing character
|
265 |
+
# character_id = row[0]
|
266 |
+
# cursor.execute("""
|
267 |
+
# UPDATE CharacterCards
|
268 |
+
# SET description = ?, personality = ?, scenario = ?, image = ?, post_history_instructions = ?, first_message = ?
|
269 |
+
# WHERE id = ?
|
270 |
+
# """, (
|
271 |
+
# card_data['description'],
|
272 |
+
# card_data['personality'],
|
273 |
+
# card_data['scenario'],
|
274 |
+
# card_data['image'],
|
275 |
+
# card_data['post_history_instructions'],
|
276 |
+
# card_data['first_message'],
|
277 |
+
# character_id
|
278 |
+
# ))
|
279 |
+
# else:
|
280 |
+
# # Insert new character
|
281 |
+
# cursor.execute("""
|
282 |
+
# INSERT INTO CharacterCards (name, description, personality, scenario, image, post_history_instructions, first_message)
|
283 |
+
# VALUES (?, ?, ?, ?, ?, ?, ?)
|
284 |
+
# """, (
|
285 |
+
# card_data['name'],
|
286 |
+
# card_data['description'],
|
287 |
+
# card_data['personality'],
|
288 |
+
# card_data['scenario'],
|
289 |
+
# card_data['image'],
|
290 |
+
# card_data['post_history_instructions'],
|
291 |
+
# card_data['first_message']
|
292 |
+
# ))
|
293 |
+
# character_id = cursor.lastrowid
|
294 |
+
#
|
295 |
+
# conn.commit()
|
296 |
+
# return cursor.lastrowid
|
297 |
+
# except sqlite3.IntegrityError as e:
|
298 |
+
# logging.error(f"Error adding character card: {e}")
|
299 |
+
# return None
|
300 |
+
# except Exception as e:
|
301 |
+
# logging.error(f"Unexpected error adding character card: {e}")
|
302 |
+
# return None
|
303 |
+
# finally:
|
304 |
+
# conn.close()
|
305 |
+
|
306 |
+
|
307 |
+
def get_character_cards() -> List[Dict]:
|
308 |
+
"""Retrieve all character cards from the database."""
|
309 |
+
logging.debug(f"Fetching characters from DB: {chat_DB_PATH}")
|
310 |
+
conn = sqlite3.connect(chat_DB_PATH)
|
311 |
+
cursor = conn.cursor()
|
312 |
+
cursor.execute("SELECT * FROM CharacterCards")
|
313 |
+
rows = cursor.fetchall()
|
314 |
+
columns = [description[0] for description in cursor.description]
|
315 |
+
conn.close()
|
316 |
+
characters = [dict(zip(columns, row)) for row in rows]
|
317 |
+
#logging.debug(f"Characters fetched from DB: {characters}")
|
318 |
+
return characters
|
319 |
+
|
320 |
+
|
321 |
+
def get_character_card_by_id(character_id: Union[int, Dict[str, Any]]) -> Optional[Dict[str, Any]]:
|
322 |
+
"""
|
323 |
+
Retrieve a single character card by its ID.
|
324 |
+
|
325 |
+
Args:
|
326 |
+
character_id: Can be either an integer ID or a dictionary containing character data.
|
327 |
+
|
328 |
+
Returns:
|
329 |
+
A dictionary containing the character card data, or None if not found.
|
330 |
+
"""
|
331 |
+
conn = sqlite3.connect(chat_DB_PATH)
|
332 |
+
cursor = conn.cursor()
|
333 |
+
try:
|
334 |
+
if isinstance(character_id, dict):
|
335 |
+
# If a dictionary is passed, assume it's already a character card
|
336 |
+
return character_id
|
337 |
+
elif isinstance(character_id, int):
|
338 |
+
# If an integer is passed, fetch the character from the database
|
339 |
+
cursor.execute("SELECT * FROM CharacterCards WHERE id = ?", (character_id,))
|
340 |
+
row = cursor.fetchone()
|
341 |
+
if row:
|
342 |
+
columns = [description[0] for description in cursor.description]
|
343 |
+
return dict(zip(columns, row))
|
344 |
+
else:
|
345 |
+
logging.warning(f"Invalid type for character_id: {type(character_id)}")
|
346 |
+
return None
|
347 |
+
except Exception as e:
|
348 |
+
logging.error(f"Error in get_character_card_by_id: {e}")
|
349 |
+
return None
|
350 |
+
finally:
|
351 |
+
conn.close()
|
352 |
+
|
353 |
+
|
354 |
+
def update_character_card(character_id: int, card_data: Dict) -> bool:
|
355 |
+
"""Update an existing character card."""
|
356 |
+
conn = sqlite3.connect(chat_DB_PATH)
|
357 |
+
cursor = conn.cursor()
|
358 |
+
try:
|
359 |
+
cursor.execute("""
|
360 |
+
UPDATE CharacterCards
|
361 |
+
SET name = ?, description = ?, personality = ?, scenario = ?, image = ?, post_history_instructions = ?, first_message = ?
|
362 |
+
WHERE id = ?
|
363 |
+
""", (
|
364 |
+
card_data.get('name'),
|
365 |
+
card_data.get('description'),
|
366 |
+
card_data.get('personality'),
|
367 |
+
card_data.get('scenario'),
|
368 |
+
card_data.get('image'),
|
369 |
+
card_data.get('post_history_instructions', ''),
|
370 |
+
card_data.get('first_message', "Hello! I'm ready to chat."),
|
371 |
+
character_id
|
372 |
+
))
|
373 |
+
conn.commit()
|
374 |
+
return cursor.rowcount > 0
|
375 |
+
except sqlite3.IntegrityError as e:
|
376 |
+
logging.error(f"Error updating character card: {e}")
|
377 |
+
return False
|
378 |
+
finally:
|
379 |
+
conn.close()
|
380 |
+
|
381 |
+
|
382 |
+
def delete_character_card(character_id: int) -> bool:
|
383 |
+
"""Delete a character card and its associated chats."""
|
384 |
+
conn = sqlite3.connect(chat_DB_PATH)
|
385 |
+
cursor = conn.cursor()
|
386 |
+
try:
|
387 |
+
# Delete associated chats first due to foreign key constraint
|
388 |
+
cursor.execute("DELETE FROM CharacterChats WHERE character_id = ?", (character_id,))
|
389 |
+
cursor.execute("DELETE FROM CharacterCards WHERE id = ?", (character_id,))
|
390 |
+
conn.commit()
|
391 |
+
return cursor.rowcount > 0
|
392 |
+
except sqlite3.Error as e:
|
393 |
+
logging.error(f"Error deleting character card: {e}")
|
394 |
+
return False
|
395 |
+
finally:
|
396 |
+
conn.close()
|
397 |
+
|
398 |
+
|
399 |
+
def add_character_chat(character_id: int, conversation_name: str, chat_history: List[Tuple[str, str]], keywords: Optional[List[str]] = None, is_snapshot: bool = False) -> Optional[int]:
|
400 |
+
"""
|
401 |
+
Add a new chat history for a character, optionally associating keywords.
|
402 |
+
|
403 |
+
Args:
|
404 |
+
character_id (int): The ID of the character.
|
405 |
+
conversation_name (str): Name of the conversation.
|
406 |
+
chat_history (List[Tuple[str, str]]): List of (user, bot) message tuples.
|
407 |
+
keywords (Optional[List[str]]): List of keywords to associate with this chat.
|
408 |
+
is_snapshot (bool, optional): Whether this chat is a snapshot.
|
409 |
+
|
410 |
+
Returns:
|
411 |
+
Optional[int]: The ID of the inserted chat or None if failed.
|
412 |
+
"""
|
413 |
+
conn = sqlite3.connect(chat_DB_PATH)
|
414 |
+
cursor = conn.cursor()
|
415 |
+
try:
|
416 |
+
chat_history_json = json.dumps(chat_history)
|
417 |
+
cursor.execute("""
|
418 |
+
INSERT INTO CharacterChats (character_id, conversation_name, chat_history, is_snapshot)
|
419 |
+
VALUES (?, ?, ?, ?)
|
420 |
+
""", (
|
421 |
+
character_id,
|
422 |
+
conversation_name,
|
423 |
+
chat_history_json,
|
424 |
+
is_snapshot
|
425 |
+
))
|
426 |
+
chat_id = cursor.lastrowid
|
427 |
+
|
428 |
+
if keywords:
|
429 |
+
# Insert keywords into ChatKeywords table
|
430 |
+
keyword_records = [(chat_id, keyword.strip().lower()) for keyword in keywords]
|
431 |
+
cursor.executemany("""
|
432 |
+
INSERT INTO ChatKeywords (chat_id, keyword)
|
433 |
+
VALUES (?, ?)
|
434 |
+
""", keyword_records)
|
435 |
+
|
436 |
+
conn.commit()
|
437 |
+
return chat_id
|
438 |
+
except sqlite3.Error as e:
|
439 |
+
logging.error(f"Error adding character chat: {e}")
|
440 |
+
return None
|
441 |
+
finally:
|
442 |
+
conn.close()
|
443 |
+
|
444 |
+
|
445 |
+
def get_character_chats(character_id: Optional[int] = None) -> List[Dict]:
|
446 |
+
"""Retrieve all chats, or chats for a specific character if character_id is provided."""
|
447 |
+
conn = sqlite3.connect(chat_DB_PATH)
|
448 |
+
cursor = conn.cursor()
|
449 |
+
if character_id is not None:
|
450 |
+
cursor.execute("SELECT * FROM CharacterChats WHERE character_id = ?", (character_id,))
|
451 |
+
else:
|
452 |
+
cursor.execute("SELECT * FROM CharacterChats")
|
453 |
+
rows = cursor.fetchall()
|
454 |
+
columns = [description[0] for description in cursor.description]
|
455 |
+
conn.close()
|
456 |
+
return [dict(zip(columns, row)) for row in rows]
|
457 |
+
|
458 |
+
|
459 |
+
def get_character_chat_by_id(chat_id: int) -> Optional[Dict]:
|
460 |
+
"""Retrieve a single chat by its ID."""
|
461 |
+
conn = sqlite3.connect(chat_DB_PATH)
|
462 |
+
cursor = conn.cursor()
|
463 |
+
cursor.execute("SELECT * FROM CharacterChats WHERE id = ?", (chat_id,))
|
464 |
+
row = cursor.fetchone()
|
465 |
+
conn.close()
|
466 |
+
if row:
|
467 |
+
columns = [description[0] for description in cursor.description]
|
468 |
+
chat = dict(zip(columns, row))
|
469 |
+
chat['chat_history'] = json.loads(chat['chat_history'])
|
470 |
+
return chat
|
471 |
+
return None
|
472 |
+
|
473 |
+
|
474 |
+
def search_character_chats(query: str, character_id: Optional[int] = None) -> Tuple[List[Dict], str]:
|
475 |
+
"""
|
476 |
+
Search for character chats using FTS5, optionally filtered by character_id.
|
477 |
+
|
478 |
+
Args:
|
479 |
+
query (str): The search query.
|
480 |
+
character_id (Optional[int]): The ID of the character to filter chats by.
|
481 |
+
|
482 |
+
Returns:
|
483 |
+
Tuple[List[Dict], str]: A list of matching chats and a status message.
|
484 |
+
"""
|
485 |
+
if not query.strip():
|
486 |
+
return [], "Please enter a search query."
|
487 |
+
|
488 |
+
conn = sqlite3.connect(chat_DB_PATH)
|
489 |
+
cursor = conn.cursor()
|
490 |
+
try:
|
491 |
+
if character_id is not None:
|
492 |
+
# Search with character_id filter
|
493 |
+
cursor.execute("""
|
494 |
+
SELECT CharacterChats.id, CharacterChats.conversation_name, CharacterChats.chat_history
|
495 |
+
FROM CharacterChats_fts
|
496 |
+
JOIN CharacterChats ON CharacterChats_fts.rowid = CharacterChats.id
|
497 |
+
WHERE CharacterChats_fts MATCH ? AND CharacterChats.character_id = ?
|
498 |
+
ORDER BY rank
|
499 |
+
""", (query, character_id))
|
500 |
+
else:
|
501 |
+
# Search without character_id filter
|
502 |
+
cursor.execute("""
|
503 |
+
SELECT CharacterChats.id, CharacterChats.conversation_name, CharacterChats.chat_history
|
504 |
+
FROM CharacterChats_fts
|
505 |
+
JOIN CharacterChats ON CharacterChats_fts.rowid = CharacterChats.id
|
506 |
+
WHERE CharacterChats_fts MATCH ?
|
507 |
+
ORDER BY rank
|
508 |
+
""", (query,))
|
509 |
+
|
510 |
+
rows = cursor.fetchall()
|
511 |
+
columns = [description[0] for description in cursor.description]
|
512 |
+
results = [dict(zip(columns, row)) for row in rows]
|
513 |
+
|
514 |
+
if character_id is not None:
|
515 |
+
status_message = f"Found {len(results)} chat(s) matching '{query}' for the selected character."
|
516 |
+
else:
|
517 |
+
status_message = f"Found {len(results)} chat(s) matching '{query}' across all characters."
|
518 |
+
|
519 |
+
return results, status_message
|
520 |
+
except Exception as e:
|
521 |
+
logging.error(f"Error searching chats with FTS5: {e}")
|
522 |
+
return [], f"Error occurred during search: {e}"
|
523 |
+
finally:
|
524 |
+
conn.close()
|
525 |
+
|
526 |
+
def update_character_chat(chat_id: int, chat_history: List[Tuple[str, str]]) -> bool:
|
527 |
+
"""Update an existing chat history."""
|
528 |
+
conn = sqlite3.connect(chat_DB_PATH)
|
529 |
+
cursor = conn.cursor()
|
530 |
+
try:
|
531 |
+
chat_history_json = json.dumps(chat_history)
|
532 |
+
cursor.execute("""
|
533 |
+
UPDATE CharacterChats
|
534 |
+
SET chat_history = ?
|
535 |
+
WHERE id = ?
|
536 |
+
""", (
|
537 |
+
chat_history_json,
|
538 |
+
chat_id
|
539 |
+
))
|
540 |
+
conn.commit()
|
541 |
+
return cursor.rowcount > 0
|
542 |
+
except sqlite3.Error as e:
|
543 |
+
logging.error(f"Error updating character chat: {e}")
|
544 |
+
return False
|
545 |
+
finally:
|
546 |
+
conn.close()
|
547 |
+
|
548 |
+
|
549 |
+
def delete_character_chat(chat_id: int) -> bool:
|
550 |
+
"""Delete a specific chat."""
|
551 |
+
conn = sqlite3.connect(chat_DB_PATH)
|
552 |
+
cursor = conn.cursor()
|
553 |
+
try:
|
554 |
+
cursor.execute("DELETE FROM CharacterChats WHERE id = ?", (chat_id,))
|
555 |
+
conn.commit()
|
556 |
+
return cursor.rowcount > 0
|
557 |
+
except sqlite3.Error as e:
|
558 |
+
logging.error(f"Error deleting character chat: {e}")
|
559 |
+
return False
|
560 |
+
finally:
|
561 |
+
conn.close()
|
562 |
+
|
563 |
+
def fetch_keywords_for_chats(keywords: List[str]) -> List[int]:
|
564 |
+
"""
|
565 |
+
Fetch chat IDs associated with any of the specified keywords.
|
566 |
+
|
567 |
+
Args:
|
568 |
+
keywords (List[str]): List of keywords to search for.
|
569 |
+
|
570 |
+
Returns:
|
571 |
+
List[int]: List of chat IDs associated with the keywords.
|
572 |
+
"""
|
573 |
+
if not keywords:
|
574 |
+
return []
|
575 |
+
|
576 |
+
conn = sqlite3.connect(chat_DB_PATH)
|
577 |
+
cursor = conn.cursor()
|
578 |
+
try:
|
579 |
+
# Construct the WHERE clause to search for each keyword
|
580 |
+
keyword_clauses = " OR ".join(["keyword = ?"] * len(keywords))
|
581 |
+
sql_query = f"SELECT DISTINCT chat_id FROM ChatKeywords WHERE {keyword_clauses}"
|
582 |
+
cursor.execute(sql_query, keywords)
|
583 |
+
rows = cursor.fetchall()
|
584 |
+
chat_ids = [row[0] for row in rows]
|
585 |
+
return chat_ids
|
586 |
+
except Exception as e:
|
587 |
+
logging.error(f"Error in fetch_keywords_for_chats: {e}")
|
588 |
+
return []
|
589 |
+
finally:
|
590 |
+
conn.close()
|
591 |
+
|
592 |
+
def save_chat_history_to_character_db(character_id: int, conversation_name: str, chat_history: List[Tuple[str, str]]) -> Optional[int]:
|
593 |
+
"""Save chat history to the CharacterChats table.
|
594 |
+
|
595 |
+
Returns the ID of the inserted chat or None if failed.
|
596 |
+
"""
|
597 |
+
return add_character_chat(character_id, conversation_name, chat_history)
|
598 |
+
|
599 |
+
def migrate_chat_to_media_db():
|
600 |
+
pass
|
601 |
+
|
602 |
+
|
603 |
+
def search_db(query: str, fields: List[str], where_clause: str = "", page: int = 1, results_per_page: int = 5) -> List[Dict[str, Any]]:
|
604 |
+
"""
|
605 |
+
Perform a full-text search on specified fields with optional filtering and pagination.
|
606 |
+
|
607 |
+
Args:
|
608 |
+
query (str): The search query.
|
609 |
+
fields (List[str]): List of fields to search in.
|
610 |
+
where_clause (str, optional): Additional SQL WHERE clause to filter results.
|
611 |
+
page (int, optional): Page number for pagination.
|
612 |
+
results_per_page (int, optional): Number of results per page.
|
613 |
+
|
614 |
+
Returns:
|
615 |
+
List[Dict[str, Any]]: List of matching chat records with content and metadata.
|
616 |
+
"""
|
617 |
+
if not query.strip():
|
618 |
+
return []
|
619 |
+
|
620 |
+
conn = sqlite3.connect(chat_DB_PATH)
|
621 |
+
cursor = conn.cursor()
|
622 |
+
try:
|
623 |
+
# Construct the MATCH query for FTS5
|
624 |
+
match_query = " AND ".join(fields) + f" MATCH ?"
|
625 |
+
# Adjust the query with the fields
|
626 |
+
fts_query = f"""
|
627 |
+
SELECT CharacterChats.id, CharacterChats.conversation_name, CharacterChats.chat_history
|
628 |
+
FROM CharacterChats_fts
|
629 |
+
JOIN CharacterChats ON CharacterChats_fts.rowid = CharacterChats.id
|
630 |
+
WHERE {match_query}
|
631 |
+
"""
|
632 |
+
if where_clause:
|
633 |
+
fts_query += f" AND ({where_clause})"
|
634 |
+
fts_query += " ORDER BY rank LIMIT ? OFFSET ?"
|
635 |
+
offset = (page - 1) * results_per_page
|
636 |
+
cursor.execute(fts_query, (query, results_per_page, offset))
|
637 |
+
rows = cursor.fetchall()
|
638 |
+
columns = [description[0] for description in cursor.description]
|
639 |
+
results = [dict(zip(columns, row)) for row in rows]
|
640 |
+
return results
|
641 |
+
except Exception as e:
|
642 |
+
logging.error(f"Error in search_db: {e}")
|
643 |
+
return []
|
644 |
+
finally:
|
645 |
+
conn.close()
|
646 |
+
|
647 |
+
|
648 |
+
def perform_full_text_search_chat(query: str, relevant_chat_ids: List[int], page: int = 1, results_per_page: int = 5) -> \
|
649 |
+
List[Dict[str, Any]]:
|
650 |
+
"""
|
651 |
+
Perform a full-text search within the specified chat IDs using FTS5.
|
652 |
+
|
653 |
+
Args:
|
654 |
+
query (str): The user's query.
|
655 |
+
relevant_chat_ids (List[int]): List of chat IDs to search within.
|
656 |
+
page (int): Pagination page number.
|
657 |
+
results_per_page (int): Number of results per page.
|
658 |
+
|
659 |
+
Returns:
|
660 |
+
List[Dict[str, Any]]: List of search results with content and metadata.
|
661 |
+
"""
|
662 |
+
try:
|
663 |
+
# Construct a WHERE clause to limit the search to relevant chat IDs
|
664 |
+
where_clause = " OR ".join([f"media_id = {chat_id}" for chat_id in relevant_chat_ids])
|
665 |
+
if not where_clause:
|
666 |
+
where_clause = "1" # No restriction if no chat IDs
|
667 |
+
|
668 |
+
# Perform full-text search using FTS5
|
669 |
+
fts_results = search_db(query, ["content"], where_clause, page=page, results_per_page=results_per_page)
|
670 |
+
|
671 |
+
filtered_fts_results = [
|
672 |
+
{
|
673 |
+
"content": result['content'],
|
674 |
+
"metadata": {"media_id": result['id']}
|
675 |
+
}
|
676 |
+
for result in fts_results
|
677 |
+
if result['id'] in relevant_chat_ids
|
678 |
+
]
|
679 |
+
return filtered_fts_results
|
680 |
+
except Exception as e:
|
681 |
+
logging.error(f"Error in perform_full_text_search_chat: {str(e)}")
|
682 |
+
return []
|
683 |
+
|
684 |
+
|
685 |
+
def fetch_all_chats() -> List[Dict[str, Any]]:
|
686 |
+
"""
|
687 |
+
Fetch all chat messages from the database.
|
688 |
+
|
689 |
+
Returns:
|
690 |
+
List[Dict[str, Any]]: List of chat messages with relevant metadata.
|
691 |
+
"""
|
692 |
+
try:
|
693 |
+
chats = get_character_chats() # Modify this function to retrieve all chats
|
694 |
+
return chats
|
695 |
+
except Exception as e:
|
696 |
+
logging.error(f"Error fetching all chats: {str(e)}")
|
697 |
+
return []
|
698 |
+
|
699 |
+
#
|
700 |
+
# End of Character_Chat_DB.py
|
701 |
+
#######################################################################################################################
|
App_Function_Libraries/DB/RAG_QA_Chat_DB.py
ADDED
@@ -0,0 +1,461 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# RAG_QA_Chat_DB.py
|
2 |
+
# Description: This file contains the database operations for the RAG QA Chat + Notes system.
|
3 |
+
#
|
4 |
+
# Imports
|
5 |
+
import logging
|
6 |
+
import re
|
7 |
+
import sqlite3
|
8 |
+
from contextlib import contextmanager
|
9 |
+
from datetime import datetime
|
10 |
+
#
|
11 |
+
# External Imports
|
12 |
+
#
|
13 |
+
# Local Imports
|
14 |
+
#
|
15 |
+
########################################################################################################################
|
16 |
+
#
|
17 |
+
# Functions:
|
18 |
+
|
19 |
+
# Set up logging
|
20 |
+
logging.basicConfig(level=logging.INFO)
|
21 |
+
logger = logging.getLogger(__name__)
|
22 |
+
|
23 |
+
# Database schema
|
24 |
+
SCHEMA_SQL = '''
|
25 |
+
-- Table for storing chat messages
|
26 |
+
CREATE TABLE IF NOT EXISTS rag_qa_chats (
|
27 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
28 |
+
conversation_id TEXT NOT NULL,
|
29 |
+
timestamp DATETIME NOT NULL,
|
30 |
+
role TEXT NOT NULL,
|
31 |
+
content TEXT NOT NULL
|
32 |
+
);
|
33 |
+
|
34 |
+
-- Table for storing conversation metadata
|
35 |
+
CREATE TABLE IF NOT EXISTS conversation_metadata (
|
36 |
+
conversation_id TEXT PRIMARY KEY,
|
37 |
+
created_at DATETIME NOT NULL,
|
38 |
+
last_updated DATETIME NOT NULL,
|
39 |
+
title TEXT NOT NULL
|
40 |
+
);
|
41 |
+
|
42 |
+
-- Table for storing keywords
|
43 |
+
CREATE TABLE IF NOT EXISTS rag_qa_keywords (
|
44 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
45 |
+
keyword TEXT NOT NULL UNIQUE
|
46 |
+
);
|
47 |
+
|
48 |
+
-- Table for linking keywords to conversations
|
49 |
+
CREATE TABLE IF NOT EXISTS rag_qa_conversation_keywords (
|
50 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
51 |
+
conversation_id TEXT NOT NULL,
|
52 |
+
keyword_id INTEGER NOT NULL,
|
53 |
+
FOREIGN KEY (conversation_id) REFERENCES conversation_metadata(conversation_id),
|
54 |
+
FOREIGN KEY (keyword_id) REFERENCES rag_qa_keywords(id)
|
55 |
+
);
|
56 |
+
|
57 |
+
-- Table for storing keyword collections
|
58 |
+
CREATE TABLE IF NOT EXISTS rag_qa_keyword_collections (
|
59 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
60 |
+
name TEXT NOT NULL UNIQUE,
|
61 |
+
parent_id INTEGER,
|
62 |
+
FOREIGN KEY (parent_id) REFERENCES rag_qa_keyword_collections(id)
|
63 |
+
);
|
64 |
+
|
65 |
+
-- Table for linking keywords to collections
|
66 |
+
CREATE TABLE IF NOT EXISTS rag_qa_collection_keywords (
|
67 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
68 |
+
collection_id INTEGER NOT NULL,
|
69 |
+
keyword_id INTEGER NOT NULL,
|
70 |
+
FOREIGN KEY (collection_id) REFERENCES rag_qa_keyword_collections(id),
|
71 |
+
FOREIGN KEY (keyword_id) REFERENCES rag_qa_keywords(id)
|
72 |
+
);
|
73 |
+
|
74 |
+
-- Table for storing notes
|
75 |
+
CREATE TABLE IF NOT EXISTS rag_qa_notes (
|
76 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
77 |
+
conversation_id TEXT NOT NULL,
|
78 |
+
content TEXT NOT NULL,
|
79 |
+
timestamp DATETIME NOT NULL,
|
80 |
+
FOREIGN KEY (conversation_id) REFERENCES conversation_metadata(conversation_id)
|
81 |
+
);
|
82 |
+
|
83 |
+
-- Table for linking notes to keywords
|
84 |
+
CREATE TABLE IF NOT EXISTS rag_qa_note_keywords (
|
85 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
86 |
+
note_id INTEGER NOT NULL,
|
87 |
+
keyword_id INTEGER NOT NULL,
|
88 |
+
FOREIGN KEY (note_id) REFERENCES rag_qa_notes(id),
|
89 |
+
FOREIGN KEY (keyword_id) REFERENCES rag_qa_keywords(id)
|
90 |
+
);
|
91 |
+
|
92 |
+
-- Indexes for improved query performance
|
93 |
+
CREATE INDEX IF NOT EXISTS idx_rag_qa_chats_conversation_id ON rag_qa_chats(conversation_id);
|
94 |
+
CREATE INDEX IF NOT EXISTS idx_rag_qa_chats_timestamp ON rag_qa_chats(timestamp);
|
95 |
+
CREATE INDEX IF NOT EXISTS idx_rag_qa_keywords_keyword ON rag_qa_keywords(keyword);
|
96 |
+
CREATE INDEX IF NOT EXISTS idx_rag_qa_conversation_keywords_conversation_id ON rag_qa_conversation_keywords(conversation_id);
|
97 |
+
CREATE INDEX IF NOT EXISTS idx_rag_qa_conversation_keywords_keyword_id ON rag_qa_conversation_keywords(keyword_id);
|
98 |
+
CREATE INDEX IF NOT EXISTS idx_rag_qa_keyword_collections_parent_id ON rag_qa_keyword_collections(parent_id);
|
99 |
+
CREATE INDEX IF NOT EXISTS idx_rag_qa_collection_keywords_collection_id ON rag_qa_collection_keywords(collection_id);
|
100 |
+
CREATE INDEX IF NOT EXISTS idx_rag_qa_collection_keywords_keyword_id ON rag_qa_collection_keywords(keyword_id);
|
101 |
+
|
102 |
+
-- Full-text search virtual table for chat content
|
103 |
+
CREATE VIRTUAL TABLE IF NOT EXISTS rag_qa_chats_fts USING fts5(conversation_id, timestamp, role, content);
|
104 |
+
|
105 |
+
-- Trigger to keep the FTS table up to date
|
106 |
+
CREATE TRIGGER IF NOT EXISTS rag_qa_chats_ai AFTER INSERT ON rag_qa_chats BEGIN
|
107 |
+
INSERT INTO rag_qa_chats_fts(conversation_id, timestamp, role, content) VALUES (new.conversation_id, new.timestamp, new.role, new.content);
|
108 |
+
END;
|
109 |
+
'''
|
110 |
+
|
111 |
+
# Database connection management
|
112 |
+
@contextmanager
|
113 |
+
def get_db_connection():
|
114 |
+
conn = sqlite3.connect('rag_qa_chat.db')
|
115 |
+
try:
|
116 |
+
yield conn
|
117 |
+
finally:
|
118 |
+
conn.close()
|
119 |
+
|
120 |
+
@contextmanager
|
121 |
+
def transaction():
|
122 |
+
with get_db_connection() as conn:
|
123 |
+
try:
|
124 |
+
conn.execute('BEGIN TRANSACTION')
|
125 |
+
yield conn
|
126 |
+
conn.commit()
|
127 |
+
except Exception:
|
128 |
+
conn.rollback()
|
129 |
+
raise
|
130 |
+
|
131 |
+
def execute_query(query, params=None, transaction_conn=None):
|
132 |
+
if transaction_conn:
|
133 |
+
cursor = transaction_conn.cursor()
|
134 |
+
if params:
|
135 |
+
cursor.execute(query, params)
|
136 |
+
else:
|
137 |
+
cursor.execute(query)
|
138 |
+
return cursor.fetchall()
|
139 |
+
else:
|
140 |
+
with get_db_connection() as conn:
|
141 |
+
cursor = conn.cursor()
|
142 |
+
if params:
|
143 |
+
cursor.execute(query, params)
|
144 |
+
else:
|
145 |
+
cursor.execute(query)
|
146 |
+
conn.commit()
|
147 |
+
return cursor.fetchall()
|
148 |
+
|
149 |
+
def create_tables():
|
150 |
+
with get_db_connection() as conn:
|
151 |
+
conn.executescript(SCHEMA_SQL)
|
152 |
+
logger.info("All RAG QA Chat tables created successfully")
|
153 |
+
|
154 |
+
# Initialize the database
|
155 |
+
create_tables()
|
156 |
+
|
157 |
+
# Input validation
|
158 |
+
def validate_keyword(keyword):
|
159 |
+
if not isinstance(keyword, str):
|
160 |
+
raise ValueError("Keyword must be a string")
|
161 |
+
if not keyword.strip():
|
162 |
+
raise ValueError("Keyword cannot be empty or just whitespace")
|
163 |
+
if len(keyword) > 100:
|
164 |
+
raise ValueError("Keyword is too long (max 100 characters)")
|
165 |
+
if not re.match(r'^[a-zA-Z0-9\s\-_]+$', keyword):
|
166 |
+
raise ValueError("Keyword contains invalid characters")
|
167 |
+
return keyword.strip()
|
168 |
+
|
169 |
+
def validate_collection_name(name):
|
170 |
+
if not isinstance(name, str):
|
171 |
+
raise ValueError("Collection name must be a string")
|
172 |
+
if not name.strip():
|
173 |
+
raise ValueError("Collection name cannot be empty or just whitespace")
|
174 |
+
if len(name) > 100:
|
175 |
+
raise ValueError("Collection name is too long (max 100 characters)")
|
176 |
+
if not re.match(r'^[a-zA-Z0-9\s\-_]+$', name):
|
177 |
+
raise ValueError("Collection name contains invalid characters")
|
178 |
+
return name.strip()
|
179 |
+
|
180 |
+
# Core functions
|
181 |
+
def add_keyword(keyword):
|
182 |
+
try:
|
183 |
+
validated_keyword = validate_keyword(keyword)
|
184 |
+
query = "INSERT OR IGNORE INTO rag_qa_keywords (keyword) VALUES (?)"
|
185 |
+
execute_query(query, (validated_keyword,))
|
186 |
+
logger.info(f"Keyword '{validated_keyword}' added successfully")
|
187 |
+
except ValueError as e:
|
188 |
+
logger.error(f"Invalid keyword: {e}")
|
189 |
+
raise
|
190 |
+
except Exception as e:
|
191 |
+
logger.error(f"Error adding keyword '{keyword}': {e}")
|
192 |
+
raise
|
193 |
+
|
194 |
+
def create_keyword_collection(name, parent_id=None):
|
195 |
+
try:
|
196 |
+
validated_name = validate_collection_name(name)
|
197 |
+
query = "INSERT INTO rag_qa_keyword_collections (name, parent_id) VALUES (?, ?)"
|
198 |
+
execute_query(query, (validated_name, parent_id))
|
199 |
+
logger.info(f"Keyword collection '{validated_name}' created successfully")
|
200 |
+
except ValueError as e:
|
201 |
+
logger.error(f"Invalid collection name: {e}")
|
202 |
+
raise
|
203 |
+
except Exception as e:
|
204 |
+
logger.error(f"Error creating keyword collection '{name}': {e}")
|
205 |
+
raise
|
206 |
+
|
207 |
+
def add_keyword_to_collection(collection_name, keyword):
|
208 |
+
try:
|
209 |
+
validated_collection_name = validate_collection_name(collection_name)
|
210 |
+
validated_keyword = validate_keyword(keyword)
|
211 |
+
|
212 |
+
with transaction() as conn:
|
213 |
+
add_keyword(validated_keyword)
|
214 |
+
|
215 |
+
query = '''
|
216 |
+
INSERT INTO rag_qa_collection_keywords (collection_id, keyword_id)
|
217 |
+
SELECT c.id, k.id
|
218 |
+
FROM rag_qa_keyword_collections c, rag_qa_keywords k
|
219 |
+
WHERE c.name = ? AND k.keyword = ?
|
220 |
+
'''
|
221 |
+
execute_query(query, (validated_collection_name, validated_keyword), conn)
|
222 |
+
|
223 |
+
logger.info(f"Keyword '{validated_keyword}' added to collection '{validated_collection_name}' successfully")
|
224 |
+
except ValueError as e:
|
225 |
+
logger.error(f"Invalid input: {e}")
|
226 |
+
raise
|
227 |
+
except Exception as e:
|
228 |
+
logger.error(f"Error adding keyword '{keyword}' to collection '{collection_name}': {e}")
|
229 |
+
raise
|
230 |
+
|
231 |
+
def add_keywords_to_conversation(conversation_id, keywords):
|
232 |
+
if not isinstance(keywords, (list, tuple)):
|
233 |
+
raise ValueError("Keywords must be a list or tuple")
|
234 |
+
try:
|
235 |
+
with transaction() as conn:
|
236 |
+
for keyword in keywords:
|
237 |
+
validated_keyword = validate_keyword(keyword)
|
238 |
+
|
239 |
+
query = "INSERT OR IGNORE INTO rag_qa_keywords (keyword) VALUES (?)"
|
240 |
+
execute_query(query, (validated_keyword,), conn)
|
241 |
+
|
242 |
+
query = '''
|
243 |
+
INSERT INTO rag_qa_conversation_keywords (conversation_id, keyword_id)
|
244 |
+
SELECT ?, id FROM rag_qa_keywords WHERE keyword = ?
|
245 |
+
'''
|
246 |
+
execute_query(query, (conversation_id, validated_keyword), conn)
|
247 |
+
|
248 |
+
logger.info(f"Keywords added to conversation '{conversation_id}' successfully")
|
249 |
+
except ValueError as e:
|
250 |
+
logger.error(f"Invalid keyword: {e}")
|
251 |
+
raise
|
252 |
+
except Exception as e:
|
253 |
+
logger.error(f"Error adding keywords to conversation '{conversation_id}': {e}")
|
254 |
+
raise
|
255 |
+
|
256 |
+
def get_keywords_for_conversation(conversation_id):
|
257 |
+
try:
|
258 |
+
query = '''
|
259 |
+
SELECT k.keyword
|
260 |
+
FROM rag_qa_keywords k
|
261 |
+
JOIN rag_qa_conversation_keywords ck ON k.id = ck.keyword_id
|
262 |
+
WHERE ck.conversation_id = ?
|
263 |
+
'''
|
264 |
+
result = execute_query(query, (conversation_id,))
|
265 |
+
keywords = [row[0] for row in result]
|
266 |
+
logger.info(f"Retrieved {len(keywords)} keywords for conversation '{conversation_id}'")
|
267 |
+
return keywords
|
268 |
+
except Exception as e:
|
269 |
+
logger.error(f"Error getting keywords for conversation '{conversation_id}': {e}")
|
270 |
+
raise
|
271 |
+
|
272 |
+
def get_keywords_for_collection(collection_name):
|
273 |
+
try:
|
274 |
+
query = '''
|
275 |
+
SELECT k.keyword
|
276 |
+
FROM rag_qa_keywords k
|
277 |
+
JOIN rag_qa_collection_keywords ck ON k.id = ck.keyword_id
|
278 |
+
JOIN rag_qa_keyword_collections c ON ck.collection_id = c.id
|
279 |
+
WHERE c.name = ?
|
280 |
+
'''
|
281 |
+
result = execute_query(query, (collection_name,))
|
282 |
+
keywords = [row[0] for row in result]
|
283 |
+
logger.info(f"Retrieved {len(keywords)} keywords for collection '{collection_name}'")
|
284 |
+
return keywords
|
285 |
+
except Exception as e:
|
286 |
+
logger.error(f"Error getting keywords for collection '{collection_name}': {e}")
|
287 |
+
raise
|
288 |
+
|
289 |
+
def save_notes(conversation_id, content):
|
290 |
+
"""Save notes to the database."""
|
291 |
+
try:
|
292 |
+
query = "INSERT INTO rag_qa_notes (conversation_id, content, timestamp) VALUES (?, ?, ?)"
|
293 |
+
timestamp = datetime.now().isoformat()
|
294 |
+
execute_query(query, (conversation_id, content, timestamp))
|
295 |
+
logger.info(f"Notes saved for conversation '{conversation_id}'")
|
296 |
+
except Exception as e:
|
297 |
+
logger.error(f"Error saving notes for conversation '{conversation_id}': {e}")
|
298 |
+
raise
|
299 |
+
|
300 |
+
def get_notes(conversation_id):
|
301 |
+
"""Retrieve notes for a given conversation."""
|
302 |
+
try:
|
303 |
+
query = "SELECT content FROM rag_qa_notes WHERE conversation_id = ?"
|
304 |
+
result = execute_query(query, (conversation_id,))
|
305 |
+
notes = [row[0] for row in result]
|
306 |
+
logger.info(f"Retrieved {len(notes)} notes for conversation '{conversation_id}'")
|
307 |
+
return notes
|
308 |
+
except Exception as e:
|
309 |
+
logger.error(f"Error getting notes for conversation '{conversation_id}': {e}")
|
310 |
+
raise
|
311 |
+
|
312 |
+
def clear_notes(conversation_id):
|
313 |
+
"""Clear all notes for a given conversation."""
|
314 |
+
try:
|
315 |
+
query = "DELETE FROM rag_qa_notes WHERE conversation_id = ?"
|
316 |
+
execute_query(query, (conversation_id,))
|
317 |
+
logger.info(f"Cleared notes for conversation '{conversation_id}'")
|
318 |
+
except Exception as e:
|
319 |
+
logger.error(f"Error clearing notes for conversation '{conversation_id}': {e}")
|
320 |
+
raise
|
321 |
+
|
322 |
+
def add_keywords_to_note(note_id, keywords):
|
323 |
+
"""Associate keywords with a note."""
|
324 |
+
try:
|
325 |
+
with transaction() as conn:
|
326 |
+
for keyword in keywords:
|
327 |
+
validated_keyword = validate_keyword(keyword)
|
328 |
+
|
329 |
+
# Insert the keyword into the rag_qa_keywords table if it doesn't exist
|
330 |
+
query = "INSERT OR IGNORE INTO rag_qa_keywords (keyword) VALUES (?)"
|
331 |
+
execute_query(query, (validated_keyword,), conn)
|
332 |
+
|
333 |
+
# Retrieve the keyword ID
|
334 |
+
query = "SELECT id FROM rag_qa_keywords WHERE keyword = ?"
|
335 |
+
keyword_id = execute_query(query, (validated_keyword,), conn)[0][0]
|
336 |
+
|
337 |
+
# Link the note and keyword
|
338 |
+
query = "INSERT INTO rag_qa_note_keywords (note_id, keyword_id) VALUES (?, ?)"
|
339 |
+
execute_query(query, (note_id, keyword_id), conn)
|
340 |
+
|
341 |
+
logger.info(f"Keywords added to note ID '{note_id}' successfully")
|
342 |
+
except Exception as e:
|
343 |
+
logger.error(f"Error adding keywords to note ID '{note_id}': {e}")
|
344 |
+
raise
|
345 |
+
|
346 |
+
def get_keywords_for_note(note_id):
|
347 |
+
"""Retrieve keywords associated with a given note."""
|
348 |
+
try:
|
349 |
+
query = '''
|
350 |
+
SELECT k.keyword
|
351 |
+
FROM rag_qa_keywords k
|
352 |
+
JOIN rag_qa_note_keywords nk ON k.id = nk.keyword_id
|
353 |
+
WHERE nk.note_id = ?
|
354 |
+
'''
|
355 |
+
result = execute_query(query, (note_id,))
|
356 |
+
keywords = [row[0] for row in result]
|
357 |
+
logger.info(f"Retrieved {len(keywords)} keywords for note ID '{note_id}'")
|
358 |
+
return keywords
|
359 |
+
except Exception as e:
|
360 |
+
logger.error(f"Error getting keywords for note ID '{note_id}': {e}")
|
361 |
+
raise
|
362 |
+
|
363 |
+
def clear_keywords_from_note(note_id):
|
364 |
+
"""Clear all keywords from a given note."""
|
365 |
+
try:
|
366 |
+
query = "DELETE FROM rag_qa_note_keywords WHERE note_id = ?"
|
367 |
+
execute_query(query, (note_id,))
|
368 |
+
logger.info(f"Cleared keywords for note ID '{note_id}'")
|
369 |
+
except Exception as e:
|
370 |
+
logger.error(f"Error clearing keywords for note ID '{note_id}': {e}")
|
371 |
+
raise
|
372 |
+
|
373 |
+
def save_message(conversation_id, role, content):
|
374 |
+
try:
|
375 |
+
query = "INSERT INTO rag_qa_chats (conversation_id, timestamp, role, content) VALUES (?, ?, ?, ?)"
|
376 |
+
timestamp = datetime.now().isoformat()
|
377 |
+
execute_query(query, (conversation_id, timestamp, role, content))
|
378 |
+
logger.info(f"Message saved for conversation '{conversation_id}'")
|
379 |
+
except Exception as e:
|
380 |
+
logger.error(f"Error saving message for conversation '{conversation_id}': {e}")
|
381 |
+
raise
|
382 |
+
|
383 |
+
def start_new_conversation(title="Untitled Conversation"):
|
384 |
+
try:
|
385 |
+
conversation_id = datetime.now().isoformat()
|
386 |
+
query = "INSERT INTO conversation_metadata (conversation_id, created_at, last_updated, title) VALUES (?, ?, ?, ?)"
|
387 |
+
now = datetime.now()
|
388 |
+
execute_query(query, (conversation_id, now, now, title))
|
389 |
+
logger.info(f"New conversation '{conversation_id}' started with title '{title}'")
|
390 |
+
return conversation_id
|
391 |
+
except Exception as e:
|
392 |
+
logger.error(f"Error starting new conversation: {e}")
|
393 |
+
raise
|
394 |
+
|
395 |
+
# Pagination helper function
|
396 |
+
def get_paginated_results(query, params=None, page=1, page_size=20):
|
397 |
+
try:
|
398 |
+
offset = (page - 1) * page_size
|
399 |
+
paginated_query = f"{query} LIMIT ? OFFSET ?"
|
400 |
+
if params:
|
401 |
+
params = tuple(params) + (page_size, offset)
|
402 |
+
else:
|
403 |
+
params = (page_size, offset)
|
404 |
+
|
405 |
+
result = execute_query(paginated_query, params)
|
406 |
+
|
407 |
+
count_query = f"SELECT COUNT(*) FROM ({query})"
|
408 |
+
total_count = execute_query(count_query, params[:-2] if params else None)[0][0]
|
409 |
+
|
410 |
+
total_pages = (total_count + page_size - 1) // page_size
|
411 |
+
|
412 |
+
logger.info(f"Retrieved page {page} of {total_pages} (total items: {total_count})")
|
413 |
+
return result, total_pages, total_count
|
414 |
+
except Exception as e:
|
415 |
+
logger.error(f"Error retrieving paginated results: {e}")
|
416 |
+
raise
|
417 |
+
|
418 |
+
def get_all_collections(page=1, page_size=20):
|
419 |
+
try:
|
420 |
+
query = "SELECT name FROM rag_qa_keyword_collections"
|
421 |
+
results, total_pages, total_count = get_paginated_results(query, page=page, page_size=page_size)
|
422 |
+
collections = [row[0] for row in results]
|
423 |
+
logger.info(f"Retrieved {len(collections)} keyword collections (page {page} of {total_pages})")
|
424 |
+
return collections, total_pages, total_count
|
425 |
+
except Exception as e:
|
426 |
+
logger.error(f"Error getting collections: {e}")
|
427 |
+
raise
|
428 |
+
|
429 |
+
def search_conversations_by_keywords(keywords, page=1, page_size=20):
|
430 |
+
try:
|
431 |
+
placeholders = ','.join(['?' for _ in keywords])
|
432 |
+
query = f'''
|
433 |
+
SELECT DISTINCT cm.conversation_id, cm.title
|
434 |
+
FROM conversation_metadata cm
|
435 |
+
JOIN rag_qa_conversation_keywords ck ON cm.conversation_id = ck.conversation_id
|
436 |
+
JOIN rag_qa_keywords k ON ck.keyword_id = k.id
|
437 |
+
WHERE k.keyword IN ({placeholders})
|
438 |
+
'''
|
439 |
+
results, total_pages, total_count = get_paginated_results(query, keywords, page, page_size)
|
440 |
+
logger.info(
|
441 |
+
f"Found {total_count} conversations matching keywords: {', '.join(keywords)} (page {page} of {total_pages})")
|
442 |
+
return results, total_pages, total_count
|
443 |
+
except Exception as e:
|
444 |
+
logger.error(f"Error searching conversations by keywords {keywords}: {e}")
|
445 |
+
raise
|
446 |
+
|
447 |
+
def load_chat_history(conversation_id, page=1, page_size=50):
|
448 |
+
try:
|
449 |
+
query = "SELECT role, content FROM rag_qa_chats WHERE conversation_id = ? ORDER BY timestamp"
|
450 |
+
results, total_pages, total_count = get_paginated_results(query, (conversation_id,), page, page_size)
|
451 |
+
history = [(msg[1] if msg[0] == 'human' else None, msg[1] if msg[0] == 'ai' else None) for msg in results]
|
452 |
+
logger.info(
|
453 |
+
f"Loaded {len(history)} messages for conversation '{conversation_id}' (page {page} of {total_pages})")
|
454 |
+
return history, total_pages, total_count
|
455 |
+
except Exception as e:
|
456 |
+
logger.error(f"Error loading chat history for conversation '{conversation_id}': {e}")
|
457 |
+
raise
|
458 |
+
|
459 |
+
#
|
460 |
+
# End of RAG_QA_Chat_DB.py
|
461 |
+
####################################################################################################
|
App_Function_Libraries/DB/SQLite_DB.py
CHANGED
@@ -1000,7 +1000,6 @@ def add_media_version(conn, media_id: int, prompt: str, summary: str) -> None:
|
|
1000 |
|
1001 |
|
1002 |
# Function to search the database with advanced options, including keyword search and full-text search
|
1003 |
-
|
1004 |
def sqlite_search_db(search_query: str, search_fields: List[str], keywords: str, page: int = 1, results_per_page: int = 10, connection=None):
|
1005 |
if page < 1:
|
1006 |
raise ValueError("Page number must be 1 or greater.")
|
@@ -1053,6 +1052,7 @@ def sqlite_search_db(search_query: str, search_fields: List[str], keywords: str,
|
|
1053 |
with db.get_connection() as conn:
|
1054 |
return execute_query(conn)
|
1055 |
|
|
|
1056 |
# Gradio function to handle user input and display results with pagination, with better feedback
|
1057 |
def search_and_display(search_query, search_fields, keywords, page):
|
1058 |
results = sqlite_search_db(search_query, search_fields, keywords, page)
|
|
|
1000 |
|
1001 |
|
1002 |
# Function to search the database with advanced options, including keyword search and full-text search
|
|
|
1003 |
def sqlite_search_db(search_query: str, search_fields: List[str], keywords: str, page: int = 1, results_per_page: int = 10, connection=None):
|
1004 |
if page < 1:
|
1005 |
raise ValueError("Page number must be 1 or greater.")
|
|
|
1052 |
with db.get_connection() as conn:
|
1053 |
return execute_query(conn)
|
1054 |
|
1055 |
+
|
1056 |
# Gradio function to handle user input and display results with pagination, with better feedback
|
1057 |
def search_and_display(search_query, search_fields, keywords, page):
|
1058 |
results = sqlite_search_db(search_query, search_fields, keywords, page)
|
App_Function_Libraries/Gradio_Related.py
CHANGED
@@ -1,423 +1,423 @@
|
|
1 |
-
# Gradio_Related.py
|
2 |
-
#########################################
|
3 |
-
# Gradio UI Functions Library
|
4 |
-
# I fucking hate Gradio.
|
5 |
-
#
|
6 |
-
#########################################
|
7 |
-
#
|
8 |
-
# Built-In Imports
|
9 |
-
import logging
|
10 |
-
import os
|
11 |
-
import webbrowser
|
12 |
-
|
13 |
-
#
|
14 |
-
# Import 3rd-Party Libraries
|
15 |
-
import gradio as gr
|
16 |
-
#
|
17 |
-
# Local Imports
|
18 |
-
from App_Function_Libraries.DB.DB_Manager import get_db_config
|
19 |
-
from App_Function_Libraries.Gradio_UI.Arxiv_tab import create_arxiv_tab
|
20 |
-
from App_Function_Libraries.Gradio_UI.Audio_ingestion_tab import create_audio_processing_tab
|
21 |
-
from App_Function_Libraries.Gradio_UI.Book_Ingestion_tab import create_import_book_tab
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
from App_Function_Libraries.Gradio_UI.Chat_ui import create_chat_management_tab, \
|
27 |
-
create_chat_interface_four, create_chat_interface_multi_api, create_chat_interface_stacked, create_chat_interface
|
28 |
-
from App_Function_Libraries.Gradio_UI.Config_tab import create_config_editor_tab
|
29 |
-
from App_Function_Libraries.Gradio_UI.Explain_summarize_tab import create_summarize_explain_tab
|
30 |
-
from App_Function_Libraries.Gradio_UI.Export_Functionality import create_export_tab
|
31 |
-
from App_Function_Libraries.Gradio_UI.Backup_Functionality import create_backup_tab, create_view_backups_tab, \
|
32 |
-
create_restore_backup_tab
|
33 |
-
from App_Function_Libraries.Gradio_UI.Import_Functionality import create_import_single_prompt_tab, \
|
34 |
-
create_import_obsidian_vault_tab, create_import_item_tab, create_import_multiple_prompts_tab
|
35 |
-
from App_Function_Libraries.Gradio_UI.Introduction_tab import create_introduction_tab
|
36 |
-
from App_Function_Libraries.Gradio_UI.Keywords import create_view_keywords_tab, create_add_keyword_tab, \
|
37 |
-
create_delete_keyword_tab, create_export_keywords_tab
|
38 |
-
|
39 |
-
from App_Function_Libraries.Gradio_UI.Llamafile_tab import create_chat_with_llamafile_tab
|
40 |
-
#from App_Function_Libraries.Gradio_UI.MMLU_Pro_tab import create_mmlu_pro_tab
|
41 |
-
from App_Function_Libraries.Gradio_UI.Media_edit import create_prompt_clone_tab, create_prompt_edit_tab, \
|
42 |
-
create_media_edit_and_clone_tab, create_media_edit_tab
|
43 |
-
from App_Function_Libraries.Gradio_UI.Media_wiki_tab import create_mediawiki_import_tab, create_mediawiki_config_tab
|
44 |
-
from App_Function_Libraries.Gradio_UI.PDF_ingestion_tab import create_pdf_ingestion_tab, create_pdf_ingestion_test_tab
|
45 |
-
from App_Function_Libraries.Gradio_UI.Plaintext_tab_import import create_plain_text_import_tab
|
46 |
-
from App_Function_Libraries.Gradio_UI.Podcast_tab import create_podcast_tab
|
47 |
-
from App_Function_Libraries.Gradio_UI.Prompt_Suggestion_tab import create_prompt_suggestion_tab
|
48 |
-
from App_Function_Libraries.Gradio_UI.RAG_QA_Chat_tab import create_rag_qa_chat_tab
|
49 |
-
from App_Function_Libraries.Gradio_UI.Re_summarize_tab import create_resummary_tab
|
50 |
-
from App_Function_Libraries.Gradio_UI.Search_Tab import create_prompt_search_tab, \
|
51 |
-
create_search_summaries_tab, create_search_tab
|
52 |
-
from App_Function_Libraries.Gradio_UI.RAG_Chat_tab import create_rag_tab
|
53 |
-
from App_Function_Libraries.Gradio_UI.Embeddings_tab import create_embeddings_tab, create_view_embeddings_tab, \
|
54 |
-
create_purge_embeddings_tab
|
55 |
-
from App_Function_Libraries.Gradio_UI.Trash import create_view_trash_tab, create_empty_trash_tab, \
|
56 |
-
create_delete_trash_tab, create_search_and_mark_trash_tab
|
57 |
-
from App_Function_Libraries.Gradio_UI.Utilities import create_utilities_yt_timestamp_tab, create_utilities_yt_audio_tab, \
|
58 |
-
create_utilities_yt_video_tab
|
59 |
-
from App_Function_Libraries.Gradio_UI.Video_transcription_tab import create_video_transcription_tab
|
60 |
-
from App_Function_Libraries.Gradio_UI.View_tab import create_manage_items_tab
|
61 |
-
from App_Function_Libraries.Gradio_UI.Website_scraping_tab import create_website_scraping_tab
|
62 |
-
from App_Function_Libraries.Gradio_UI.Chat_Workflows import chat_workflows_tab
|
63 |
-
from App_Function_Libraries.Gradio_UI.View_DB_Items_tab import create_prompt_view_tab, \
|
64 |
-
create_view_all_with_versions_tab, create_viewing_tab
|
65 |
-
#
|
66 |
-
# Gradio UI Imports
|
67 |
-
from App_Function_Libraries.Gradio_UI.Evaluations_Benchmarks_tab import create_geval_tab, create_infinite_bench_tab
|
68 |
-
#from App_Function_Libraries.Local_LLM.Local_LLM_huggingface import create_huggingface_tab
|
69 |
-
from App_Function_Libraries.Local_LLM.Local_LLM_ollama import create_ollama_tab
|
70 |
-
from App_Function_Libraries.Gradio_UI.RAG_QA_Chat_Notes import create_rag_qa_chat_notes_tab
|
71 |
-
|
72 |
-
#
|
73 |
-
#######################################################################################################################
|
74 |
-
# Function Definitions
|
75 |
-
#
|
76 |
-
|
77 |
-
|
78 |
-
# Disable Gradio Analytics
|
79 |
-
os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
|
80 |
-
|
81 |
-
|
82 |
-
custom_prompt_input = None
|
83 |
-
server_mode = False
|
84 |
-
share_public = False
|
85 |
-
custom_prompt_summarize_bulleted_notes = ("""
|
86 |
-
<s>You are a bulleted notes specialist. [INST]```When creating comprehensive bulleted notes, you should follow these guidelines: Use multiple headings based on the referenced topics, not categories like quotes or terms. Headings should be surrounded by bold formatting and not be listed as bullet points themselves. Leave no space between headings and their corresponding list items underneath. Important terms within the content should be emphasized by setting them in bold font. Any text that ends with a colon should also be bolded. Before submitting your response, review the instructions, and make any corrections necessary to adhered to the specified format. Do not reference these instructions within the notes.``` \nBased on the content between backticks create comprehensive bulleted notes.[/INST]
|
87 |
-
**Bulleted Note Creation Guidelines**
|
88 |
-
|
89 |
-
**Headings**:
|
90 |
-
- Based on referenced topics, not categories like quotes or terms
|
91 |
-
- Surrounded by **bold** formatting
|
92 |
-
- Not listed as bullet points
|
93 |
-
- No space between headings and list items underneath
|
94 |
-
|
95 |
-
**Emphasis**:
|
96 |
-
- **Important terms** set in bold font
|
97 |
-
- **Text ending in a colon**: also bolded
|
98 |
-
|
99 |
-
**Review**:
|
100 |
-
- Ensure adherence to specified format
|
101 |
-
- Do not reference these instructions in your response.</s>[INST] {{ .Prompt }} [/INST]
|
102 |
-
""")
|
103 |
-
#
|
104 |
-
# End of globals
|
105 |
-
#######################################################################################################################
|
106 |
-
#
|
107 |
-
# Start of Video/Audio Transcription and Summarization Functions
|
108 |
-
#
|
109 |
-
# Functions:
|
110 |
-
# FIXME
|
111 |
-
#
|
112 |
-
#
|
113 |
-
################################################################################################################
|
114 |
-
# Functions for Re-Summarization
|
115 |
-
#
|
116 |
-
# Functions:
|
117 |
-
# FIXME
|
118 |
-
# End of Re-Summarization Functions
|
119 |
-
#
|
120 |
-
############################################################################################################################################################################################################################
|
121 |
-
#
|
122 |
-
# Explain/Summarize This Tab
|
123 |
-
#
|
124 |
-
# Functions:
|
125 |
-
# FIXME
|
126 |
-
#
|
127 |
-
#
|
128 |
-
############################################################################################################################################################################################################################
|
129 |
-
#
|
130 |
-
# Transcript Comparison Tab
|
131 |
-
#
|
132 |
-
# Functions:
|
133 |
-
# FIXME
|
134 |
-
#
|
135 |
-
#
|
136 |
-
###########################################################################################################################################################################################################################
|
137 |
-
#
|
138 |
-
# Search Tab
|
139 |
-
#
|
140 |
-
# Functions:
|
141 |
-
# FIXME
|
142 |
-
#
|
143 |
-
# End of Search Tab Functions
|
144 |
-
#
|
145 |
-
##############################################################################################################################################################################################################################
|
146 |
-
#
|
147 |
-
# Llamafile Tab
|
148 |
-
#
|
149 |
-
# Functions:
|
150 |
-
# FIXME
|
151 |
-
#
|
152 |
-
# End of Llamafile Tab Functions
|
153 |
-
##############################################################################################################################################################################################################################
|
154 |
-
#
|
155 |
-
# Chat Interface Tab Functions
|
156 |
-
#
|
157 |
-
# Functions:
|
158 |
-
# FIXME
|
159 |
-
#
|
160 |
-
#
|
161 |
-
# End of Chat Interface Tab Functions
|
162 |
-
################################################################################################################################################################################################################################
|
163 |
-
#
|
164 |
-
# Media Edit Tab Functions
|
165 |
-
# Functions:
|
166 |
-
# Fixme
|
167 |
-
# create_media_edit_tab():
|
168 |
-
##### Trash Tab
|
169 |
-
# FIXME
|
170 |
-
# Functions:
|
171 |
-
#
|
172 |
-
# End of Media Edit Tab Functions
|
173 |
-
################################################################################################################
|
174 |
-
#
|
175 |
-
# Import Items Tab Functions
|
176 |
-
#
|
177 |
-
# Functions:
|
178 |
-
#FIXME
|
179 |
-
# End of Import Items Tab Functions
|
180 |
-
################################################################################################################
|
181 |
-
#
|
182 |
-
# Export Items Tab Functions
|
183 |
-
#
|
184 |
-
# Functions:
|
185 |
-
# FIXME
|
186 |
-
#
|
187 |
-
#
|
188 |
-
# End of Export Items Tab Functions
|
189 |
-
################################################################################################################
|
190 |
-
#
|
191 |
-
# Keyword Management Tab Functions
|
192 |
-
#
|
193 |
-
# Functions:
|
194 |
-
# create_view_keywords_tab():
|
195 |
-
# FIXME
|
196 |
-
#
|
197 |
-
# End of Keyword Management Tab Functions
|
198 |
-
################################################################################################################
|
199 |
-
#
|
200 |
-
# Document Editing Tab Functions
|
201 |
-
#
|
202 |
-
# Functions:
|
203 |
-
# #FIXME
|
204 |
-
#
|
205 |
-
#
|
206 |
-
################################################################################################################
|
207 |
-
#
|
208 |
-
# Utilities Tab Functions
|
209 |
-
# Functions:
|
210 |
-
# create_utilities_yt_video_tab():
|
211 |
-
# #FIXME
|
212 |
-
|
213 |
-
#
|
214 |
-
# End of Utilities Tab Functions
|
215 |
-
################################################################################################################
|
216 |
-
|
217 |
-
# FIXME - Prompt sample box
|
218 |
-
#
|
219 |
-
# # Sample data
|
220 |
-
# prompts_category_1 = [
|
221 |
-
# "What are the key points discussed in the video?",
|
222 |
-
# "Summarize the main arguments made by the speaker.",
|
223 |
-
# "Describe the conclusions of the study presented."
|
224 |
-
# ]
|
225 |
-
#
|
226 |
-
# prompts_category_2 = [
|
227 |
-
# "How does the proposed solution address the problem?",
|
228 |
-
# "What are the implications of the findings?",
|
229 |
-
# "Can you explain the theory behind the observed phenomenon?"
|
230 |
-
# ]
|
231 |
-
#
|
232 |
-
# all_prompts2 = prompts_category_1 + prompts_category_2
|
233 |
-
|
234 |
-
|
235 |
-
def launch_ui(share_public=None, server_mode=False):
|
236 |
-
webbrowser.open_new_tab('http://127.0.0.1:7860/?__theme=dark')
|
237 |
-
share=share_public
|
238 |
-
css = """
|
239 |
-
.result-box {
|
240 |
-
margin-bottom: 20px;
|
241 |
-
border: 1px solid #ddd;
|
242 |
-
padding: 10px;
|
243 |
-
}
|
244 |
-
.result-box.error {
|
245 |
-
border-color: #ff0000;
|
246 |
-
background-color: #ffeeee;
|
247 |
-
}
|
248 |
-
.transcription, .summary {
|
249 |
-
max-height: 800px;
|
250 |
-
overflow-y: auto;
|
251 |
-
border: 1px solid #eee;
|
252 |
-
padding: 10px;
|
253 |
-
margin-top: 10px;
|
254 |
-
}
|
255 |
-
"""
|
256 |
-
|
257 |
-
with gr.Blocks(theme='bethecloud/storj_theme',css=css) as iface:
|
258 |
-
gr.HTML(
|
259 |
-
"""
|
260 |
-
<script>
|
261 |
-
document.addEventListener('DOMContentLoaded', (event) => {
|
262 |
-
document.body.classList.add('dark');
|
263 |
-
document.querySelector('gradio-app').style.backgroundColor = 'var(--color-background-primary)';
|
264 |
-
});
|
265 |
-
</script>
|
266 |
-
"""
|
267 |
-
)
|
268 |
-
db_config = get_db_config()
|
269 |
-
db_type = db_config['type']
|
270 |
-
gr.Markdown(f"# tl/dw: Your LLM-powered Research Multi-tool")
|
271 |
-
gr.Markdown(f"(Using {db_type.capitalize()} Database)")
|
272 |
-
with gr.Tabs():
|
273 |
-
with gr.TabItem("Transcription / Summarization / Ingestion", id="ingestion-grouping"):
|
274 |
-
with gr.Tabs():
|
275 |
-
create_video_transcription_tab()
|
276 |
-
create_audio_processing_tab()
|
277 |
-
create_podcast_tab()
|
278 |
-
create_import_book_tab()
|
279 |
-
create_plain_text_import_tab()
|
280 |
-
create_website_scraping_tab()
|
281 |
-
create_pdf_ingestion_tab()
|
282 |
-
create_pdf_ingestion_test_tab()
|
283 |
-
create_resummary_tab()
|
284 |
-
create_summarize_explain_tab()
|
285 |
-
create_live_recording_tab()
|
286 |
-
create_arxiv_tab()
|
287 |
-
|
288 |
-
with gr.TabItem("Text Search", id="text search"):
|
289 |
-
create_search_tab()
|
290 |
-
create_search_summaries_tab()
|
291 |
-
|
292 |
-
with gr.TabItem("RAG Chat+Notes", id="RAG Chat Notes group"):
|
293 |
-
create_rag_qa_chat_notes_tab()
|
294 |
-
|
295 |
-
with gr.TabItem("RAG Search", id="RAG Search grou"):
|
296 |
-
create_rag_tab()
|
297 |
-
create_rag_qa_chat_tab()
|
298 |
-
|
299 |
-
with gr.TabItem("Chat with an LLM", id="LLM Chat group"):
|
300 |
-
create_chat_interface()
|
301 |
-
create_chat_interface_stacked()
|
302 |
-
create_chat_interface_multi_api()
|
303 |
-
create_chat_interface_four()
|
304 |
-
create_chat_with_llamafile_tab()
|
305 |
-
create_chat_management_tab()
|
306 |
-
chat_workflows_tab()
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
with gr.TabItem("View DB Items", id="view db items group"):
|
321 |
-
# This one works
|
322 |
-
create_view_all_with_versions_tab()
|
323 |
-
# This one is WIP
|
324 |
-
create_viewing_tab()
|
325 |
-
create_prompt_view_tab()
|
326 |
-
|
327 |
-
|
328 |
-
with gr.TabItem("Prompts", id='view prompts group'):
|
329 |
-
create_prompt_view_tab()
|
330 |
-
create_prompt_search_tab()
|
331 |
-
create_prompt_edit_tab()
|
332 |
-
create_prompt_clone_tab()
|
333 |
-
create_prompt_suggestion_tab()
|
334 |
-
|
335 |
-
|
336 |
-
with gr.TabItem("Manage / Edit Existing Items", id="manage group"):
|
337 |
-
create_media_edit_tab()
|
338 |
-
create_manage_items_tab()
|
339 |
-
create_media_edit_and_clone_tab()
|
340 |
-
# FIXME
|
341 |
-
#create_compare_transcripts_tab()
|
342 |
-
|
343 |
-
|
344 |
-
with gr.TabItem("Embeddings Management", id="embeddings group"):
|
345 |
-
create_embeddings_tab()
|
346 |
-
create_view_embeddings_tab()
|
347 |
-
create_purge_embeddings_tab()
|
348 |
-
|
349 |
-
with gr.TabItem("Writing Tools", id="writing_tools group"):
|
350 |
-
from App_Function_Libraries.Gradio_UI.Writing_tab import create_document_feedback_tab
|
351 |
-
create_document_feedback_tab()
|
352 |
-
from App_Function_Libraries.Gradio_UI.Writing_tab import create_grammar_style_check_tab
|
353 |
-
create_grammar_style_check_tab()
|
354 |
-
from App_Function_Libraries.Gradio_UI.Writing_tab import create_tone_adjustment_tab
|
355 |
-
create_tone_adjustment_tab()
|
356 |
-
from App_Function_Libraries.Gradio_UI.Writing_tab import create_creative_writing_tab
|
357 |
-
create_creative_writing_tab()
|
358 |
-
from App_Function_Libraries.Gradio_UI.Writing_tab import create_mikupad_tab
|
359 |
-
create_mikupad_tab()
|
360 |
-
|
361 |
-
|
362 |
-
with gr.TabItem("Keywords", id="keywords group"):
|
363 |
-
create_view_keywords_tab()
|
364 |
-
create_add_keyword_tab()
|
365 |
-
create_delete_keyword_tab()
|
366 |
-
create_export_keywords_tab()
|
367 |
-
|
368 |
-
with gr.TabItem("Import", id="import group"):
|
369 |
-
create_import_item_tab()
|
370 |
-
create_import_obsidian_vault_tab()
|
371 |
-
create_import_single_prompt_tab()
|
372 |
-
create_import_multiple_prompts_tab()
|
373 |
-
create_mediawiki_import_tab()
|
374 |
-
create_mediawiki_config_tab()
|
375 |
-
|
376 |
-
with gr.TabItem("Export", id="export group"):
|
377 |
-
create_export_tab()
|
378 |
-
|
379 |
-
with gr.TabItem("Backup Management", id="backup group"):
|
380 |
-
create_backup_tab()
|
381 |
-
create_view_backups_tab()
|
382 |
-
create_restore_backup_tab()
|
383 |
-
|
384 |
-
with gr.TabItem("Utilities", id="util group"):
|
385 |
-
create_utilities_yt_video_tab()
|
386 |
-
create_utilities_yt_audio_tab()
|
387 |
-
create_utilities_yt_timestamp_tab()
|
388 |
-
|
389 |
-
with gr.TabItem("Local LLM", id="local llm group"):
|
390 |
-
create_chat_with_llamafile_tab()
|
391 |
-
create_ollama_tab()
|
392 |
-
#create_huggingface_tab()
|
393 |
-
|
394 |
-
with gr.TabItem("Trashcan", id="trashcan group"):
|
395 |
-
create_search_and_mark_trash_tab()
|
396 |
-
create_view_trash_tab()
|
397 |
-
create_delete_trash_tab()
|
398 |
-
create_empty_trash_tab()
|
399 |
-
|
400 |
-
with gr.TabItem("Evaluations", id="eval"):
|
401 |
-
create_geval_tab()
|
402 |
-
create_infinite_bench_tab()
|
403 |
-
# FIXME
|
404 |
-
#create_mmlu_pro_tab()
|
405 |
-
|
406 |
-
with gr.TabItem("Introduction/Help", id="introduction group"):
|
407 |
-
create_introduction_tab()
|
408 |
-
|
409 |
-
with gr.TabItem("Config Editor", id="config group"):
|
410 |
-
create_config_editor_tab()
|
411 |
-
|
412 |
-
# Launch the interface
|
413 |
-
server_port_variable = 7860
|
414 |
-
os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
|
415 |
-
if share==True:
|
416 |
-
iface.launch(share=True)
|
417 |
-
elif server_mode and not share_public:
|
418 |
-
iface.launch(share=False, server_name="0.0.0.0", server_port=server_port_variable, )
|
419 |
-
else:
|
420 |
-
try:
|
421 |
-
iface.launch(share=False, server_name="0.0.0.0", server_port=server_port_variable, )
|
422 |
-
except Exception as e:
|
423 |
-
logging.error(f"Error launching interface: {str(e)}")
|
|
|
1 |
+
# Gradio_Related.py
|
2 |
+
#########################################
|
3 |
+
# Gradio UI Functions Library
|
4 |
+
# I fucking hate Gradio.
|
5 |
+
#
|
6 |
+
#########################################
|
7 |
+
#
|
8 |
+
# Built-In Imports
|
9 |
+
import logging
|
10 |
+
import os
|
11 |
+
import webbrowser
|
12 |
+
|
13 |
+
#
|
14 |
+
# Import 3rd-Party Libraries
|
15 |
+
import gradio as gr
|
16 |
+
#
|
17 |
+
# Local Imports
|
18 |
+
from App_Function_Libraries.DB.DB_Manager import get_db_config
|
19 |
+
from App_Function_Libraries.Gradio_UI.Arxiv_tab import create_arxiv_tab
|
20 |
+
from App_Function_Libraries.Gradio_UI.Audio_ingestion_tab import create_audio_processing_tab
|
21 |
+
from App_Function_Libraries.Gradio_UI.Book_Ingestion_tab import create_import_book_tab
|
22 |
+
from App_Function_Libraries.Gradio_UI.Character_Chat_tab import create_character_card_interaction_tab, create_character_chat_mgmt_tab, create_custom_character_card_tab, \
|
23 |
+
create_character_card_validation_tab, create_export_characters_tab
|
24 |
+
from App_Function_Libraries.Gradio_UI.Character_interaction_tab import create_narrator_controlled_conversation_tab, \
|
25 |
+
create_multiple_character_chat_tab
|
26 |
+
from App_Function_Libraries.Gradio_UI.Chat_ui import create_chat_management_tab, \
|
27 |
+
create_chat_interface_four, create_chat_interface_multi_api, create_chat_interface_stacked, create_chat_interface
|
28 |
+
from App_Function_Libraries.Gradio_UI.Config_tab import create_config_editor_tab
|
29 |
+
from App_Function_Libraries.Gradio_UI.Explain_summarize_tab import create_summarize_explain_tab
|
30 |
+
from App_Function_Libraries.Gradio_UI.Export_Functionality import create_export_tab
|
31 |
+
from App_Function_Libraries.Gradio_UI.Backup_Functionality import create_backup_tab, create_view_backups_tab, \
|
32 |
+
create_restore_backup_tab
|
33 |
+
from App_Function_Libraries.Gradio_UI.Import_Functionality import create_import_single_prompt_tab, \
|
34 |
+
create_import_obsidian_vault_tab, create_import_item_tab, create_import_multiple_prompts_tab
|
35 |
+
from App_Function_Libraries.Gradio_UI.Introduction_tab import create_introduction_tab
|
36 |
+
from App_Function_Libraries.Gradio_UI.Keywords import create_view_keywords_tab, create_add_keyword_tab, \
|
37 |
+
create_delete_keyword_tab, create_export_keywords_tab
|
38 |
+
from App_Function_Libraries.Gradio_UI.Live_Recording import create_live_recording_tab
|
39 |
+
from App_Function_Libraries.Gradio_UI.Llamafile_tab import create_chat_with_llamafile_tab
|
40 |
+
#from App_Function_Libraries.Gradio_UI.MMLU_Pro_tab import create_mmlu_pro_tab
|
41 |
+
from App_Function_Libraries.Gradio_UI.Media_edit import create_prompt_clone_tab, create_prompt_edit_tab, \
|
42 |
+
create_media_edit_and_clone_tab, create_media_edit_tab
|
43 |
+
from App_Function_Libraries.Gradio_UI.Media_wiki_tab import create_mediawiki_import_tab, create_mediawiki_config_tab
|
44 |
+
from App_Function_Libraries.Gradio_UI.PDF_ingestion_tab import create_pdf_ingestion_tab, create_pdf_ingestion_test_tab
|
45 |
+
from App_Function_Libraries.Gradio_UI.Plaintext_tab_import import create_plain_text_import_tab
|
46 |
+
from App_Function_Libraries.Gradio_UI.Podcast_tab import create_podcast_tab
|
47 |
+
from App_Function_Libraries.Gradio_UI.Prompt_Suggestion_tab import create_prompt_suggestion_tab
|
48 |
+
from App_Function_Libraries.Gradio_UI.RAG_QA_Chat_tab import create_rag_qa_chat_tab
|
49 |
+
from App_Function_Libraries.Gradio_UI.Re_summarize_tab import create_resummary_tab
|
50 |
+
from App_Function_Libraries.Gradio_UI.Search_Tab import create_prompt_search_tab, \
|
51 |
+
create_search_summaries_tab, create_search_tab
|
52 |
+
from App_Function_Libraries.Gradio_UI.RAG_Chat_tab import create_rag_tab
|
53 |
+
from App_Function_Libraries.Gradio_UI.Embeddings_tab import create_embeddings_tab, create_view_embeddings_tab, \
|
54 |
+
create_purge_embeddings_tab
|
55 |
+
from App_Function_Libraries.Gradio_UI.Trash import create_view_trash_tab, create_empty_trash_tab, \
|
56 |
+
create_delete_trash_tab, create_search_and_mark_trash_tab
|
57 |
+
from App_Function_Libraries.Gradio_UI.Utilities import create_utilities_yt_timestamp_tab, create_utilities_yt_audio_tab, \
|
58 |
+
create_utilities_yt_video_tab
|
59 |
+
from App_Function_Libraries.Gradio_UI.Video_transcription_tab import create_video_transcription_tab
|
60 |
+
from App_Function_Libraries.Gradio_UI.View_tab import create_manage_items_tab
|
61 |
+
from App_Function_Libraries.Gradio_UI.Website_scraping_tab import create_website_scraping_tab
|
62 |
+
from App_Function_Libraries.Gradio_UI.Chat_Workflows import chat_workflows_tab
|
63 |
+
from App_Function_Libraries.Gradio_UI.View_DB_Items_tab import create_prompt_view_tab, \
|
64 |
+
create_view_all_with_versions_tab, create_viewing_tab
|
65 |
+
#
|
66 |
+
# Gradio UI Imports
|
67 |
+
from App_Function_Libraries.Gradio_UI.Evaluations_Benchmarks_tab import create_geval_tab, create_infinite_bench_tab
|
68 |
+
#from App_Function_Libraries.Local_LLM.Local_LLM_huggingface import create_huggingface_tab
|
69 |
+
from App_Function_Libraries.Local_LLM.Local_LLM_ollama import create_ollama_tab
|
70 |
+
from App_Function_Libraries.Gradio_UI.RAG_QA_Chat_Notes import create_rag_qa_chat_notes_tab
|
71 |
+
|
72 |
+
#
|
73 |
+
#######################################################################################################################
|
74 |
+
# Function Definitions
|
75 |
+
#
|
76 |
+
|
77 |
+
|
78 |
+
# Disable Gradio Analytics
|
79 |
+
os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
|
80 |
+
|
81 |
+
|
82 |
+
custom_prompt_input = None
|
83 |
+
server_mode = False
|
84 |
+
share_public = False
|
85 |
+
custom_prompt_summarize_bulleted_notes = ("""
|
86 |
+
<s>You are a bulleted notes specialist. [INST]```When creating comprehensive bulleted notes, you should follow these guidelines: Use multiple headings based on the referenced topics, not categories like quotes or terms. Headings should be surrounded by bold formatting and not be listed as bullet points themselves. Leave no space between headings and their corresponding list items underneath. Important terms within the content should be emphasized by setting them in bold font. Any text that ends with a colon should also be bolded. Before submitting your response, review the instructions, and make any corrections necessary to adhered to the specified format. Do not reference these instructions within the notes.``` \nBased on the content between backticks create comprehensive bulleted notes.[/INST]
|
87 |
+
**Bulleted Note Creation Guidelines**
|
88 |
+
|
89 |
+
**Headings**:
|
90 |
+
- Based on referenced topics, not categories like quotes or terms
|
91 |
+
- Surrounded by **bold** formatting
|
92 |
+
- Not listed as bullet points
|
93 |
+
- No space between headings and list items underneath
|
94 |
+
|
95 |
+
**Emphasis**:
|
96 |
+
- **Important terms** set in bold font
|
97 |
+
- **Text ending in a colon**: also bolded
|
98 |
+
|
99 |
+
**Review**:
|
100 |
+
- Ensure adherence to specified format
|
101 |
+
- Do not reference these instructions in your response.</s>[INST] {{ .Prompt }} [/INST]
|
102 |
+
""")
|
103 |
+
#
|
104 |
+
# End of globals
|
105 |
+
#######################################################################################################################
|
106 |
+
#
|
107 |
+
# Start of Video/Audio Transcription and Summarization Functions
|
108 |
+
#
|
109 |
+
# Functions:
|
110 |
+
# FIXME
|
111 |
+
#
|
112 |
+
#
|
113 |
+
################################################################################################################
|
114 |
+
# Functions for Re-Summarization
|
115 |
+
#
|
116 |
+
# Functions:
|
117 |
+
# FIXME
|
118 |
+
# End of Re-Summarization Functions
|
119 |
+
#
|
120 |
+
############################################################################################################################################################################################################################
|
121 |
+
#
|
122 |
+
# Explain/Summarize This Tab
|
123 |
+
#
|
124 |
+
# Functions:
|
125 |
+
# FIXME
|
126 |
+
#
|
127 |
+
#
|
128 |
+
############################################################################################################################################################################################################################
|
129 |
+
#
|
130 |
+
# Transcript Comparison Tab
|
131 |
+
#
|
132 |
+
# Functions:
|
133 |
+
# FIXME
|
134 |
+
#
|
135 |
+
#
|
136 |
+
###########################################################################################################################################################################################################################
|
137 |
+
#
|
138 |
+
# Search Tab
|
139 |
+
#
|
140 |
+
# Functions:
|
141 |
+
# FIXME
|
142 |
+
#
|
143 |
+
# End of Search Tab Functions
|
144 |
+
#
|
145 |
+
##############################################################################################################################################################################################################################
|
146 |
+
#
|
147 |
+
# Llamafile Tab
|
148 |
+
#
|
149 |
+
# Functions:
|
150 |
+
# FIXME
|
151 |
+
#
|
152 |
+
# End of Llamafile Tab Functions
|
153 |
+
##############################################################################################################################################################################################################################
|
154 |
+
#
|
155 |
+
# Chat Interface Tab Functions
|
156 |
+
#
|
157 |
+
# Functions:
|
158 |
+
# FIXME
|
159 |
+
#
|
160 |
+
#
|
161 |
+
# End of Chat Interface Tab Functions
|
162 |
+
################################################################################################################################################################################################################################
|
163 |
+
#
|
164 |
+
# Media Edit Tab Functions
|
165 |
+
# Functions:
|
166 |
+
# Fixme
|
167 |
+
# create_media_edit_tab():
|
168 |
+
##### Trash Tab
|
169 |
+
# FIXME
|
170 |
+
# Functions:
|
171 |
+
#
|
172 |
+
# End of Media Edit Tab Functions
|
173 |
+
################################################################################################################
|
174 |
+
#
|
175 |
+
# Import Items Tab Functions
|
176 |
+
#
|
177 |
+
# Functions:
|
178 |
+
#FIXME
|
179 |
+
# End of Import Items Tab Functions
|
180 |
+
################################################################################################################
|
181 |
+
#
|
182 |
+
# Export Items Tab Functions
|
183 |
+
#
|
184 |
+
# Functions:
|
185 |
+
# FIXME
|
186 |
+
#
|
187 |
+
#
|
188 |
+
# End of Export Items Tab Functions
|
189 |
+
################################################################################################################
|
190 |
+
#
|
191 |
+
# Keyword Management Tab Functions
|
192 |
+
#
|
193 |
+
# Functions:
|
194 |
+
# create_view_keywords_tab():
|
195 |
+
# FIXME
|
196 |
+
#
|
197 |
+
# End of Keyword Management Tab Functions
|
198 |
+
################################################################################################################
|
199 |
+
#
|
200 |
+
# Document Editing Tab Functions
|
201 |
+
#
|
202 |
+
# Functions:
|
203 |
+
# #FIXME
|
204 |
+
#
|
205 |
+
#
|
206 |
+
################################################################################################################
|
207 |
+
#
|
208 |
+
# Utilities Tab Functions
|
209 |
+
# Functions:
|
210 |
+
# create_utilities_yt_video_tab():
|
211 |
+
# #FIXME
|
212 |
+
|
213 |
+
#
|
214 |
+
# End of Utilities Tab Functions
|
215 |
+
################################################################################################################
|
216 |
+
|
217 |
+
# FIXME - Prompt sample box
|
218 |
+
#
|
219 |
+
# # Sample data
|
220 |
+
# prompts_category_1 = [
|
221 |
+
# "What are the key points discussed in the video?",
|
222 |
+
# "Summarize the main arguments made by the speaker.",
|
223 |
+
# "Describe the conclusions of the study presented."
|
224 |
+
# ]
|
225 |
+
#
|
226 |
+
# prompts_category_2 = [
|
227 |
+
# "How does the proposed solution address the problem?",
|
228 |
+
# "What are the implications of the findings?",
|
229 |
+
# "Can you explain the theory behind the observed phenomenon?"
|
230 |
+
# ]
|
231 |
+
#
|
232 |
+
# all_prompts2 = prompts_category_1 + prompts_category_2
|
233 |
+
|
234 |
+
|
235 |
+
def launch_ui(share_public=None, server_mode=False):
|
236 |
+
webbrowser.open_new_tab('http://127.0.0.1:7860/?__theme=dark')
|
237 |
+
share=share_public
|
238 |
+
css = """
|
239 |
+
.result-box {
|
240 |
+
margin-bottom: 20px;
|
241 |
+
border: 1px solid #ddd;
|
242 |
+
padding: 10px;
|
243 |
+
}
|
244 |
+
.result-box.error {
|
245 |
+
border-color: #ff0000;
|
246 |
+
background-color: #ffeeee;
|
247 |
+
}
|
248 |
+
.transcription, .summary {
|
249 |
+
max-height: 800px;
|
250 |
+
overflow-y: auto;
|
251 |
+
border: 1px solid #eee;
|
252 |
+
padding: 10px;
|
253 |
+
margin-top: 10px;
|
254 |
+
}
|
255 |
+
"""
|
256 |
+
|
257 |
+
with gr.Blocks(theme='bethecloud/storj_theme',css=css) as iface:
|
258 |
+
gr.HTML(
|
259 |
+
"""
|
260 |
+
<script>
|
261 |
+
document.addEventListener('DOMContentLoaded', (event) => {
|
262 |
+
document.body.classList.add('dark');
|
263 |
+
document.querySelector('gradio-app').style.backgroundColor = 'var(--color-background-primary)';
|
264 |
+
});
|
265 |
+
</script>
|
266 |
+
"""
|
267 |
+
)
|
268 |
+
db_config = get_db_config()
|
269 |
+
db_type = db_config['type']
|
270 |
+
gr.Markdown(f"# tl/dw: Your LLM-powered Research Multi-tool")
|
271 |
+
gr.Markdown(f"(Using {db_type.capitalize()} Database)")
|
272 |
+
with gr.Tabs():
|
273 |
+
with gr.TabItem("Transcription / Summarization / Ingestion", id="ingestion-grouping"):
|
274 |
+
with gr.Tabs():
|
275 |
+
create_video_transcription_tab()
|
276 |
+
create_audio_processing_tab()
|
277 |
+
create_podcast_tab()
|
278 |
+
create_import_book_tab()
|
279 |
+
create_plain_text_import_tab()
|
280 |
+
create_website_scraping_tab()
|
281 |
+
create_pdf_ingestion_tab()
|
282 |
+
create_pdf_ingestion_test_tab()
|
283 |
+
create_resummary_tab()
|
284 |
+
create_summarize_explain_tab()
|
285 |
+
create_live_recording_tab()
|
286 |
+
create_arxiv_tab()
|
287 |
+
|
288 |
+
with gr.TabItem("Text Search", id="text search"):
|
289 |
+
create_search_tab()
|
290 |
+
create_search_summaries_tab()
|
291 |
+
|
292 |
+
with gr.TabItem("RAG Chat+Notes", id="RAG Chat Notes group"):
|
293 |
+
create_rag_qa_chat_notes_tab()
|
294 |
+
|
295 |
+
with gr.TabItem("RAG Search", id="RAG Search grou"):
|
296 |
+
create_rag_tab()
|
297 |
+
create_rag_qa_chat_tab()
|
298 |
+
|
299 |
+
with gr.TabItem("Chat with an LLM", id="LLM Chat group"):
|
300 |
+
create_chat_interface()
|
301 |
+
create_chat_interface_stacked()
|
302 |
+
create_chat_interface_multi_api()
|
303 |
+
create_chat_interface_four()
|
304 |
+
create_chat_with_llamafile_tab()
|
305 |
+
create_chat_management_tab()
|
306 |
+
chat_workflows_tab()
|
307 |
+
|
308 |
+
|
309 |
+
with gr.TabItem("Character Chat", id="character chat group"):
|
310 |
+
with gr.Tabs():
|
311 |
+
create_character_card_interaction_tab()
|
312 |
+
create_character_chat_mgmt_tab()
|
313 |
+
create_custom_character_card_tab()
|
314 |
+
create_character_card_validation_tab()
|
315 |
+
create_multiple_character_chat_tab()
|
316 |
+
create_narrator_controlled_conversation_tab()
|
317 |
+
create_export_characters_tab()
|
318 |
+
|
319 |
+
|
320 |
+
with gr.TabItem("View DB Items", id="view db items group"):
|
321 |
+
# This one works
|
322 |
+
create_view_all_with_versions_tab()
|
323 |
+
# This one is WIP
|
324 |
+
create_viewing_tab()
|
325 |
+
create_prompt_view_tab()
|
326 |
+
|
327 |
+
|
328 |
+
with gr.TabItem("Prompts", id='view prompts group'):
|
329 |
+
create_prompt_view_tab()
|
330 |
+
create_prompt_search_tab()
|
331 |
+
create_prompt_edit_tab()
|
332 |
+
create_prompt_clone_tab()
|
333 |
+
create_prompt_suggestion_tab()
|
334 |
+
|
335 |
+
|
336 |
+
with gr.TabItem("Manage / Edit Existing Items", id="manage group"):
|
337 |
+
create_media_edit_tab()
|
338 |
+
create_manage_items_tab()
|
339 |
+
create_media_edit_and_clone_tab()
|
340 |
+
# FIXME
|
341 |
+
#create_compare_transcripts_tab()
|
342 |
+
|
343 |
+
|
344 |
+
with gr.TabItem("Embeddings Management", id="embeddings group"):
|
345 |
+
create_embeddings_tab()
|
346 |
+
create_view_embeddings_tab()
|
347 |
+
create_purge_embeddings_tab()
|
348 |
+
|
349 |
+
with gr.TabItem("Writing Tools", id="writing_tools group"):
|
350 |
+
from App_Function_Libraries.Gradio_UI.Writing_tab import create_document_feedback_tab
|
351 |
+
create_document_feedback_tab()
|
352 |
+
from App_Function_Libraries.Gradio_UI.Writing_tab import create_grammar_style_check_tab
|
353 |
+
create_grammar_style_check_tab()
|
354 |
+
from App_Function_Libraries.Gradio_UI.Writing_tab import create_tone_adjustment_tab
|
355 |
+
create_tone_adjustment_tab()
|
356 |
+
from App_Function_Libraries.Gradio_UI.Writing_tab import create_creative_writing_tab
|
357 |
+
create_creative_writing_tab()
|
358 |
+
from App_Function_Libraries.Gradio_UI.Writing_tab import create_mikupad_tab
|
359 |
+
create_mikupad_tab()
|
360 |
+
|
361 |
+
|
362 |
+
with gr.TabItem("Keywords", id="keywords group"):
|
363 |
+
create_view_keywords_tab()
|
364 |
+
create_add_keyword_tab()
|
365 |
+
create_delete_keyword_tab()
|
366 |
+
create_export_keywords_tab()
|
367 |
+
|
368 |
+
with gr.TabItem("Import", id="import group"):
|
369 |
+
create_import_item_tab()
|
370 |
+
create_import_obsidian_vault_tab()
|
371 |
+
create_import_single_prompt_tab()
|
372 |
+
create_import_multiple_prompts_tab()
|
373 |
+
create_mediawiki_import_tab()
|
374 |
+
create_mediawiki_config_tab()
|
375 |
+
|
376 |
+
with gr.TabItem("Export", id="export group"):
|
377 |
+
create_export_tab()
|
378 |
+
|
379 |
+
with gr.TabItem("Backup Management", id="backup group"):
|
380 |
+
create_backup_tab()
|
381 |
+
create_view_backups_tab()
|
382 |
+
create_restore_backup_tab()
|
383 |
+
|
384 |
+
with gr.TabItem("Utilities", id="util group"):
|
385 |
+
create_utilities_yt_video_tab()
|
386 |
+
create_utilities_yt_audio_tab()
|
387 |
+
create_utilities_yt_timestamp_tab()
|
388 |
+
|
389 |
+
with gr.TabItem("Local LLM", id="local llm group"):
|
390 |
+
create_chat_with_llamafile_tab()
|
391 |
+
create_ollama_tab()
|
392 |
+
#create_huggingface_tab()
|
393 |
+
|
394 |
+
with gr.TabItem("Trashcan", id="trashcan group"):
|
395 |
+
create_search_and_mark_trash_tab()
|
396 |
+
create_view_trash_tab()
|
397 |
+
create_delete_trash_tab()
|
398 |
+
create_empty_trash_tab()
|
399 |
+
|
400 |
+
with gr.TabItem("Evaluations", id="eval"):
|
401 |
+
create_geval_tab()
|
402 |
+
create_infinite_bench_tab()
|
403 |
+
# FIXME
|
404 |
+
#create_mmlu_pro_tab()
|
405 |
+
|
406 |
+
with gr.TabItem("Introduction/Help", id="introduction group"):
|
407 |
+
create_introduction_tab()
|
408 |
+
|
409 |
+
with gr.TabItem("Config Editor", id="config group"):
|
410 |
+
create_config_editor_tab()
|
411 |
+
|
412 |
+
# Launch the interface
|
413 |
+
server_port_variable = 7860
|
414 |
+
os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
|
415 |
+
if share==True:
|
416 |
+
iface.launch(share=True)
|
417 |
+
elif server_mode and not share_public:
|
418 |
+
iface.launch(share=False, server_name="0.0.0.0", server_port=server_port_variable, )
|
419 |
+
else:
|
420 |
+
try:
|
421 |
+
iface.launch(share=False, server_name="0.0.0.0", server_port=server_port_variable, )
|
422 |
+
except Exception as e:
|
423 |
+
logging.error(f"Error launching interface: {str(e)}")
|
App_Function_Libraries/Gradio_UI/Character_Chat_tab.py
CHANGED
@@ -12,6 +12,7 @@ import logging
|
|
12 |
import io
|
13 |
import base64
|
14 |
from typing import Dict, Any, Optional, List, Tuple, Union, cast
|
|
|
15 |
#
|
16 |
# External Imports
|
17 |
from PIL import Image
|
@@ -186,6 +187,67 @@ def parse_v1_card(card_data: Dict[str, Any]) -> Dict[str, Any]:
|
|
186 |
# End of Character card import functions
|
187 |
####################################################
|
188 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
189 |
|
190 |
####################################################
|
191 |
#
|
@@ -1721,457 +1783,65 @@ def create_character_card_validation_tab():
|
|
1721 |
inputs=[file_upload],
|
1722 |
outputs=[validation_output]
|
1723 |
)
|
1724 |
-
# v2-not-working-on-export-def create_character_card_validation_tab():
|
1725 |
-
# with gr.TabItem("Validate and Edit Character Card"):
|
1726 |
-
# gr.Markdown("# Validate and Edit Character Card (v2)")
|
1727 |
-
# gr.Markdown("Upload a character card (PNG, WEBP, or JSON) to validate and modify it.")
|
1728 |
-
#
|
1729 |
-
# with gr.Row():
|
1730 |
-
# with gr.Column():
|
1731 |
-
# # File uploader
|
1732 |
-
# file_upload = gr.File(
|
1733 |
-
# label="Upload Character Card (PNG, WEBP, JSON)",
|
1734 |
-
# file_types=[".png", ".webp", ".json"]
|
1735 |
-
# )
|
1736 |
-
# # Validation button
|
1737 |
-
# validate_button = gr.Button("Validate and Load Character Card")
|
1738 |
-
# # Output area for validation results
|
1739 |
-
# validation_output = gr.Markdown("")
|
1740 |
-
#
|
1741 |
-
# # Input fields for character card data (duplicated from the create tab)
|
1742 |
-
# with gr.Row():
|
1743 |
-
# with gr.Column():
|
1744 |
-
# name_input = gr.Textbox(label="Name", placeholder="Enter character name")
|
1745 |
-
# description_input = gr.TextArea(label="Description", placeholder="Enter character description")
|
1746 |
-
# personality_input = gr.TextArea(label="Personality", placeholder="Enter character personality")
|
1747 |
-
# scenario_input = gr.TextArea(label="Scenario", placeholder="Enter character scenario")
|
1748 |
-
# first_mes_input = gr.TextArea(label="First Message", placeholder="Enter the first message")
|
1749 |
-
# mes_example_input = gr.TextArea(label="Example Messages", placeholder="Enter example messages")
|
1750 |
-
# creator_notes_input = gr.TextArea(label="Creator Notes", placeholder="Enter notes for the creator")
|
1751 |
-
# system_prompt_input = gr.TextArea(label="System Prompt", placeholder="Enter system prompt")
|
1752 |
-
# post_history_instructions_input = gr.TextArea(label="Post History Instructions", placeholder="Enter post history instructions")
|
1753 |
-
# alternate_greetings_input = gr.TextArea(
|
1754 |
-
# label="Alternate Greetings (one per line)",
|
1755 |
-
# placeholder="Enter alternate greetings, one per line"
|
1756 |
-
# )
|
1757 |
-
# tags_input = gr.Textbox(label="Tags", placeholder="Enter tags, separated by commas")
|
1758 |
-
# creator_input = gr.Textbox(label="Creator", placeholder="Enter creator name")
|
1759 |
-
# character_version_input = gr.Textbox(label="Character Version", placeholder="Enter character version")
|
1760 |
-
# extensions_input = gr.TextArea(
|
1761 |
-
# label="Extensions (JSON)",
|
1762 |
-
# placeholder="Enter extensions as JSON (optional)"
|
1763 |
-
# )
|
1764 |
-
# image_input = gr.Image(label="Character Image", type="pil")
|
1765 |
-
#
|
1766 |
-
# # Buttons
|
1767 |
-
# save_button = gr.Button("Save Character Card")
|
1768 |
-
# download_button = gr.Button("Download Character Card")
|
1769 |
-
# download_image_button = gr.Button("Download Character Card as Image")
|
1770 |
-
#
|
1771 |
-
# # Output status and outputs
|
1772 |
-
# save_status = gr.Markdown("")
|
1773 |
-
# download_output = gr.File(label="Download Character Card", interactive=False)
|
1774 |
-
# download_image_output = gr.File(label="Download Character Card as Image", interactive=False)
|
1775 |
-
#
|
1776 |
-
# # Callback Functions
|
1777 |
-
# def extract_json_from_image(file):
|
1778 |
-
# try:
|
1779 |
-
# image = Image.open(file.name)
|
1780 |
-
# if "chara" in image.info:
|
1781 |
-
# json_data = image.info["chara"]
|
1782 |
-
# # Decode base64 if necessary
|
1783 |
-
# try:
|
1784 |
-
# json_data = base64.b64decode(json_data).decode('utf-8')
|
1785 |
-
# except Exception:
|
1786 |
-
# pass # Assume it's already in plain text
|
1787 |
-
# return json_data
|
1788 |
-
# else:
|
1789 |
-
# return None
|
1790 |
-
# except Exception as e:
|
1791 |
-
# logging.error(f"Error extracting JSON from image: {e}")
|
1792 |
-
# return None
|
1793 |
-
#
|
1794 |
-
# def validate_v2_card(card_data):
|
1795 |
-
# """
|
1796 |
-
# Validate a character card according to the V2 specification.
|
1797 |
-
#
|
1798 |
-
# Args:
|
1799 |
-
# card_data (dict): The parsed character card data.
|
1800 |
-
#
|
1801 |
-
# Returns:
|
1802 |
-
# Tuple[bool, List[str]]: A tuple containing a boolean indicating validity and a list of validation messages.
|
1803 |
-
# """
|
1804 |
-
# validation_messages = []
|
1805 |
-
#
|
1806 |
-
# # Check top-level fields
|
1807 |
-
# if 'spec' not in card_data:
|
1808 |
-
# validation_messages.append("Missing 'spec' field.")
|
1809 |
-
# elif card_data['spec'] != 'chara_card_v2':
|
1810 |
-
# validation_messages.append(f"Invalid 'spec' value: {card_data['spec']}. Expected 'chara_card_v2'.")
|
1811 |
-
#
|
1812 |
-
# if 'spec_version' not in card_data:
|
1813 |
-
# validation_messages.append("Missing 'spec_version' field.")
|
1814 |
-
# else:
|
1815 |
-
# # Ensure 'spec_version' is '2.0' or higher
|
1816 |
-
# try:
|
1817 |
-
# spec_version = float(card_data['spec_version'])
|
1818 |
-
# if spec_version < 2.0:
|
1819 |
-
# validation_messages.append(
|
1820 |
-
# f"'spec_version' must be '2.0' or higher. Found '{card_data['spec_version']}'.")
|
1821 |
-
# except ValueError:
|
1822 |
-
# validation_messages.append(
|
1823 |
-
# f"Invalid 'spec_version' format: {card_data['spec_version']}. Must be a number as a string.")
|
1824 |
-
#
|
1825 |
-
# if 'data' not in card_data:
|
1826 |
-
# validation_messages.append("Missing 'data' field.")
|
1827 |
-
# return False, validation_messages # Cannot proceed without 'data' field
|
1828 |
-
#
|
1829 |
-
# data = card_data['data']
|
1830 |
-
#
|
1831 |
-
# # Required fields in 'data'
|
1832 |
-
# required_fields = ['name', 'description', 'personality', 'scenario', 'first_mes', 'mes_example']
|
1833 |
-
# for field in required_fields:
|
1834 |
-
# if field not in data:
|
1835 |
-
# validation_messages.append(f"Missing required field in 'data': '{field}'.")
|
1836 |
-
# elif not isinstance(data[field], str):
|
1837 |
-
# validation_messages.append(f"Field '{field}' must be a string.")
|
1838 |
-
# elif not data[field].strip():
|
1839 |
-
# validation_messages.append(f"Field '{field}' cannot be empty.")
|
1840 |
-
#
|
1841 |
-
# # Optional fields with expected types
|
1842 |
-
# optional_fields = {
|
1843 |
-
# 'creator_notes': str,
|
1844 |
-
# 'system_prompt': str,
|
1845 |
-
# 'post_history_instructions': str,
|
1846 |
-
# 'alternate_greetings': list,
|
1847 |
-
# 'tags': list,
|
1848 |
-
# 'creator': str,
|
1849 |
-
# 'character_version': str,
|
1850 |
-
# 'extensions': dict,
|
1851 |
-
# 'character_book': dict # If present, should be a dict
|
1852 |
-
# }
|
1853 |
-
#
|
1854 |
-
# for field, expected_type in optional_fields.items():
|
1855 |
-
# if field in data:
|
1856 |
-
# if not isinstance(data[field], expected_type):
|
1857 |
-
# validation_messages.append(f"Field '{field}' must be of type '{expected_type.__name__}'.")
|
1858 |
-
# elif field == 'extensions':
|
1859 |
-
# # Validate that extensions keys are properly namespaced
|
1860 |
-
# for key in data[field].keys():
|
1861 |
-
# if '/' not in key and '_' not in key:
|
1862 |
-
# validation_messages.append(
|
1863 |
-
# f"Extension key '{key}' in 'extensions' should be namespaced to prevent conflicts.")
|
1864 |
-
#
|
1865 |
-
# # If 'alternate_greetings' is present, check that it's a list of non-empty strings
|
1866 |
-
# if 'alternate_greetings' in data and isinstance(data['alternate_greetings'], list):
|
1867 |
-
# for idx, greeting in enumerate(data['alternate_greetings']):
|
1868 |
-
# if not isinstance(greeting, str) or not greeting.strip():
|
1869 |
-
# validation_messages.append(
|
1870 |
-
# f"Element {idx} in 'alternate_greetings' must be a non-empty string.")
|
1871 |
-
#
|
1872 |
-
# # If 'tags' is present, check that it's a list of non-empty strings
|
1873 |
-
# if 'tags' in data and isinstance(data['tags'], list):
|
1874 |
-
# for idx, tag in enumerate(data['tags']):
|
1875 |
-
# if not isinstance(tag, str) or not tag.strip():
|
1876 |
-
# validation_messages.append(f"Element {idx} in 'tags' must be a non-empty string.")
|
1877 |
-
#
|
1878 |
-
# # Validate 'extensions' field
|
1879 |
-
# if 'extensions' in data and not isinstance(data['extensions'], dict):
|
1880 |
-
# validation_messages.append("Field 'extensions' must be a dictionary.")
|
1881 |
-
#
|
1882 |
-
# # Validate 'character_book' if present
|
1883 |
-
# # (Assuming you have a validate_character_book function)
|
1884 |
-
# # if 'character_book' in data:
|
1885 |
-
# # is_valid_book, book_messages = validate_character_book(data['character_book'])
|
1886 |
-
# # if not is_valid_book:
|
1887 |
-
# # validation_messages.extend(book_messages)
|
1888 |
-
#
|
1889 |
-
# is_valid = len(validation_messages) == 0
|
1890 |
-
# return is_valid, validation_messages
|
1891 |
-
#
|
1892 |
-
# # Include the save_character_card, download_character_card, and download_character_card_as_image functions
|
1893 |
-
# def save_character_card(
|
1894 |
-
# name, description, personality, scenario, first_mes, mes_example,
|
1895 |
-
# creator_notes, system_prompt, post_history_instructions,
|
1896 |
-
# alternate_greetings_str, tags_str, creator, character_version,
|
1897 |
-
# extensions_str, image
|
1898 |
-
# ):
|
1899 |
-
# # Build the character card
|
1900 |
-
# character_card = build_character_card(
|
1901 |
-
# name, description, personality, scenario, first_mes, mes_example,
|
1902 |
-
# creator_notes, system_prompt, post_history_instructions,
|
1903 |
-
# alternate_greetings_str, tags_str, creator, character_version,
|
1904 |
-
# extensions_str
|
1905 |
-
# )
|
1906 |
-
#
|
1907 |
-
# # Validate the character card
|
1908 |
-
# is_valid, validation_messages = validate_v2_card(character_card)
|
1909 |
-
# if not is_valid:
|
1910 |
-
# # Return validation errors
|
1911 |
-
# validation_output = "Character card validation failed:\n"
|
1912 |
-
# validation_output += "\n".join(validation_messages)
|
1913 |
-
# return validation_output
|
1914 |
-
#
|
1915 |
-
# # If image is provided, encode it to base64
|
1916 |
-
# if image:
|
1917 |
-
# img_byte_arr = io.BytesIO()
|
1918 |
-
# image.save(img_byte_arr, format='PNG')
|
1919 |
-
# character_card['data']['image'] = base64.b64encode(img_byte_arr.getvalue()).decode('utf-8')
|
1920 |
-
#
|
1921 |
-
# # Save character card to database
|
1922 |
-
# character_id = add_character_card(character_card['data'])
|
1923 |
-
# if character_id:
|
1924 |
-
# return f"Character card '{name}' saved successfully."
|
1925 |
-
# else:
|
1926 |
-
# return f"Failed to save character card '{name}'. It may already exist."
|
1927 |
-
#
|
1928 |
-
# def download_character_card(
|
1929 |
-
# name, description, personality, scenario, first_mes, mes_example,
|
1930 |
-
# creator_notes, system_prompt, post_history_instructions,
|
1931 |
-
# alternate_greetings_str, tags_str, creator, character_version,
|
1932 |
-
# extensions_str, image
|
1933 |
-
# ):
|
1934 |
-
# # Build the character card
|
1935 |
-
# character_card = build_character_card(
|
1936 |
-
# name, description, personality, scenario, first_mes, mes_example,
|
1937 |
-
# creator_notes, system_prompt, post_history_instructions,
|
1938 |
-
# alternate_greetings_str, tags_str, creator, character_version,
|
1939 |
-
# extensions_str
|
1940 |
-
# )
|
1941 |
-
#
|
1942 |
-
# # Validate the character card
|
1943 |
-
# is_valid, validation_messages = validate_v2_card(character_card)
|
1944 |
-
# if not is_valid:
|
1945 |
-
# # Return validation errors
|
1946 |
-
# validation_output = "Character card validation failed:\n"
|
1947 |
-
# validation_output += "\n".join(validation_messages)
|
1948 |
-
# return gr.update(value=None), validation_output # Return None for the file output
|
1949 |
-
#
|
1950 |
-
# # If image is provided, include it as base64
|
1951 |
-
# if image:
|
1952 |
-
# img_byte_arr = io.BytesIO()
|
1953 |
-
# image.save(img_byte_arr, format='PNG')
|
1954 |
-
# character_card['data']['image'] = base64.b64encode(img_byte_arr.getvalue()).decode('utf-8')
|
1955 |
-
#
|
1956 |
-
# # Convert to JSON string
|
1957 |
-
# json_str = json.dumps(character_card, indent=2)
|
1958 |
-
#
|
1959 |
-
# # Write the JSON to a temporary file
|
1960 |
-
# with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.json', encoding='utf-8') as temp_file:
|
1961 |
-
# temp_file.write(json_str)
|
1962 |
-
# temp_file_path = temp_file.name
|
1963 |
-
#
|
1964 |
-
# # Return the file path and clear validation output
|
1965 |
-
# return temp_file_path, ""
|
1966 |
-
#
|
1967 |
-
# def download_character_card_as_image(
|
1968 |
-
# name, description, personality, scenario, first_mes, mes_example,
|
1969 |
-
# creator_notes, system_prompt, post_history_instructions,
|
1970 |
-
# alternate_greetings_str, tags_str, creator, character_version,
|
1971 |
-
# extensions_str, image
|
1972 |
-
# ):
|
1973 |
-
# # Build the character card
|
1974 |
-
# character_card = build_character_card(
|
1975 |
-
# name, description, personality, scenario, first_mes, mes_example,
|
1976 |
-
# creator_notes, system_prompt, post_history_instructions,
|
1977 |
-
# alternate_greetings_str, tags_str, creator, character_version,
|
1978 |
-
# extensions_str
|
1979 |
-
# )
|
1980 |
-
#
|
1981 |
-
# # Validate the character card
|
1982 |
-
# is_valid, validation_messages = validate_v2_card(character_card)
|
1983 |
-
# if not is_valid:
|
1984 |
-
# # Return validation errors
|
1985 |
-
# validation_output = "Character card validation failed:\n"
|
1986 |
-
# validation_output += "\n".join(validation_messages)
|
1987 |
-
# return gr.update(value=None), validation_output # Return None for the file output
|
1988 |
-
#
|
1989 |
-
# # Convert the character card JSON to a string
|
1990 |
-
# json_str = json.dumps(character_card, indent=2)
|
1991 |
-
#
|
1992 |
-
# # Encode the JSON string to base64
|
1993 |
-
# chara_content = base64.b64encode(json_str.encode('utf-8')).decode('utf-8')
|
1994 |
-
#
|
1995 |
-
# # Create PNGInfo object to hold metadata
|
1996 |
-
# png_info = PngInfo()
|
1997 |
-
# png_info.add_text('chara', chara_content)
|
1998 |
-
#
|
1999 |
-
# # If image is provided, use it; otherwise, create a blank image
|
2000 |
-
# if image:
|
2001 |
-
# img = image.copy()
|
2002 |
-
# else:
|
2003 |
-
# # Create a default blank image
|
2004 |
-
# img = Image.new('RGB', (512, 512), color='white')
|
2005 |
-
#
|
2006 |
-
# # Save the image to a temporary file with metadata
|
2007 |
-
# with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.png') as temp_file:
|
2008 |
-
# img.save(temp_file, format='PNG', pnginfo=png_info)
|
2009 |
-
# temp_file_path = temp_file.name
|
2010 |
-
#
|
2011 |
-
# # Return the file path and clear validation output
|
2012 |
-
# return temp_file_path, ""
|
2013 |
-
#
|
2014 |
-
# def build_character_card(
|
2015 |
-
# name, description, personality, scenario, first_mes, mes_example,
|
2016 |
-
# creator_notes, system_prompt, post_history_instructions,
|
2017 |
-
# alternate_greetings_str, tags_str, creator, character_version,
|
2018 |
-
# extensions_str
|
2019 |
-
# ):
|
2020 |
-
# # Parse alternate_greetings from multiline string
|
2021 |
-
# alternate_greetings = [line.strip() for line in alternate_greetings_str.strip().split('\n') if line.strip()]
|
2022 |
-
#
|
2023 |
-
# # Parse tags from comma-separated string
|
2024 |
-
# tags = [tag.strip() for tag in tags_str.strip().split(',') if tag.strip()]
|
2025 |
-
#
|
2026 |
-
# # Parse extensions from JSON string
|
2027 |
-
# try:
|
2028 |
-
# extensions = json.loads(extensions_str) if extensions_str.strip() else {}
|
2029 |
-
# except json.JSONDecodeError as e:
|
2030 |
-
# extensions = {}
|
2031 |
-
# logging.error(f"Error parsing extensions JSON: {e}")
|
2032 |
-
#
|
2033 |
-
# # Build the character card dictionary according to V2 spec
|
2034 |
-
# character_card = {
|
2035 |
-
# 'spec': 'chara_card_v2',
|
2036 |
-
# 'spec_version': '2.0',
|
2037 |
-
# 'data': {
|
2038 |
-
# 'name': name,
|
2039 |
-
# 'description': description,
|
2040 |
-
# 'personality': personality,
|
2041 |
-
# 'scenario': scenario,
|
2042 |
-
# 'first_mes': first_mes,
|
2043 |
-
# 'mes_example': mes_example,
|
2044 |
-
# 'creator_notes': creator_notes,
|
2045 |
-
# 'system_prompt': system_prompt,
|
2046 |
-
# 'post_history_instructions': post_history_instructions,
|
2047 |
-
# 'alternate_greetings': alternate_greetings,
|
2048 |
-
# 'tags': tags,
|
2049 |
-
# 'creator': creator,
|
2050 |
-
# 'character_version': character_version,
|
2051 |
-
# 'extensions': extensions,
|
2052 |
-
# }
|
2053 |
-
# }
|
2054 |
-
# return character_card
|
2055 |
-
#
|
2056 |
-
# def validate_and_load_character_card(file):
|
2057 |
-
# if file is None:
|
2058 |
-
# return ["No file provided for validation."] + [gr.update() for _ in range(15)]
|
2059 |
-
#
|
2060 |
-
# try:
|
2061 |
-
# if file.name.lower().endswith(('.png', '.webp')):
|
2062 |
-
# json_data = extract_json_from_image(file)
|
2063 |
-
# if not json_data:
|
2064 |
-
# return ["Failed to extract JSON data from the image."] + [gr.update() for _ in range(15)]
|
2065 |
-
# elif file.name.lower().endswith('.json'):
|
2066 |
-
# with open(file.name, 'r', encoding='utf-8') as f:
|
2067 |
-
# json_data = f.read()
|
2068 |
-
# else:
|
2069 |
-
# return ["Unsupported file type."] + [gr.update() for _ in range(15)]
|
2070 |
-
#
|
2071 |
-
# # Parse the JSON content
|
2072 |
-
# try:
|
2073 |
-
# card_data = json.loads(json_data)
|
2074 |
-
# except json.JSONDecodeError as e:
|
2075 |
-
# return [f"JSON decoding error: {e}"] + [gr.update() for _ in range(15)]
|
2076 |
-
#
|
2077 |
-
# # Validate the character card
|
2078 |
-
# is_valid, validation_messages = validate_v2_card(card_data)
|
2079 |
-
#
|
2080 |
-
# # Prepare the validation output
|
2081 |
-
# if is_valid:
|
2082 |
-
# validation_output_msg = "Character card is valid according to the V2 specification."
|
2083 |
-
# else:
|
2084 |
-
# validation_output_msg = "Character card validation failed:\n" + "\n".join(validation_messages)
|
2085 |
-
#
|
2086 |
-
# # Extract data to populate input fields
|
2087 |
-
# data = card_data.get('data', {})
|
2088 |
-
#
|
2089 |
-
# # Handle image data
|
2090 |
-
# if 'image' in data:
|
2091 |
-
# # Decode base64 image
|
2092 |
-
# image_data = base64.b64decode(data['image'])
|
2093 |
-
# image = Image.open(io.BytesIO(image_data))
|
2094 |
-
# else:
|
2095 |
-
# image = None
|
2096 |
-
#
|
2097 |
-
# # Prepare values for input fields
|
2098 |
-
# alternate_greetings_str = "\n".join(data.get('alternate_greetings', []))
|
2099 |
-
# tags_str = ", ".join(data.get('tags', []))
|
2100 |
-
# extensions_str = json.dumps(data.get('extensions', {}), indent=2) if data.get('extensions', {}) else ""
|
2101 |
-
#
|
2102 |
-
# outputs = [
|
2103 |
-
# validation_output_msg,
|
2104 |
-
# data.get('name', ''),
|
2105 |
-
# data.get('description', ''),
|
2106 |
-
# data.get('personality', ''),
|
2107 |
-
# data.get('scenario', ''),
|
2108 |
-
# data.get('first_mes', ''),
|
2109 |
-
# data.get('mes_example', ''),
|
2110 |
-
# data.get('creator_notes', ''),
|
2111 |
-
# data.get('system_prompt', ''),
|
2112 |
-
# data.get('post_history_instructions', ''),
|
2113 |
-
# alternate_greetings_str,
|
2114 |
-
# tags_str,
|
2115 |
-
# data.get('creator', ''),
|
2116 |
-
# data.get('character_version', ''),
|
2117 |
-
# extensions_str,
|
2118 |
-
# image
|
2119 |
-
# ]
|
2120 |
-
#
|
2121 |
-
# return outputs
|
2122 |
-
#
|
2123 |
-
# except Exception as e:
|
2124 |
-
# logging.error(f"Error validating character card: {e}")
|
2125 |
-
# return [f"An unexpected error occurred: {e}"] + [gr.update() for _ in range(15)]
|
2126 |
-
#
|
2127 |
-
# # Button Callback for validation
|
2128 |
-
# validate_button.click(
|
2129 |
-
# fn=validate_and_load_character_card,
|
2130 |
-
# inputs=[file_upload],
|
2131 |
-
# outputs=[
|
2132 |
-
# validation_output,
|
2133 |
-
# name_input, description_input, personality_input, scenario_input,
|
2134 |
-
# first_mes_input, mes_example_input, creator_notes_input, system_prompt_input,
|
2135 |
-
# post_history_instructions_input, alternate_greetings_input, tags_input,
|
2136 |
-
# creator_input, character_version_input, extensions_input, image_input
|
2137 |
-
# ]
|
2138 |
-
# )
|
2139 |
-
#
|
2140 |
-
# # Button Callbacks for save, download, etc.
|
2141 |
-
# save_button.click(
|
2142 |
-
# fn=save_character_card,
|
2143 |
-
# inputs=[
|
2144 |
-
# name_input, description_input, personality_input, scenario_input,
|
2145 |
-
# first_mes_input, mes_example_input, creator_notes_input, system_prompt_input,
|
2146 |
-
# post_history_instructions_input, alternate_greetings_input, tags_input,
|
2147 |
-
# creator_input, character_version_input, extensions_input, image_input
|
2148 |
-
# ],
|
2149 |
-
# outputs=[save_status]
|
2150 |
-
# )
|
2151 |
-
#
|
2152 |
-
# download_button.click(
|
2153 |
-
# fn=download_character_card,
|
2154 |
-
# inputs=[
|
2155 |
-
# name_input, description_input, personality_input, scenario_input,
|
2156 |
-
# first_mes_input, mes_example_input, creator_notes_input, system_prompt_input,
|
2157 |
-
# post_history_instructions_input, alternate_greetings_input, tags_input,
|
2158 |
-
# creator_input, character_version_input, extensions_input, image_input
|
2159 |
-
# ],
|
2160 |
-
# outputs=[download_output, save_status]
|
2161 |
-
# )
|
2162 |
-
#
|
2163 |
-
# download_image_button.click(
|
2164 |
-
# fn=download_character_card_as_image,
|
2165 |
-
# inputs=[
|
2166 |
-
# name_input, description_input, personality_input, scenario_input,
|
2167 |
-
# first_mes_input, mes_example_input, creator_notes_input, system_prompt_input,
|
2168 |
-
# post_history_instructions_input, alternate_greetings_input, tags_input,
|
2169 |
-
# creator_input, character_version_input, extensions_input, image_input
|
2170 |
-
# ],
|
2171 |
-
# outputs=[download_image_output, save_status]
|
2172 |
-
# )
|
2173 |
|
2174 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2175 |
#
|
2176 |
# End of Character_Chat_tab.py
|
2177 |
#######################################################################################################################
|
|
|
12 |
import io
|
13 |
import base64
|
14 |
from typing import Dict, Any, Optional, List, Tuple, Union, cast
|
15 |
+
import zipfile
|
16 |
#
|
17 |
# External Imports
|
18 |
from PIL import Image
|
|
|
187 |
# End of Character card import functions
|
188 |
####################################################
|
189 |
|
190 |
+
####################################################
|
191 |
+
#
|
192 |
+
# Character card export functions
|
193 |
+
|
194 |
+
def export_character_as_json(character_id):
|
195 |
+
character = get_character_card_by_id(character_id)
|
196 |
+
if character:
|
197 |
+
# Remove the 'id' field from the character data
|
198 |
+
character_data = {k: v for k, v in character.items() if k != 'id'}
|
199 |
+
|
200 |
+
# Convert image to base64 if it exists
|
201 |
+
if 'image' in character_data and character_data['image']:
|
202 |
+
image_data = base64.b64decode(character_data['image'])
|
203 |
+
img = Image.open(io.BytesIO(image_data))
|
204 |
+
buffered = io.BytesIO()
|
205 |
+
img.save(buffered, format="PNG")
|
206 |
+
character_data['image'] = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
207 |
+
|
208 |
+
json_data = json.dumps(character_data, indent=2)
|
209 |
+
return json_data
|
210 |
+
return None
|
211 |
+
|
212 |
+
def export_all_characters_as_zip():
|
213 |
+
characters = get_character_cards()
|
214 |
+
with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.zip') as temp_zip:
|
215 |
+
with zipfile.ZipFile(temp_zip, 'w') as zf:
|
216 |
+
for character in characters:
|
217 |
+
character_data = {k: v for k, v in character.items() if k != 'id'}
|
218 |
+
|
219 |
+
# Convert image to base64 if it exists
|
220 |
+
if 'image' in character_data and character_data['image']:
|
221 |
+
image_data = base64.b64decode(character_data['image'])
|
222 |
+
img = Image.open(io.BytesIO(image_data))
|
223 |
+
buffered = io.BytesIO()
|
224 |
+
img.save(buffered, format="PNG")
|
225 |
+
character_data['image'] = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
226 |
+
json_data = json.dumps(character_data, indent=2)
|
227 |
+
zf.writestr(f"{character['name']}.json", json_data)
|
228 |
+
return temp_zip.name
|
229 |
+
|
230 |
+
def export_single_character(character_selection):
|
231 |
+
if not character_selection:
|
232 |
+
return None, "No character selected."
|
233 |
+
|
234 |
+
character_id = int(character_selection.split('(ID: ')[1].rstrip(')'))
|
235 |
+
json_data = export_character_as_json(character_id)
|
236 |
+
|
237 |
+
if json_data:
|
238 |
+
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.json', encoding='utf-8') as temp_file:
|
239 |
+
temp_file.write(json_data)
|
240 |
+
return temp_file.name, f"Character '{character_selection.split(' (ID:')[0]}' exported successfully."
|
241 |
+
else:
|
242 |
+
return None, f"Failed to export character '{character_selection.split(' (ID:')[0]}'."
|
243 |
+
|
244 |
+
def export_all_characters():
|
245 |
+
zip_path = export_all_characters_as_zip()
|
246 |
+
return zip_path, "All characters exported successfully."
|
247 |
+
|
248 |
+
#
|
249 |
+
# End of Character card export functions
|
250 |
+
####################################################
|
251 |
|
252 |
####################################################
|
253 |
#
|
|
|
1783 |
inputs=[file_upload],
|
1784 |
outputs=[validation_output]
|
1785 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1786 |
|
1787 |
|
1788 |
+
def create_export_characters_tab():
|
1789 |
+
with gr.TabItem("Export Characters"):
|
1790 |
+
gr.Markdown("# Export Characters")
|
1791 |
+
gr.Markdown("Export character cards individually as JSON files or all together as a ZIP file.")
|
1792 |
+
|
1793 |
+
with gr.Row():
|
1794 |
+
with gr.Column(scale=1):
|
1795 |
+
# Dropdown to select a character for individual export
|
1796 |
+
characters = get_character_cards()
|
1797 |
+
character_choices = [f"{char['name']} (ID: {char['id']})" for char in characters]
|
1798 |
+
export_character_dropdown = gr.Dropdown(
|
1799 |
+
label="Select Character to Export",
|
1800 |
+
choices=character_choices
|
1801 |
+
)
|
1802 |
+
load_characters_button = gr.Button("Load Existing Characters")
|
1803 |
+
export_single_button = gr.Button("Export Selected Character")
|
1804 |
+
export_all_button = gr.Button("Export All Characters")
|
1805 |
+
|
1806 |
+
with gr.Column(scale=1):
|
1807 |
+
# Output components
|
1808 |
+
export_output = gr.File(label="Exported Character(s)", interactive=False)
|
1809 |
+
export_status = gr.Markdown("")
|
1810 |
+
|
1811 |
+
def export_single_character_wrapper(character_selection):
|
1812 |
+
file_path, status_message = export_single_character(character_selection)
|
1813 |
+
if file_path:
|
1814 |
+
return gr.File.update(value=file_path), status_message
|
1815 |
+
else:
|
1816 |
+
return gr.File.update(value=None), status_message
|
1817 |
+
|
1818 |
+
def export_all_characters_wrapper():
|
1819 |
+
zip_path = export_all_characters_as_zip()
|
1820 |
+
characters = get_character_cards()
|
1821 |
+
exported_characters = [char['name'] for char in characters]
|
1822 |
+
status_message = f"Exported {len(exported_characters)} characters successfully:\n" + "\n".join(exported_characters)
|
1823 |
+
return gr.File.update(value=zip_path), status_message
|
1824 |
+
|
1825 |
+
# Event listeners
|
1826 |
+
load_characters_button.click(
|
1827 |
+
fn=lambda: gr.update(choices=[f"{char['name']} (ID: {char['id']})" for char in get_character_cards()]),
|
1828 |
+
outputs=export_character_dropdown
|
1829 |
+
)
|
1830 |
+
|
1831 |
+
export_single_button.click(
|
1832 |
+
fn=export_single_character_wrapper,
|
1833 |
+
inputs=[export_character_dropdown],
|
1834 |
+
outputs=[export_output, export_status]
|
1835 |
+
)
|
1836 |
+
|
1837 |
+
export_all_button.click(
|
1838 |
+
fn=export_all_characters_wrapper,
|
1839 |
+
inputs=[],
|
1840 |
+
outputs=[export_output, export_status]
|
1841 |
+
)
|
1842 |
+
|
1843 |
+
return export_character_dropdown, load_characters_button, export_single_button, export_all_button, export_output, export_status
|
1844 |
+
|
1845 |
#
|
1846 |
# End of Character_Chat_tab.py
|
1847 |
#######################################################################################################################
|
App_Function_Libraries/Gradio_UI/Live_Recording.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Live_Recording.py
|
2 |
+
# Description: Gradio UI for live audio recording and transcription.
|
3 |
+
#
|
4 |
+
# Import necessary modules and functions
|
5 |
+
import logging
|
6 |
+
import os
|
7 |
+
import time
|
8 |
+
|
9 |
+
# External Imports
|
10 |
+
import gradio as gr
|
11 |
+
# Local Imports
|
12 |
+
from App_Function_Libraries.Audio.Audio_Transcription_Lib import (record_audio, speech_to_text, save_audio_temp,
|
13 |
+
stop_recording)
|
14 |
+
from App_Function_Libraries.DB.DB_Manager import add_media_to_database
|
15 |
+
from App_Function_Libraries.Metrics.metrics_logger import log_counter, log_histogram
|
16 |
+
#
|
17 |
+
#######################################################################################################################
|
18 |
+
#
|
19 |
+
# Functions:
|
20 |
+
|
21 |
+
whisper_models = ["small", "medium", "small.en", "medium.en", "medium", "large", "large-v1", "large-v2", "large-v3",
|
22 |
+
"distil-large-v2", "distil-medium.en", "distil-small.en"]
|
23 |
+
|
24 |
+
def create_live_recording_tab():
|
25 |
+
with gr.Tab("Live Recording and Transcription"):
|
26 |
+
gr.Markdown("# Live Audio Recording and Transcription")
|
27 |
+
with gr.Row():
|
28 |
+
with gr.Column():
|
29 |
+
duration = gr.Slider(minimum=1, maximum=8000, value=15, label="Recording Duration (seconds)")
|
30 |
+
whisper_models_input = gr.Dropdown(choices=whisper_models, value="medium", label="Whisper Model")
|
31 |
+
vad_filter = gr.Checkbox(label="Use VAD Filter")
|
32 |
+
save_recording = gr.Checkbox(label="Save Recording")
|
33 |
+
save_to_db = gr.Checkbox(label="Save Transcription to Database(Must be checked to save - can be checked afer transcription)", value=False)
|
34 |
+
custom_title = gr.Textbox(label="Custom Title (for database)", visible=False)
|
35 |
+
record_button = gr.Button("Start Recording")
|
36 |
+
stop_button = gr.Button("Stop Recording")
|
37 |
+
with gr.Column():
|
38 |
+
output = gr.Textbox(label="Transcription", lines=10)
|
39 |
+
audio_output = gr.Audio(label="Recorded Audio", visible=False)
|
40 |
+
|
41 |
+
recording_state = gr.State(value=None)
|
42 |
+
|
43 |
+
def start_recording(duration):
|
44 |
+
log_counter("live_recording_start_attempt", labels={"duration": duration})
|
45 |
+
p, stream, audio_queue, stop_event, audio_thread = record_audio(duration)
|
46 |
+
log_counter("live_recording_start_success", labels={"duration": duration})
|
47 |
+
return (p, stream, audio_queue, stop_event, audio_thread)
|
48 |
+
|
49 |
+
def end_recording_and_transcribe(recording_state, whisper_model, vad_filter, save_recording, save_to_db, custom_title):
|
50 |
+
log_counter("live_recording_end_attempt", labels={"model": whisper_model})
|
51 |
+
start_time = time.time()
|
52 |
+
|
53 |
+
if recording_state is None:
|
54 |
+
log_counter("live_recording_end_error", labels={"error": "Recording hasn't started yet"})
|
55 |
+
return "Recording hasn't started yet.", None
|
56 |
+
|
57 |
+
p, stream, audio_queue, stop_event, audio_thread = recording_state
|
58 |
+
audio_data = stop_recording(p, stream, audio_queue, stop_event, audio_thread)
|
59 |
+
|
60 |
+
temp_file = save_audio_temp(audio_data)
|
61 |
+
segments = speech_to_text(temp_file, whisper_model=whisper_model, vad_filter=vad_filter)
|
62 |
+
transcription = "\n".join([segment["Text"] for segment in segments])
|
63 |
+
|
64 |
+
if save_recording:
|
65 |
+
log_counter("live_recording_saved", labels={"model": whisper_model})
|
66 |
+
else:
|
67 |
+
os.remove(temp_file)
|
68 |
+
|
69 |
+
end_time = time.time() - start_time
|
70 |
+
log_histogram("live_recording_end_duration", end_time, labels={"model": whisper_model})
|
71 |
+
log_counter("live_recording_end_success", labels={"model": whisper_model})
|
72 |
+
return transcription, temp_file if save_recording else None
|
73 |
+
|
74 |
+
def save_transcription_to_db(transcription, custom_title):
|
75 |
+
log_counter("save_transcription_to_db_attempt")
|
76 |
+
start_time = time.time()
|
77 |
+
if custom_title.strip() == "":
|
78 |
+
custom_title = "Self-recorded Audio"
|
79 |
+
|
80 |
+
try:
|
81 |
+
url = "self_recorded"
|
82 |
+
info_dict = {
|
83 |
+
"title": custom_title,
|
84 |
+
"uploader": "self-recorded",
|
85 |
+
"webpage_url": url
|
86 |
+
}
|
87 |
+
segments = [{"Text": transcription}]
|
88 |
+
summary = ""
|
89 |
+
keywords = ["self-recorded", "audio"]
|
90 |
+
custom_prompt_input = ""
|
91 |
+
whisper_model = "self-recorded"
|
92 |
+
media_type = "audio"
|
93 |
+
|
94 |
+
result = add_media_to_database(
|
95 |
+
url=url,
|
96 |
+
info_dict=info_dict,
|
97 |
+
segments=segments,
|
98 |
+
summary=summary,
|
99 |
+
keywords=keywords,
|
100 |
+
custom_prompt_input=custom_prompt_input,
|
101 |
+
whisper_model=whisper_model,
|
102 |
+
media_type=media_type
|
103 |
+
)
|
104 |
+
end_time = time.time() - start_time
|
105 |
+
log_histogram("save_transcription_to_db_duration", end_time)
|
106 |
+
log_counter("save_transcription_to_db_success")
|
107 |
+
return f"Transcription saved to database successfully. {result}"
|
108 |
+
except Exception as e:
|
109 |
+
logging.error(f"Error saving transcription to database: {str(e)}")
|
110 |
+
log_counter("save_transcription_to_db_error", labels={"error": str(e)})
|
111 |
+
return f"Error saving transcription to database: {str(e)}"
|
112 |
+
|
113 |
+
def update_custom_title_visibility(save_to_db):
|
114 |
+
return gr.update(visible=save_to_db)
|
115 |
+
|
116 |
+
record_button.click(
|
117 |
+
fn=start_recording,
|
118 |
+
inputs=[duration],
|
119 |
+
outputs=[recording_state]
|
120 |
+
)
|
121 |
+
|
122 |
+
stop_button.click(
|
123 |
+
fn=end_recording_and_transcribe,
|
124 |
+
inputs=[recording_state, whisper_models_input, vad_filter, save_recording, save_to_db, custom_title],
|
125 |
+
outputs=[output, audio_output]
|
126 |
+
)
|
127 |
+
|
128 |
+
save_to_db.change(
|
129 |
+
fn=update_custom_title_visibility,
|
130 |
+
inputs=[save_to_db],
|
131 |
+
outputs=[custom_title]
|
132 |
+
)
|
133 |
+
|
134 |
+
gr.Button("Save to Database").click(
|
135 |
+
fn=save_transcription_to_db,
|
136 |
+
inputs=[output, custom_title],
|
137 |
+
outputs=gr.Textbox(label="Database Save Status")
|
138 |
+
)
|
139 |
+
|
140 |
+
#
|
141 |
+
# End of Functions
|
142 |
+
########################################################################################################################
|
App_Function_Libraries/Gradio_UI/Llamafile_tab.py
CHANGED
@@ -1,122 +1,312 @@
|
|
1 |
# Llamafile_tab.py
|
2 |
-
# Description:
|
3 |
-
|
4 |
# Imports
|
5 |
import os
|
6 |
-
import
|
7 |
-
|
8 |
-
# External Imports
|
9 |
import gradio as gr
|
10 |
-
|
11 |
-
|
12 |
-
from App_Function_Libraries.Local_LLM.
|
|
|
|
|
|
|
|
|
|
|
13 |
#
|
14 |
#######################################################################################################################
|
15 |
#
|
16 |
# Functions:
|
17 |
|
18 |
-
|
19 |
def create_chat_with_llamafile_tab():
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
with gr.TabItem("Local LLM with Llamafile"):
|
33 |
gr.Markdown("# Settings for Llamafile")
|
|
|
34 |
with gr.Row():
|
35 |
with gr.Column():
|
36 |
-
am_noob = gr.Checkbox(label="
|
37 |
-
|
38 |
-
|
39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
|
41 |
with gr.Column():
|
42 |
-
#
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
)
|
54 |
-
|
55 |
-
|
56 |
with gr.Row():
|
57 |
-
with gr.Column():
|
58 |
-
ngl_checked = gr.Checkbox(label="Enable Setting GPU Layers", value=False, visible=True)
|
59 |
-
ngl_value = gr.Number(label="Number of GPU Layers", value=None, precision=0, visible=True)
|
60 |
-
advanced_inputs = create_llamafile_advanced_inputs()
|
61 |
with gr.Column():
|
62 |
start_button = gr.Button("Start Llamafile")
|
63 |
stop_button = gr.Button("Stop Llamafile (doesn't work)")
|
64 |
output_display = gr.Markdown()
|
65 |
|
66 |
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
|
86 |
-
|
87 |
-
fn=
|
88 |
-
inputs=[],
|
89 |
-
outputs=[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
)
|
91 |
|
92 |
start_button.click(
|
93 |
-
fn=
|
94 |
-
inputs=[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
outputs=output_display
|
96 |
)
|
97 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
hf_repo_checked = gr.Checkbox(label="Use Huggingface Repo Model", value=False, visible=False)
|
106 |
-
hf_repo_value = gr.Textbox(label="Huggingface Repo Name", value="", visible=False)
|
107 |
-
hf_file_checked = gr.Checkbox(label="Set Huggingface Model File", value=False, visible=False)
|
108 |
-
hf_file_value = gr.Textbox(label="Huggingface Model File", value="", visible=False)
|
109 |
-
ctx_size_checked = gr.Checkbox(label="Set Prompt Context Size", value=False, visible=False)
|
110 |
-
ctx_size_value = gr.Number(label="Prompt Context Size", value=8124, precision=0, visible=False)
|
111 |
-
host_checked = gr.Checkbox(label="Set IP to Listen On", value=False, visible=False)
|
112 |
-
host_value = gr.Textbox(label="Host IP Address", value="", visible=False)
|
113 |
-
port_checked = gr.Checkbox(label="Set Server Port", value=False, visible=False)
|
114 |
-
port_value = gr.Number(label="Port Number", value=None, precision=0, visible=False)
|
115 |
-
|
116 |
-
return [verbose_checked, threads_checked, threads_value, http_threads_checked, http_threads_value,
|
117 |
-
hf_repo_checked, hf_repo_value, hf_file_checked, hf_file_value, ctx_size_checked, ctx_size_value,
|
118 |
-
host_checked, host_value, port_checked, port_value]
|
119 |
|
120 |
#
|
121 |
-
#
|
122 |
-
|
|
|
1 |
# Llamafile_tab.py
|
2 |
+
# Description: Gradio interface for configuring and launching Llamafile with Local LLMs
|
3 |
+
|
4 |
# Imports
|
5 |
import os
|
6 |
+
import logging
|
7 |
+
from typing import Tuple, Optional
|
|
|
8 |
import gradio as gr
|
9 |
+
|
10 |
+
|
11 |
+
from App_Function_Libraries.Local_LLM.Local_LLM_Inference_Engine_Lib import (
|
12 |
+
download_llm_model,
|
13 |
+
llm_models,
|
14 |
+
start_llamafile,
|
15 |
+
get_gguf_llamafile_files
|
16 |
+
)
|
17 |
#
|
18 |
#######################################################################################################################
|
19 |
#
|
20 |
# Functions:
|
21 |
|
|
|
22 |
def create_chat_with_llamafile_tab():
|
23 |
+
# Function to update model path based on selection
|
24 |
+
def on_local_model_change(selected_model: str, search_directory: str) -> str:
|
25 |
+
if selected_model and isinstance(search_directory, str):
|
26 |
+
model_path = os.path.abspath(os.path.join(search_directory, selected_model))
|
27 |
+
logging.debug(f"Selected model path: {model_path}") # Debug print for selected model path
|
28 |
+
return model_path
|
29 |
+
return "Invalid selection or directory."
|
30 |
+
|
31 |
+
# Function to update the dropdown with available models
|
32 |
+
def update_dropdowns(search_directory: str) -> Tuple[dict, str]:
|
33 |
+
logging.debug(f"User-entered directory: {search_directory}") # Debug print for directory
|
34 |
+
if not os.path.isdir(search_directory):
|
35 |
+
logging.debug(f"Directory does not exist: {search_directory}") # Debug print for non-existing directory
|
36 |
+
return gr.update(choices=[], value=None), "Directory does not exist."
|
37 |
+
|
38 |
+
logging.debug(f"Directory exists: {search_directory}, scanning for files...") # Confirm directory exists
|
39 |
+
model_files = get_gguf_llamafile_files(search_directory)
|
40 |
+
|
41 |
+
if not model_files:
|
42 |
+
logging.debug(f"No model files found in {search_directory}") # Debug print for no files found
|
43 |
+
return gr.update(choices=[], value=None), "No model files found in the specified directory."
|
44 |
+
|
45 |
+
# Update the dropdown choices with the model files found
|
46 |
+
logging.debug(f"Models loaded from {search_directory}: {model_files}") # Debug: Print model files loaded
|
47 |
+
return gr.update(choices=model_files, value=None), f"Models loaded from {search_directory}."
|
48 |
+
|
49 |
+
|
50 |
+
|
51 |
+
def download_preset_model(selected_model: str) -> Tuple[str, str]:
|
52 |
+
"""
|
53 |
+
Downloads the selected preset model.
|
54 |
+
|
55 |
+
Args:
|
56 |
+
selected_model (str): The key of the selected preset model.
|
57 |
+
|
58 |
+
Returns:
|
59 |
+
Tuple[str, str]: Status message and the path to the downloaded model.
|
60 |
+
"""
|
61 |
+
model_info = llm_models.get(selected_model)
|
62 |
+
if not model_info:
|
63 |
+
return "Invalid model selection.", ""
|
64 |
+
|
65 |
+
try:
|
66 |
+
model_path = download_llm_model(
|
67 |
+
model_name=model_info["name"],
|
68 |
+
model_url=model_info["url"],
|
69 |
+
model_filename=model_info["filename"],
|
70 |
+
model_hash=model_info["hash"]
|
71 |
+
)
|
72 |
+
return f"Model '{model_info['name']}' downloaded successfully.", model_path
|
73 |
+
except Exception as e:
|
74 |
+
logging.error(f"Error downloading model: {e}")
|
75 |
+
return f"Failed to download model: {e}", ""
|
76 |
|
77 |
with gr.TabItem("Local LLM with Llamafile"):
|
78 |
gr.Markdown("# Settings for Llamafile")
|
79 |
+
|
80 |
with gr.Row():
|
81 |
with gr.Column():
|
82 |
+
am_noob = gr.Checkbox(label="Enable Sane Defaults", value=False, visible=True)
|
83 |
+
advanced_mode_toggle = gr.Checkbox(label="Advanced Mode - Show All Settings", value=False)
|
84 |
+
# Advanced Inputs
|
85 |
+
verbose_checked = gr.Checkbox(label="Enable Verbose Output", value=False, visible=False)
|
86 |
+
threads_checked = gr.Checkbox(label="Set CPU Threads", value=False, visible=False)
|
87 |
+
threads_value = gr.Number(label="Number of CPU Threads", value=None, precision=0, visible=False)
|
88 |
+
threads_batched_checked = gr.Checkbox(label="Enable Batched Inference", value=False, visible=False)
|
89 |
+
threads_batched_value = gr.Number(label="Batch Size for Inference", value=None, precision=0, visible=False)
|
90 |
+
model_alias_checked = gr.Checkbox(label="Set Model Alias", value=False, visible=False)
|
91 |
+
model_alias_value = gr.Textbox(label="Model Alias", value="", visible=False)
|
92 |
+
ctx_size_checked = gr.Checkbox(label="Set Prompt Context Size", value=False, visible=False)
|
93 |
+
ctx_size_value = gr.Number(label="Prompt Context Size", value=8124, precision=0, visible=False)
|
94 |
+
ngl_checked = gr.Checkbox(label="Enable GPU Layers", value=False, visible=True)
|
95 |
+
ngl_value = gr.Number(label="Number of GPU Layers", value=None, precision=0, visible=True)
|
96 |
+
batch_size_checked = gr.Checkbox(label="Set Batch Size", value=False, visible=False)
|
97 |
+
batch_size_value = gr.Number(label="Batch Size", value=512, visible=False)
|
98 |
+
memory_f32_checked = gr.Checkbox(label="Use 32-bit Floating Point", value=False, visible=False)
|
99 |
+
numa_checked = gr.Checkbox(label="Enable NUMA", value=False, visible=False)
|
100 |
+
server_timeout_value = gr.Number(label="Server Timeout", value=600, precision=0, visible=False)
|
101 |
+
host_checked = gr.Checkbox(label="Set IP to Listen On", value=False, visible=False)
|
102 |
+
host_value = gr.Textbox(label="Host IP Address", value="", visible=False)
|
103 |
+
port_checked = gr.Checkbox(label="Set Server Port", value=False, visible=False)
|
104 |
+
port_value = gr.Number(label="Port Number", value=8080, precision=0, visible=False)
|
105 |
+
api_key_checked = gr.Checkbox(label="Set API Key", value=False, visible=False)
|
106 |
+
api_key_value = gr.Textbox(label="API Key", value="", visible=False)
|
107 |
+
http_threads_checked = gr.Checkbox(label="Set HTTP Server Threads", value=False, visible=False)
|
108 |
+
http_threads_value = gr.Number(label="Number of HTTP Server Threads", value=None, precision=0, visible=False)
|
109 |
+
hf_repo_checked = gr.Checkbox(label="Use Huggingface Repo Model", value=False, visible=False)
|
110 |
+
hf_repo_value = gr.Textbox(label="Huggingface Repo Name", value="", visible=False)
|
111 |
+
hf_file_checked = gr.Checkbox(label="Set Huggingface Model File", value=False, visible=False)
|
112 |
+
hf_file_value = gr.Textbox(label="Huggingface Model File", value="", visible=False)
|
113 |
|
114 |
with gr.Column():
|
115 |
+
# Model Selection Section
|
116 |
+
gr.Markdown("## Model Selection")
|
117 |
+
|
118 |
+
# Option 1: Select from Local Filesystem
|
119 |
+
with gr.Row():
|
120 |
+
search_directory = gr.Textbox(label="Model Directory",
|
121 |
+
placeholder="Enter directory path(currently '.\Models')",
|
122 |
+
value=".\Models",
|
123 |
+
interactive=True)
|
124 |
+
|
125 |
+
# Initial population of local models
|
126 |
+
initial_dropdown_update, _ = update_dropdowns(".\Models")
|
127 |
+
refresh_button = gr.Button("Refresh Models")
|
128 |
+
local_model_dropdown = gr.Dropdown(label="Select Model from Directory", choices=[])
|
129 |
+
# Display selected model path
|
130 |
+
model_value = gr.Textbox(label="Selected Model File Path", value="", interactive=False)
|
131 |
+
|
132 |
+
# Option 2: Download Preset Models
|
133 |
+
gr.Markdown("## Download Preset Models")
|
134 |
+
|
135 |
+
preset_model_dropdown = gr.Dropdown(
|
136 |
+
label="Select a Preset Model",
|
137 |
+
choices=list(llm_models.keys()),
|
138 |
+
value=None,
|
139 |
+
interactive=True,
|
140 |
+
info="Choose a preset model to download."
|
141 |
)
|
142 |
+
download_preset_button = gr.Button("Download Selected Preset")
|
143 |
+
|
144 |
with gr.Row():
|
|
|
|
|
|
|
|
|
145 |
with gr.Column():
|
146 |
start_button = gr.Button("Start Llamafile")
|
147 |
stop_button = gr.Button("Stop Llamafile (doesn't work)")
|
148 |
output_display = gr.Markdown()
|
149 |
|
150 |
|
151 |
+
# Show/hide advanced inputs based on toggle
|
152 |
+
def update_visibility(show_advanced: bool):
|
153 |
+
components = [
|
154 |
+
verbose_checked, threads_checked, threads_value,
|
155 |
+
http_threads_checked, http_threads_value,
|
156 |
+
hf_repo_checked, hf_repo_value,
|
157 |
+
hf_file_checked, hf_file_value,
|
158 |
+
ctx_size_checked, ctx_size_value,
|
159 |
+
ngl_checked, ngl_value,
|
160 |
+
host_checked, host_value,
|
161 |
+
port_checked, port_value
|
162 |
+
]
|
163 |
+
return [gr.update(visible=show_advanced) for _ in components]
|
164 |
|
165 |
+
def on_start_button_click(
|
166 |
+
am_noob: bool,
|
167 |
+
verbose_checked: bool,
|
168 |
+
threads_checked: bool,
|
169 |
+
threads_value: Optional[int],
|
170 |
+
threads_batched_checked: bool,
|
171 |
+
threads_batched_value: Optional[int],
|
172 |
+
model_alias_checked: bool,
|
173 |
+
model_alias_value: str,
|
174 |
+
http_threads_checked: bool,
|
175 |
+
http_threads_value: Optional[int],
|
176 |
+
model_value: str,
|
177 |
+
hf_repo_checked: bool,
|
178 |
+
hf_repo_value: str,
|
179 |
+
hf_file_checked: bool,
|
180 |
+
hf_file_value: str,
|
181 |
+
ctx_size_checked: bool,
|
182 |
+
ctx_size_value: Optional[int],
|
183 |
+
ngl_checked: bool,
|
184 |
+
ngl_value: Optional[int],
|
185 |
+
batch_size_checked: bool,
|
186 |
+
batch_size_value: Optional[int],
|
187 |
+
memory_f32_checked: bool,
|
188 |
+
numa_checked: bool,
|
189 |
+
server_timeout_value: Optional[int],
|
190 |
+
host_checked: bool,
|
191 |
+
host_value: str,
|
192 |
+
port_checked: bool,
|
193 |
+
port_value: Optional[int],
|
194 |
+
api_key_checked: bool,
|
195 |
+
api_key_value: str
|
196 |
+
) -> str:
|
197 |
+
"""
|
198 |
+
Event handler for the Start Llamafile button.
|
199 |
+
"""
|
200 |
+
try:
|
201 |
+
result = start_llamafile(
|
202 |
+
am_noob,
|
203 |
+
verbose_checked,
|
204 |
+
threads_checked,
|
205 |
+
threads_value,
|
206 |
+
threads_batched_checked,
|
207 |
+
threads_batched_value,
|
208 |
+
model_alias_checked,
|
209 |
+
model_alias_value,
|
210 |
+
http_threads_checked,
|
211 |
+
http_threads_value,
|
212 |
+
model_value,
|
213 |
+
hf_repo_checked,
|
214 |
+
hf_repo_value,
|
215 |
+
hf_file_checked,
|
216 |
+
hf_file_value,
|
217 |
+
ctx_size_checked,
|
218 |
+
ctx_size_value,
|
219 |
+
ngl_checked,
|
220 |
+
ngl_value,
|
221 |
+
batch_size_checked,
|
222 |
+
batch_size_value,
|
223 |
+
memory_f32_checked,
|
224 |
+
numa_checked,
|
225 |
+
server_timeout_value,
|
226 |
+
host_checked,
|
227 |
+
host_value,
|
228 |
+
port_checked,
|
229 |
+
port_value,
|
230 |
+
api_key_checked,
|
231 |
+
api_key_value
|
232 |
+
)
|
233 |
+
return result
|
234 |
+
except Exception as e:
|
235 |
+
logging.error(f"Error starting Llamafile: {e}")
|
236 |
+
return f"Failed to start Llamafile: {e}"
|
237 |
|
238 |
+
advanced_mode_toggle.change(
|
239 |
+
fn=update_visibility,
|
240 |
+
inputs=[advanced_mode_toggle],
|
241 |
+
outputs=[
|
242 |
+
verbose_checked, threads_checked, threads_value,
|
243 |
+
http_threads_checked, http_threads_value,
|
244 |
+
hf_repo_checked, hf_repo_value,
|
245 |
+
hf_file_checked, hf_file_value,
|
246 |
+
ctx_size_checked, ctx_size_value,
|
247 |
+
ngl_checked, ngl_value,
|
248 |
+
host_checked, host_value,
|
249 |
+
port_checked, port_value
|
250 |
+
]
|
251 |
)
|
252 |
|
253 |
start_button.click(
|
254 |
+
fn=on_start_button_click,
|
255 |
+
inputs=[
|
256 |
+
am_noob,
|
257 |
+
verbose_checked,
|
258 |
+
threads_checked,
|
259 |
+
threads_value,
|
260 |
+
threads_batched_checked,
|
261 |
+
threads_batched_value,
|
262 |
+
model_alias_checked,
|
263 |
+
model_alias_value,
|
264 |
+
http_threads_checked,
|
265 |
+
http_threads_value,
|
266 |
+
model_value,
|
267 |
+
hf_repo_checked,
|
268 |
+
hf_repo_value,
|
269 |
+
hf_file_checked,
|
270 |
+
hf_file_value,
|
271 |
+
ctx_size_checked,
|
272 |
+
ctx_size_value,
|
273 |
+
ngl_checked,
|
274 |
+
ngl_value,
|
275 |
+
batch_size_checked,
|
276 |
+
batch_size_value,
|
277 |
+
memory_f32_checked,
|
278 |
+
numa_checked,
|
279 |
+
server_timeout_value,
|
280 |
+
host_checked,
|
281 |
+
host_value,
|
282 |
+
port_checked,
|
283 |
+
port_value,
|
284 |
+
api_key_checked,
|
285 |
+
api_key_value
|
286 |
+
],
|
287 |
outputs=output_display
|
288 |
)
|
289 |
|
290 |
+
download_preset_button.click(
|
291 |
+
fn=download_preset_model,
|
292 |
+
inputs=[preset_model_dropdown],
|
293 |
+
outputs=[output_display, model_value]
|
294 |
+
)
|
295 |
+
|
296 |
+
# Click event for refreshing models
|
297 |
+
refresh_button.click(
|
298 |
+
fn=update_dropdowns,
|
299 |
+
inputs=[search_directory], # Ensure that the directory path (string) is passed
|
300 |
+
outputs=[local_model_dropdown, output_display] # Update dropdown and status
|
301 |
+
)
|
302 |
|
303 |
+
# Event to update model_value when a model is selected from the dropdown
|
304 |
+
local_model_dropdown.change(
|
305 |
+
fn=on_local_model_change, # Function that calculates the model path
|
306 |
+
inputs=[local_model_dropdown, search_directory], # Inputs: selected model and directory
|
307 |
+
outputs=[model_value] # Output: Update the model_value textbox with the selected model path
|
308 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
309 |
|
310 |
#
|
311 |
+
#
|
312 |
+
#######################################################################################################################
|
App_Function_Libraries/Gradio_UI/MMLU_Pro_tab.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# MMLU_Pro_tab.py
|
2 |
+
# is a library that contains the Gradio UI code for the MMLU-Pro benchmarking tool.
|
3 |
+
#
|
4 |
+
##############################################################################################################
|
5 |
+
# Imports
|
6 |
+
import os
|
7 |
+
|
8 |
+
import gradio as gr
|
9 |
+
import logging
|
10 |
+
#
|
11 |
+
# External Imports
|
12 |
+
from tqdm import tqdm
|
13 |
+
# Local Imports
|
14 |
+
from App_Function_Libraries.Benchmarks_Evaluations.MMLU_Pro.MMLU_Pro_rewritten import (
|
15 |
+
load_mmlu_pro, run_mmlu_pro_benchmark, mmlu_pro_main, load_mmlu_pro_config
|
16 |
+
)
|
17 |
+
#
|
18 |
+
##############################################################################################################
|
19 |
+
#
|
20 |
+
# Functions:
|
21 |
+
|
22 |
+
# Set up logging
|
23 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
24 |
+
logger = logging.getLogger(__name__)
|
25 |
+
|
26 |
+
|
27 |
+
def get_categories():
|
28 |
+
"""Fetch categories using the dataset loader from MMLU_Pro_rewritten.py"""
|
29 |
+
try:
|
30 |
+
test_data, _ = load_mmlu_pro() # Use the function from MMLU_Pro_rewritten.py
|
31 |
+
return list(test_data.keys()) # Return the categories from the test dataset
|
32 |
+
except Exception as e:
|
33 |
+
logger.error(f"Failed to load categories: {e}")
|
34 |
+
return ["Error loading categories"]
|
35 |
+
|
36 |
+
|
37 |
+
def load_categories():
|
38 |
+
"""Helper function to return the categories for the Gradio dropdown."""
|
39 |
+
categories = get_categories() # Fetch categories from the dataset
|
40 |
+
if categories:
|
41 |
+
return gr.update(choices=categories, value=categories[0]) # Update dropdown with categories
|
42 |
+
else:
|
43 |
+
return gr.update(choices=["Error loading categories"], value="Error loading categories")
|
44 |
+
|
45 |
+
|
46 |
+
def run_benchmark_from_ui(url, api_key, model, timeout, category, parallel, verbosity, log_prompt):
|
47 |
+
"""Function to run the benchmark with parameters from the UI."""
|
48 |
+
|
49 |
+
# Override config with UI parameters
|
50 |
+
config = load_mmlu_pro_config(
|
51 |
+
url=url,
|
52 |
+
api_key=api_key,
|
53 |
+
model=model,
|
54 |
+
timeout=timeout,
|
55 |
+
categories=[category] if category else None,
|
56 |
+
parallel=parallel,
|
57 |
+
verbosity=verbosity,
|
58 |
+
log_prompt=log_prompt
|
59 |
+
)
|
60 |
+
|
61 |
+
# Run the benchmarking process
|
62 |
+
try:
|
63 |
+
# Call the main benchmarking function
|
64 |
+
mmlu_pro_main()
|
65 |
+
|
66 |
+
# Assume the final report is generated in "eval_results" folder
|
67 |
+
report_path = os.path.join("eval_results", config["server"]["model"].replace("/", "-"), "final_report.txt")
|
68 |
+
|
69 |
+
# Read the final report
|
70 |
+
with open(report_path, "r") as f:
|
71 |
+
report = f.read()
|
72 |
+
|
73 |
+
return report
|
74 |
+
except Exception as e:
|
75 |
+
logger.error(f"An error occurred during benchmark execution: {e}")
|
76 |
+
return f"An error occurred during benchmark execution. Please check the logs for more information. Error: {str(e)}"
|
77 |
+
|
78 |
+
|
79 |
+
def create_mmlu_pro_tab():
|
80 |
+
"""Create the Gradio UI tab for MMLU-Pro Benchmark."""
|
81 |
+
with gr.Tab("MMLU-Pro Benchmark"):
|
82 |
+
gr.Markdown("## Run MMLU-Pro Benchmark")
|
83 |
+
|
84 |
+
with gr.Row():
|
85 |
+
with gr.Column():
|
86 |
+
# Inputs for the benchmark
|
87 |
+
url = gr.Textbox(label="Server URL")
|
88 |
+
api_key = gr.Textbox(label="API Key", type="password")
|
89 |
+
model = gr.Textbox(label="Model Name")
|
90 |
+
timeout = gr.Number(label="Timeout (seconds)", value=30)
|
91 |
+
category = gr.Dropdown(label="Category", choices=["Load categories..."])
|
92 |
+
load_categories_btn = gr.Button("Load Categories")
|
93 |
+
parallel = gr.Slider(label="Parallel Requests", minimum=1, maximum=10, step=1, value=1)
|
94 |
+
verbosity = gr.Slider(label="Verbosity Level", minimum=0, maximum=2, step=1, value=1)
|
95 |
+
log_prompt = gr.Checkbox(label="Log Prompt")
|
96 |
+
|
97 |
+
with gr.Column():
|
98 |
+
# Run button and output display
|
99 |
+
run_button = gr.Button("Run Benchmark")
|
100 |
+
output = gr.Textbox(label="Benchmark Results", lines=20)
|
101 |
+
|
102 |
+
# When "Load Categories" is clicked, load the categories into the dropdown
|
103 |
+
load_categories_btn.click(
|
104 |
+
load_categories,
|
105 |
+
outputs=category
|
106 |
+
)
|
107 |
+
|
108 |
+
# When "Run Benchmark" is clicked, trigger the run_benchmark_from_ui function
|
109 |
+
run_button.click(
|
110 |
+
run_benchmark_from_ui, # Use the function defined to run the benchmark
|
111 |
+
inputs=[url, api_key, model, timeout, category, parallel, verbosity, log_prompt],
|
112 |
+
outputs=output
|
113 |
+
)
|
114 |
+
|
115 |
+
return [url, api_key, model, timeout, category, parallel, verbosity, log_prompt, run_button, output]
|
App_Function_Libraries/Gradio_UI/RAG_QA_Chat_Notes.py
ADDED
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# RAG_QA_Chat_Notes.py
|
2 |
+
# Description: This file contains the code for the RAG QA Chat Notes tab in the RAG QA Chat application.
|
3 |
+
#
|
4 |
+
# Imports
|
5 |
+
import logging
|
6 |
+
# External Imports
|
7 |
+
import gradio as gr
|
8 |
+
#
|
9 |
+
# Local Imports
|
10 |
+
from App_Function_Libraries.DB.RAG_QA_Chat_DB import save_message, add_keywords_to_conversation, \
|
11 |
+
search_conversations_by_keywords, load_chat_history, save_notes, get_notes, clear_notes, \
|
12 |
+
add_keywords_to_note, execute_query, start_new_conversation
|
13 |
+
from App_Function_Libraries.RAG.RAG_QA_Chat import rag_qa_chat
|
14 |
+
#
|
15 |
+
####################################################################################################
|
16 |
+
#
|
17 |
+
# Functions
|
18 |
+
def create_rag_qa_chat_notes_tab():
|
19 |
+
with gr.TabItem("RAG QA Chat"):
|
20 |
+
gr.Markdown("# RAG QA Chat")
|
21 |
+
|
22 |
+
state = gr.State({
|
23 |
+
"conversation_id": None,
|
24 |
+
"page": 1,
|
25 |
+
"context_source": "Entire Media Database",
|
26 |
+
})
|
27 |
+
|
28 |
+
with gr.Row():
|
29 |
+
with gr.Column(scale=1):
|
30 |
+
context_source = gr.Radio(
|
31 |
+
["Entire Media Database", "Search Database", "Upload File"],
|
32 |
+
label="Context Source",
|
33 |
+
value="Entire Media Database"
|
34 |
+
)
|
35 |
+
existing_file = gr.Dropdown(label="Select Existing File", choices=[], interactive=True)
|
36 |
+
file_page = gr.State(value=1)
|
37 |
+
with gr.Row():
|
38 |
+
page_number = gr.Number(value=1, label="Page", precision=0)
|
39 |
+
page_size = gr.Number(value=20, label="Items per page", precision=0)
|
40 |
+
total_pages = gr.Number(label="Total Pages", interactive=False)
|
41 |
+
with gr.Row():
|
42 |
+
prev_page_btn = gr.Button("Previous Page")
|
43 |
+
next_page_btn = gr.Button("Next Page")
|
44 |
+
page_info = gr.HTML("Page 1")
|
45 |
+
|
46 |
+
search_query = gr.Textbox(label="Search Query", visible=False)
|
47 |
+
search_button = gr.Button("Search", visible=False)
|
48 |
+
search_results = gr.Dropdown(label="Search Results", choices=[], visible=False)
|
49 |
+
file_upload = gr.File(
|
50 |
+
label="Upload File",
|
51 |
+
visible=False,
|
52 |
+
file_types=["txt", "pdf", "epub", "md", "rtf", "json", "csv"]
|
53 |
+
)
|
54 |
+
convert_to_text = gr.Checkbox(label="Convert to plain text", visible=False)
|
55 |
+
keywords = gr.Textbox(label="Keywords (comma-separated)", visible=False)
|
56 |
+
with gr.Column(scale=1):
|
57 |
+
api_choice = gr.Dropdown(
|
58 |
+
choices=["Local-LLM", "OpenAI", "Anthropic", "Cohere", "Groq", "DeepSeek", "Mistral", "OpenRouter",
|
59 |
+
"Llama.cpp", "Kobold", "Ooba", "Tabbyapi", "VLLM", "ollama", "HuggingFace"],
|
60 |
+
label="Select API for RAG",
|
61 |
+
value="OpenAI"
|
62 |
+
)
|
63 |
+
use_query_rewriting = gr.Checkbox(label="Use Query Rewriting", value=True)
|
64 |
+
|
65 |
+
# FIXME - add load conversations button
|
66 |
+
load_conversation = gr.Dropdown(label="Load Conversation", choices=[])
|
67 |
+
new_conversation = gr.Button("New Conversation")
|
68 |
+
conversation_title = gr.Textbox(label="Conversation Title",
|
69 |
+
placeholder="Enter a title for the new conversation")
|
70 |
+
|
71 |
+
with gr.Row():
|
72 |
+
with gr.Column(scale=2):
|
73 |
+
chatbot = gr.Chatbot(height=500)
|
74 |
+
msg = gr.Textbox(label="Enter your message")
|
75 |
+
submit = gr.Button("Submit")
|
76 |
+
clear_chat = gr.Button("Clear Chat History")
|
77 |
+
|
78 |
+
with gr.Column(scale=1):
|
79 |
+
notes = gr.TextArea(label="Notes", placeholder="Enter your notes here...", lines=20)
|
80 |
+
keywords_for_notes = gr.Textbox(label="Keywords for Notes (comma-separated)",
|
81 |
+
placeholder="Enter keywords for the note", visible=True)
|
82 |
+
save_notes_btn = gr.Button("Save Notes") # Renamed to avoid conflict
|
83 |
+
clear_notes_btn = gr.Button("Clear Notes") # Renamed to avoid conflict
|
84 |
+
|
85 |
+
loading_indicator = gr.HTML(visible=False)
|
86 |
+
|
87 |
+
def rag_qa_chat_wrapper(message, history, state, context_source, existing_file, search_results, file_upload,
|
88 |
+
convert_to_text, keywords, api_choice, use_query_rewriting):
|
89 |
+
try:
|
90 |
+
conversation_id = state.value["conversation_id"]
|
91 |
+
if not conversation_id:
|
92 |
+
conversation_id = start_new_conversation("Untitled Conversation") # Provide a title or handle accordingly
|
93 |
+
state = update_state(state, conversation_id=conversation_id)
|
94 |
+
|
95 |
+
save_message(conversation_id, 'human', message)
|
96 |
+
|
97 |
+
if keywords:
|
98 |
+
add_keywords_to_conversation(conversation_id, [kw.strip() for kw in keywords.split(',')])
|
99 |
+
|
100 |
+
# Implement your actual RAG logic here
|
101 |
+
response = "response"#rag_qa_chat(message, conversation_id, context_source, existing_file, search_results,
|
102 |
+
#file_upload, convert_to_text, api_choice, use_query_rewriting)
|
103 |
+
|
104 |
+
save_message(conversation_id, 'ai', response)
|
105 |
+
|
106 |
+
new_history = history + [(message, response)]
|
107 |
+
|
108 |
+
logging.info(f"Successfully processed message for conversation '{conversation_id}'")
|
109 |
+
return new_history, "", gr.update(visible=False), state
|
110 |
+
|
111 |
+
except Exception as e:
|
112 |
+
logging.error(f"Error in rag_qa_chat_wrapper: {e}")
|
113 |
+
gr.Error("An unexpected error occurred. Please try again later.")
|
114 |
+
return history, "", gr.update(visible=False), state
|
115 |
+
|
116 |
+
def load_conversation_history(selected_conversation_id, page, page_size, state):
|
117 |
+
if selected_conversation_id:
|
118 |
+
history, total_pages_val, _ = load_chat_history(selected_conversation_id, page, page_size)
|
119 |
+
notes_content = get_notes(selected_conversation_id) # Retrieve notes here
|
120 |
+
updated_state = update_state(state, conversation_id=selected_conversation_id, page=page)
|
121 |
+
return history, total_pages_val, updated_state, "\n".join(notes_content)
|
122 |
+
return [], 1, state, ""
|
123 |
+
|
124 |
+
def start_new_conversation_wrapper(title, state):
|
125 |
+
new_conversation_id = start_new_conversation(title if title else "Untitled Conversation")
|
126 |
+
return [], update_state(state, conversation_id=new_conversation_id, page=1)
|
127 |
+
|
128 |
+
def update_state(state, **kwargs):
|
129 |
+
new_state = state.value.copy()
|
130 |
+
new_state.update(kwargs)
|
131 |
+
return new_state
|
132 |
+
|
133 |
+
def update_page(direction, current_page, total_pages_val):
|
134 |
+
new_page = max(1, min(current_page + direction, total_pages_val))
|
135 |
+
return new_page
|
136 |
+
|
137 |
+
def update_context_source(choice):
|
138 |
+
return {
|
139 |
+
existing_file: gr.update(visible=choice == "Select Existing File"),
|
140 |
+
prev_page_btn: gr.update(visible=choice == "Search Database"),
|
141 |
+
next_page_btn: gr.update(visible=choice == "Search Database"),
|
142 |
+
page_info: gr.update(visible=choice == "Search Database"),
|
143 |
+
search_query: gr.update(visible=choice == "Search Database"),
|
144 |
+
search_button: gr.update(visible=choice == "Search Database"),
|
145 |
+
search_results: gr.update(visible=choice == "Search Database"),
|
146 |
+
file_upload: gr.update(visible=choice == "Upload File"),
|
147 |
+
convert_to_text: gr.update(visible=choice == "Upload File"),
|
148 |
+
keywords: gr.update(visible=choice == "Upload File")
|
149 |
+
}
|
150 |
+
|
151 |
+
def perform_search(query):
|
152 |
+
try:
|
153 |
+
results = search_conversations_by_keywords([kw.strip() for kw in query.split()])
|
154 |
+
return gr.update(choices=[f"{title} (ID: {id})" for id, title in results[0]])
|
155 |
+
except Exception as e:
|
156 |
+
logging.error(f"Error performing search: {e}")
|
157 |
+
gr.Error(f"Error performing search: {str(e)}")
|
158 |
+
return gr.update(choices=[])
|
159 |
+
|
160 |
+
def clear_chat_history():
|
161 |
+
return [], ""
|
162 |
+
|
163 |
+
def save_notes_function(notes_content, keywords_content):
|
164 |
+
"""Save the notes and associated keywords to the database."""
|
165 |
+
conversation_id = state.value["conversation_id"]
|
166 |
+
if conversation_id and notes_content:
|
167 |
+
# Save the note
|
168 |
+
save_notes(conversation_id, notes_content)
|
169 |
+
|
170 |
+
# Get the last inserted note ID
|
171 |
+
query = "SELECT id FROM rag_qa_notes WHERE conversation_id = ? ORDER BY timestamp DESC LIMIT 1"
|
172 |
+
note_id = execute_query(query, (conversation_id,))[0][0]
|
173 |
+
|
174 |
+
if keywords_content:
|
175 |
+
add_keywords_to_note(note_id, [kw.strip() for kw in keywords_content.split(',')])
|
176 |
+
|
177 |
+
logging.info("Notes and keywords saved successfully!")
|
178 |
+
return notes_content
|
179 |
+
else:
|
180 |
+
logging.warning("No conversation ID or notes to save.")
|
181 |
+
return ""
|
182 |
+
|
183 |
+
def clear_notes_function():
|
184 |
+
"""Clear notes for the current conversation."""
|
185 |
+
conversation_id = state.value["conversation_id"]
|
186 |
+
if conversation_id:
|
187 |
+
clear_notes(conversation_id)
|
188 |
+
logging.info("Notes cleared successfully!")
|
189 |
+
return ""
|
190 |
+
|
191 |
+
# Event handlers
|
192 |
+
submit.click(
|
193 |
+
rag_qa_chat_wrapper,
|
194 |
+
inputs=[msg, chatbot, state, context_source, existing_file, search_results, file_upload,
|
195 |
+
convert_to_text, keywords, api_choice, use_query_rewriting],
|
196 |
+
outputs=[chatbot, msg, loading_indicator, state]
|
197 |
+
)
|
198 |
+
|
199 |
+
load_conversation.change(
|
200 |
+
load_conversation_history,
|
201 |
+
inputs=[load_conversation, page_number, page_size, state],
|
202 |
+
outputs=[chatbot, total_pages, state, notes]
|
203 |
+
)
|
204 |
+
|
205 |
+
new_conversation.click(
|
206 |
+
start_new_conversation_wrapper,
|
207 |
+
inputs=[conversation_title, state],
|
208 |
+
outputs=[chatbot, state]
|
209 |
+
)
|
210 |
+
|
211 |
+
# Pagination Event handlers
|
212 |
+
prev_page_btn.click(
|
213 |
+
lambda current_page, total_pages_val: update_page(-1, current_page, total_pages_val),
|
214 |
+
inputs=[page_number, total_pages],
|
215 |
+
outputs=[page_number]
|
216 |
+
)
|
217 |
+
|
218 |
+
next_page_btn.click(
|
219 |
+
lambda current_page, total_pages_val: update_page(1, current_page, total_pages_val),
|
220 |
+
inputs=[page_number, total_pages],
|
221 |
+
outputs=[page_number]
|
222 |
+
)
|
223 |
+
|
224 |
+
context_source.change(update_context_source, inputs=[context_source],
|
225 |
+
outputs=[existing_file, prev_page_btn, next_page_btn, page_info,
|
226 |
+
search_query, search_button, search_results,
|
227 |
+
file_upload, convert_to_text, keywords])
|
228 |
+
|
229 |
+
search_button.click(perform_search, inputs=[search_query], outputs=[search_results])
|
230 |
+
|
231 |
+
clear_chat.click(clear_chat_history, outputs=[chatbot, msg])
|
232 |
+
|
233 |
+
save_notes_btn.click(save_notes_function, inputs=[notes, keywords_for_notes], outputs=[notes])
|
234 |
+
clear_notes_btn.click(clear_notes_function, outputs=[notes])
|
235 |
+
|
236 |
+
return (context_source, existing_file, search_query, search_button, search_results, file_upload,
|
237 |
+
convert_to_text, keywords, api_choice, use_query_rewriting, chatbot, msg, submit, clear_chat,
|
238 |
+
notes, save_notes_btn, clear_notes_btn, load_conversation, new_conversation, conversation_title,
|
239 |
+
prev_page_btn, next_page_btn, page_number, page_size, total_pages)
|
240 |
+
|
241 |
+
#
|
242 |
+
# End of RAG_QA_Chat_Notes.py
|
243 |
+
####################################################################################################
|
App_Function_Libraries/Gradio_UI/Utilities.py
CHANGED
@@ -10,7 +10,7 @@ from App_Function_Libraries.Utils.Utils import sanitize_filename, downloaded_fil
|
|
10 |
|
11 |
|
12 |
def create_utilities_yt_video_tab():
|
13 |
-
with gr.Tab("YouTube Video Downloader"):
|
14 |
with gr.Row():
|
15 |
with gr.Column():
|
16 |
gr.Markdown(
|
@@ -28,7 +28,7 @@ def create_utilities_yt_video_tab():
|
|
28 |
)
|
29 |
|
30 |
def create_utilities_yt_audio_tab():
|
31 |
-
with gr.Tab("YouTube Audio Downloader"):
|
32 |
with gr.Row():
|
33 |
with gr.Column():
|
34 |
gr.Markdown(
|
@@ -48,7 +48,7 @@ def create_utilities_yt_audio_tab():
|
|
48 |
)
|
49 |
|
50 |
def create_utilities_yt_timestamp_tab():
|
51 |
-
with gr.Tab("YouTube Timestamp URL Generator"):
|
52 |
gr.Markdown("## Generate YouTube URL with Timestamp")
|
53 |
with gr.Row():
|
54 |
with gr.Column():
|
|
|
10 |
|
11 |
|
12 |
def create_utilities_yt_video_tab():
|
13 |
+
with gr.Tab("YouTube Video Downloader", id='youtube_dl'):
|
14 |
with gr.Row():
|
15 |
with gr.Column():
|
16 |
gr.Markdown(
|
|
|
28 |
)
|
29 |
|
30 |
def create_utilities_yt_audio_tab():
|
31 |
+
with gr.Tab("YouTube Audio Downloader", id="youtube audio downloader"):
|
32 |
with gr.Row():
|
33 |
with gr.Column():
|
34 |
gr.Markdown(
|
|
|
48 |
)
|
49 |
|
50 |
def create_utilities_yt_timestamp_tab():
|
51 |
+
with gr.Tab("YouTube Timestamp URL Generator", id="timestamp-gen"):
|
52 |
gr.Markdown("## Generate YouTube URL with Timestamp")
|
53 |
with gr.Row():
|
54 |
with gr.Column():
|
App_Function_Libraries/Local_LLM/Local_LLM_Inference_Engine_Lib.py
ADDED
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Local_LLM_Inference_Engine_Lib.py
|
2 |
+
#########################################
|
3 |
+
# Local LLM Inference Engine Library
|
4 |
+
# This library is used to handle downloading, configuring, and launching the Local LLM Inference Engine
|
5 |
+
# via (llama.cpp via llamafile)
|
6 |
+
#
|
7 |
+
#
|
8 |
+
####
|
9 |
+
####################
|
10 |
+
# Function List
|
11 |
+
#
|
12 |
+
# 1.
|
13 |
+
#
|
14 |
+
####################
|
15 |
+
# Import necessary libraries
|
16 |
+
#import atexit
|
17 |
+
import glob
|
18 |
+
import logging
|
19 |
+
import os
|
20 |
+
import re
|
21 |
+
import signal
|
22 |
+
import subprocess
|
23 |
+
import sys
|
24 |
+
import time
|
25 |
+
from typing import List, Optional
|
26 |
+
#
|
27 |
+
# Import 3rd-pary Libraries
|
28 |
+
import requests
|
29 |
+
#
|
30 |
+
# Import Local
|
31 |
+
from App_Function_Libraries.Web_Scraping.Article_Summarization_Lib import *
|
32 |
+
from App_Function_Libraries.Utils.Utils import download_file
|
33 |
+
#
|
34 |
+
#######################################################################################################################
|
35 |
+
# Function Definitions:
|
36 |
+
|
37 |
+
|
38 |
+
###############################################################
|
39 |
+
# LLM models information
|
40 |
+
|
41 |
+
llm_models = {
|
42 |
+
"Mistral-7B-Instruct-v0.2-Q8.llamafile": {
|
43 |
+
"name": "Mistral-7B-Instruct-v0.2-Q8.llamafile",
|
44 |
+
"url": "https://huggingface.co/Mozilla/Mistral-7B-Instruct-v0.2-llamafile/resolve/main/mistral-7b-instruct-v0.2.Q8_0.llamafile?download=true",
|
45 |
+
"filename": "mistral-7b-instruct-v0.2.Q8_0.llamafile",
|
46 |
+
"hash": "1ee6114517d2f770425c880e5abc443da36b193c82abec8e2885dd7ce3b9bfa6"
|
47 |
+
},
|
48 |
+
"Samantha-Mistral-Instruct-7B-Bulleted-Notes-Q8.gguf": {
|
49 |
+
"name": "Samantha-Mistral-Instruct-7B-Bulleted-Notes-Q8.gguf",
|
50 |
+
"url": "https://huggingface.co/cognitivetech/samantha-mistral-instruct-7b-bulleted-notes-GGUF/resolve/main/samantha-mistral-instruct-7b-bulleted-notes.Q8_0.gguf?download=true",
|
51 |
+
"filename": "samantha-mistral-instruct-7b-bulleted-notes.Q8_0.gguf",
|
52 |
+
"hash": "6334c1ab56c565afd86535271fab52b03e67a5e31376946bce7bf5c144e847e4"
|
53 |
+
},
|
54 |
+
"Phi-3-mini-128k-instruct-Q8_0.gguf": {
|
55 |
+
"name": "Phi-3-mini-128k-instruct-Q8_0.gguf",
|
56 |
+
"url": "https://huggingface.co/gaianet/Phi-3-mini-128k-instruct-GGUF/resolve/main/Phi-3-mini-128k-instruct-Q8_0.gguf?download=true",
|
57 |
+
"filename": "Phi-3-mini-128k-instruct-Q8_0.gguf",
|
58 |
+
"hash": "6817b66d1c3c59ab06822e9732f0e594eea44e64cae2110906eac9d17f75d193"
|
59 |
+
},
|
60 |
+
"Meta-Llama-3-8B-Instruct.Q8_0.llamafile": {
|
61 |
+
"name": "Meta-Llama-3-8B-Instruct.Q8_0.llamafile",
|
62 |
+
"url": "https://huggingface.co/Mozilla/Meta-Llama-3-8B-Instruct-llamafile/resolve/main/Meta-Llama-3-8B-Instruct.Q8_0.llamafile?download=true",
|
63 |
+
"filename": "Meta-Llama-3-8B-Instruct.Q8_0.llamafile",
|
64 |
+
"hash": "406868a97f02f57183716c7e4441d427f223fdbc7fa42964ef10c4d60dd8ed37"
|
65 |
+
}
|
66 |
+
}
|
67 |
+
#
|
68 |
+
###############################################################
|
69 |
+
|
70 |
+
# Function to download the latest llamafile from the Mozilla-Ocho/llamafile repo
|
71 |
+
def download_latest_llamafile(output_filename: str) -> str:
|
72 |
+
"""
|
73 |
+
Downloads the latest llamafile binary from the Mozilla-Ocho/llamafile GitHub repository.
|
74 |
+
"""
|
75 |
+
logging.info("Checking for and downloading Llamafile if it doesn't already exist...")
|
76 |
+
if os.path.exists(output_filename):
|
77 |
+
logging.debug(f"{output_filename} already exists. Skipping download.")
|
78 |
+
return os.path.abspath(output_filename)
|
79 |
+
|
80 |
+
repo = "Mozilla-Ocho/llamafile"
|
81 |
+
asset_name_prefix = "llamafile-"
|
82 |
+
latest_release_url = f"https://api.github.com/repos/{repo}/releases/latest"
|
83 |
+
response = requests.get(latest_release_url)
|
84 |
+
if response.status_code != 200:
|
85 |
+
raise Exception(f"Failed to fetch latest release info: {response.status_code}")
|
86 |
+
|
87 |
+
latest_release_data = response.json()
|
88 |
+
tag_name = latest_release_data['tag_name']
|
89 |
+
|
90 |
+
release_details_url = f"https://api.github.com/repos/{repo}/releases/tags/{tag_name}"
|
91 |
+
response = requests.get(release_details_url)
|
92 |
+
if response.status_code != 200:
|
93 |
+
raise Exception(f"Failed to fetch release details for tag {tag_name}: {response.status_code}")
|
94 |
+
|
95 |
+
release_data = response.json()
|
96 |
+
assets = release_data.get('assets', [])
|
97 |
+
|
98 |
+
asset_url = None
|
99 |
+
for asset in assets:
|
100 |
+
if re.match(f"{asset_name_prefix}.*", asset['name']):
|
101 |
+
asset_url = asset['browser_download_url']
|
102 |
+
break
|
103 |
+
|
104 |
+
if not asset_url:
|
105 |
+
raise Exception(f"No asset found with prefix {asset_name_prefix}")
|
106 |
+
|
107 |
+
logging.info("Downloading Llamafile...")
|
108 |
+
download_file(asset_url, output_filename)
|
109 |
+
|
110 |
+
logging.debug(f"Downloaded {output_filename} from {asset_url}")
|
111 |
+
return os.path.abspath(output_filename)
|
112 |
+
|
113 |
+
def download_llm_model(model_name: str, model_url: str, model_filename: str, model_hash: str) -> str:
|
114 |
+
"""
|
115 |
+
Downloads the specified LLM model if not already present.
|
116 |
+
"""
|
117 |
+
logging.info(f"Checking availability of model: {model_name}")
|
118 |
+
if os.path.exists(model_filename):
|
119 |
+
logging.debug(f"Model '{model_name}' already exists. Skipping download.")
|
120 |
+
return os.path.abspath(model_filename)
|
121 |
+
|
122 |
+
logging.info(f"Downloading model: {model_name}")
|
123 |
+
download_file(model_url, model_filename, expected_checksum=model_hash)
|
124 |
+
logging.debug(f"Downloaded model '{model_name}' successfully.")
|
125 |
+
return os.path.abspath(model_filename)
|
126 |
+
|
127 |
+
def launch_in_new_terminal(executable: str, args: List[str]) -> subprocess.Popen:
|
128 |
+
"""
|
129 |
+
Launches the executable in a new terminal window based on the operating system.
|
130 |
+
Returns the subprocess.Popen object.
|
131 |
+
"""
|
132 |
+
useros = os.name
|
133 |
+
if useros == "nt":
|
134 |
+
# For Windows
|
135 |
+
args_str = ' '.join(args)
|
136 |
+
command = f'start cmd /k "{executable} {args_str}"'
|
137 |
+
elif useros == "posix":
|
138 |
+
# For Linux (assuming GNOME Terminal; adjust if necessary)
|
139 |
+
args_str = ' '.join(args)
|
140 |
+
command = f'gnome-terminal -- bash -c "{executable} {args_str}; exec bash"'
|
141 |
+
else:
|
142 |
+
# For macOS
|
143 |
+
args_str = ' '.join(args)
|
144 |
+
command = f'open -a Terminal.app "{executable}" --args {args_str}'
|
145 |
+
|
146 |
+
try:
|
147 |
+
process = subprocess.Popen(command, shell=True)
|
148 |
+
logging.info(f"Launched {executable} with arguments: {args}")
|
149 |
+
return process
|
150 |
+
except Exception as e:
|
151 |
+
logging.error(f"Failed to launch the process: {e}")
|
152 |
+
raise
|
153 |
+
|
154 |
+
# Function to scan the directory for .gguf and .llamafile files
|
155 |
+
def get_gguf_llamafile_files(directory: str) -> List[str]:
|
156 |
+
"""
|
157 |
+
Retrieves model files with extensions .gguf or .llamafile from the specified directory.
|
158 |
+
"""
|
159 |
+
logging.debug(f"Scanning directory: {directory}") # Debug print for directory
|
160 |
+
|
161 |
+
# Print all files in the directory for debugging
|
162 |
+
all_files = os.listdir(directory)
|
163 |
+
logging.debug(f"All files in directory: {all_files}")
|
164 |
+
|
165 |
+
pattern_gguf = os.path.join(directory, "*.gguf")
|
166 |
+
pattern_llamafile = os.path.join(directory, "*.llamafile")
|
167 |
+
|
168 |
+
gguf_files = glob.glob(pattern_gguf)
|
169 |
+
llamafile_files = glob.glob(pattern_llamafile)
|
170 |
+
|
171 |
+
# Debug: Print the files found
|
172 |
+
logging.debug(f"Found .gguf files: {gguf_files}")
|
173 |
+
logging.debug(f"Found .llamafile files: {llamafile_files}")
|
174 |
+
|
175 |
+
return [os.path.basename(f) for f in gguf_files + llamafile_files]
|
176 |
+
|
177 |
+
|
178 |
+
# Initialize process with type annotation
|
179 |
+
process: Optional[subprocess.Popen] = None
|
180 |
+
# Function to close out llamafile process on script exit.
|
181 |
+
def cleanup_process() -> None:
|
182 |
+
"""
|
183 |
+
Terminates the external llamafile process if it is running.
|
184 |
+
"""
|
185 |
+
global process
|
186 |
+
if process is not None:
|
187 |
+
process.kill()
|
188 |
+
logging.debug("Terminated the external process")
|
189 |
+
process = None # Reset the process variable after killing
|
190 |
+
|
191 |
+
def signal_handler(sig, frame):
|
192 |
+
"""
|
193 |
+
Handles termination signals to ensure the subprocess is cleaned up.
|
194 |
+
"""
|
195 |
+
logging.info('Signal handler called with signal: %s', sig)
|
196 |
+
cleanup_process()
|
197 |
+
sys.exit(0)
|
198 |
+
|
199 |
+
# Register signal handlers
|
200 |
+
def setup_signal_handlers():
|
201 |
+
signal.signal(signal.SIGINT, signal_handler)
|
202 |
+
signal.signal(signal.SIGTERM, signal_handler)
|
203 |
+
|
204 |
+
setup_signal_handlers()
|
205 |
+
|
206 |
+
def start_llamafile(
|
207 |
+
am_noob: bool,
|
208 |
+
verbose_checked: bool,
|
209 |
+
threads_checked: bool,
|
210 |
+
threads_value: Optional[int],
|
211 |
+
threads_batched_checked: bool,
|
212 |
+
threads_batched_value: Optional[int],
|
213 |
+
model_alias_checked: bool,
|
214 |
+
model_alias_value: str,
|
215 |
+
http_threads_checked: bool,
|
216 |
+
http_threads_value: Optional[int],
|
217 |
+
model_value: str,
|
218 |
+
hf_repo_checked: bool,
|
219 |
+
hf_repo_value: str,
|
220 |
+
hf_file_checked: bool,
|
221 |
+
hf_file_value: str,
|
222 |
+
ctx_size_checked: bool,
|
223 |
+
ctx_size_value: Optional[int],
|
224 |
+
ngl_checked: bool,
|
225 |
+
ngl_value: Optional[int],
|
226 |
+
batch_size_checked: bool,
|
227 |
+
batch_size_value: Optional[int],
|
228 |
+
memory_f32_checked: bool,
|
229 |
+
numa_checked: bool,
|
230 |
+
server_timeout_value: Optional[int],
|
231 |
+
host_checked: bool,
|
232 |
+
host_value: str,
|
233 |
+
port_checked: bool,
|
234 |
+
port_value: Optional[int],
|
235 |
+
api_key_checked: bool,
|
236 |
+
api_key_value: Optional[str],
|
237 |
+
) -> str:
|
238 |
+
"""
|
239 |
+
Starts the llamafile process based on provided configuration.
|
240 |
+
"""
|
241 |
+
global process
|
242 |
+
|
243 |
+
# Construct command based on checked values
|
244 |
+
command = []
|
245 |
+
if am_noob:
|
246 |
+
# Define what 'am_noob' does, e.g., set default parameters
|
247 |
+
command.append('--sane-defaults') # Replace with actual flag if needed
|
248 |
+
|
249 |
+
if verbose_checked:
|
250 |
+
command.append('-v')
|
251 |
+
|
252 |
+
if threads_checked and threads_value is not None:
|
253 |
+
command.extend(['-t', str(threads_value)])
|
254 |
+
|
255 |
+
if http_threads_checked and http_threads_value is not None:
|
256 |
+
command.extend(['--threads', str(http_threads_value)])
|
257 |
+
|
258 |
+
if threads_batched_checked and threads_batched_value is not None:
|
259 |
+
command.extend(['-tb', str(threads_batched_value)])
|
260 |
+
|
261 |
+
if model_alias_checked and model_alias_value:
|
262 |
+
command.extend(['-a', model_alias_value])
|
263 |
+
|
264 |
+
# Set model path
|
265 |
+
model_path = os.path.abspath(model_value)
|
266 |
+
command.extend(['-m', model_path])
|
267 |
+
|
268 |
+
if hf_repo_checked and hf_repo_value:
|
269 |
+
command.extend(['-hfr', hf_repo_value])
|
270 |
+
|
271 |
+
if hf_file_checked and hf_file_value:
|
272 |
+
command.extend(['-hff', hf_file_value])
|
273 |
+
|
274 |
+
if ctx_size_checked and ctx_size_value is not None:
|
275 |
+
command.extend(['-c', str(ctx_size_value)])
|
276 |
+
|
277 |
+
if ngl_checked and ngl_value is not None:
|
278 |
+
command.extend(['-ngl', str(ngl_value)])
|
279 |
+
|
280 |
+
if batch_size_checked and batch_size_value is not None:
|
281 |
+
command.extend(['-b', str(batch_size_value)])
|
282 |
+
|
283 |
+
if memory_f32_checked:
|
284 |
+
command.append('--memory-f32')
|
285 |
+
|
286 |
+
if numa_checked:
|
287 |
+
command.append('--numa')
|
288 |
+
|
289 |
+
if host_checked and host_value:
|
290 |
+
command.extend(['--host', host_value])
|
291 |
+
|
292 |
+
if port_checked and port_value is not None:
|
293 |
+
command.extend(['--port', str(port_value)])
|
294 |
+
|
295 |
+
if api_key_checked and api_key_value:
|
296 |
+
command.extend(['--api-key', api_key_value])
|
297 |
+
|
298 |
+
try:
|
299 |
+
useros = os.name
|
300 |
+
output_filename = "llamafile.exe" if useros == "nt" else "llamafile"
|
301 |
+
|
302 |
+
# Ensure llamafile is downloaded
|
303 |
+
llamafile_path = download_latest_llamafile(output_filename)
|
304 |
+
|
305 |
+
# Start llamafile process
|
306 |
+
process = launch_in_new_terminal(llamafile_path, command)
|
307 |
+
|
308 |
+
logging.info(f"Llamafile started with command: {' '.join(command)}")
|
309 |
+
return f"Command built and ran: {' '.join(command)} \n\nLlamafile started successfully."
|
310 |
+
|
311 |
+
except Exception as e:
|
312 |
+
logging.error(f"Failed to start llamafile: {e}")
|
313 |
+
return f"Failed to start llamafile: {e}"
|
314 |
+
|
315 |
+
#
|
316 |
+
# End of Local_LLM_Inference_Engine_Lib.py
|
317 |
+
#######################################################################################################################
|
App_Function_Libraries/Local_LLM/Local_LLM_huggingface.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# import gradio as gr
|
2 |
+
# from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
|
3 |
+
# import os
|
4 |
+
# import torch
|
5 |
+
#
|
6 |
+
# # Assuming models are stored in a 'models' directory
|
7 |
+
# MODELS_DIR = "models"
|
8 |
+
#
|
9 |
+
#
|
10 |
+
# def get_local_models():
|
11 |
+
# if not os.path.exists(MODELS_DIR):
|
12 |
+
# os.makedirs(MODELS_DIR)
|
13 |
+
# return [d for d in os.listdir(MODELS_DIR) if os.path.isdir(os.path.join(MODELS_DIR, d))]
|
14 |
+
#
|
15 |
+
#
|
16 |
+
# def download_model(model_name):
|
17 |
+
# try:
|
18 |
+
# tokenizer = AutoTokenizer.from_pretrained(model_name)
|
19 |
+
# model = AutoModelForCausalLM.from_pretrained(model_name)
|
20 |
+
#
|
21 |
+
# # Save the model and tokenizer
|
22 |
+
# save_path = os.path.join(MODELS_DIR, model_name.split('/')[-1])
|
23 |
+
# tokenizer.save_pretrained(save_path)
|
24 |
+
# model.save_pretrained(save_path)
|
25 |
+
#
|
26 |
+
# return f"Successfully downloaded model: {model_name}"
|
27 |
+
# except Exception as e:
|
28 |
+
# return f"Failed to download model: {str(e)}"
|
29 |
+
#
|
30 |
+
#
|
31 |
+
# def run_inference(model_name, prompt):
|
32 |
+
# try:
|
33 |
+
# model_path = os.path.join(MODELS_DIR, model_name)
|
34 |
+
# tokenizer = AutoTokenizer.from_pretrained(model_path)
|
35 |
+
# model = AutoModelForCausalLM.from_pretrained(model_path)
|
36 |
+
#
|
37 |
+
# # Use GPU if available
|
38 |
+
# device = "cuda" if torch.cuda.is_available() else "cpu"
|
39 |
+
# model.to(device)
|
40 |
+
#
|
41 |
+
# # Create a text-generation pipeline
|
42 |
+
# text_generator = pipeline("text-generation", model=model, tokenizer=tokenizer, device=device)
|
43 |
+
#
|
44 |
+
# # Generate text
|
45 |
+
# result = text_generator(prompt, max_length=100, num_return_sequences=1)
|
46 |
+
#
|
47 |
+
# return result[0]['generated_text']
|
48 |
+
# except Exception as e:
|
49 |
+
# return f"Error running inference: {str(e)}"
|
50 |
+
#
|
51 |
+
#
|
52 |
+
# def create_huggingface_tab():
|
53 |
+
# with gr.Tab("Hugging Face Transformers"):
|
54 |
+
# gr.Markdown("# Hugging Face Transformers Model Management")
|
55 |
+
#
|
56 |
+
# with gr.Row():
|
57 |
+
# model_list = gr.Dropdown(label="Available Models", choices=get_local_models())
|
58 |
+
# refresh_button = gr.Button("Refresh Model List")
|
59 |
+
#
|
60 |
+
# with gr.Row():
|
61 |
+
# new_model_name = gr.Textbox(label="Model to Download (e.g., 'gpt2' or 'EleutherAI/gpt-neo-1.3B')")
|
62 |
+
# download_button = gr.Button("Download Model")
|
63 |
+
#
|
64 |
+
# download_output = gr.Textbox(label="Download Status")
|
65 |
+
#
|
66 |
+
# with gr.Row():
|
67 |
+
# run_model = gr.Dropdown(label="Model to Run", choices=get_local_models())
|
68 |
+
# prompt = gr.Textbox(label="Prompt")
|
69 |
+
# run_button = gr.Button("Run Inference")
|
70 |
+
#
|
71 |
+
# run_output = gr.Textbox(label="Model Output")
|
72 |
+
#
|
73 |
+
# def update_model_lists():
|
74 |
+
# models = get_local_models()
|
75 |
+
# return gr.update(choices=models), gr.update(choices=models)
|
76 |
+
#
|
77 |
+
# refresh_button.click(update_model_lists, outputs=[model_list, run_model])
|
78 |
+
# download_button.click(download_model, inputs=[new_model_name], outputs=[download_output])
|
79 |
+
# run_button.click(run_inference, inputs=[run_model, prompt], outputs=[run_output])
|
App_Function_Libraries/Local_LLM/Local_LLM_ollama.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import platform
|
2 |
+
|
3 |
+
import gradio as gr
|
4 |
+
import subprocess
|
5 |
+
import psutil
|
6 |
+
import os
|
7 |
+
import signal
|
8 |
+
|
9 |
+
|
10 |
+
def get_ollama_models():
|
11 |
+
try:
|
12 |
+
result = subprocess.run(['ollama', 'list'], capture_output=True, text=True, check=True)
|
13 |
+
models = result.stdout.strip().split('\n')[1:] # Skip header
|
14 |
+
return [model.split()[0] for model in models]
|
15 |
+
except subprocess.CalledProcessError:
|
16 |
+
return []
|
17 |
+
|
18 |
+
|
19 |
+
def pull_ollama_model(model_name):
|
20 |
+
try:
|
21 |
+
subprocess.run(['ollama', 'pull', model_name], check=True)
|
22 |
+
return f"Successfully pulled model: {model_name}"
|
23 |
+
except subprocess.CalledProcessError as e:
|
24 |
+
return f"Failed to pull model: {e}"
|
25 |
+
|
26 |
+
|
27 |
+
def serve_ollama_model(model_name, port):
|
28 |
+
try:
|
29 |
+
# Check if a server is already running on the specified port
|
30 |
+
for conn in psutil.net_connections():
|
31 |
+
if conn.laddr.port == int(port):
|
32 |
+
return f"Port {port} is already in use. Please choose a different port."
|
33 |
+
|
34 |
+
# Start the Ollama server
|
35 |
+
port = str(port)
|
36 |
+
os.environ["OLLAMA_HOST"] = port
|
37 |
+
cmd = f"ollama serve"
|
38 |
+
process = subprocess.Popen(cmd, shell=True)
|
39 |
+
return f"Started Ollama server for model {model_name} on port {port}. Process ID: {process.pid}"
|
40 |
+
except Exception as e:
|
41 |
+
return f"Error starting Ollama server: {e}"
|
42 |
+
|
43 |
+
|
44 |
+
def stop_ollama_server(pid):
|
45 |
+
try:
|
46 |
+
if platform.system() == "Windows":
|
47 |
+
os.system(f"taskkill /F /PID {pid}")
|
48 |
+
return f"Stopped Ollama server with PID {pid}"
|
49 |
+
elif platform.system() == "Linux":
|
50 |
+
os.system(f"kill {pid}")
|
51 |
+
return f"Stopped Ollama server with PID {pid}"
|
52 |
+
elif platform.system() == "Darwin":
|
53 |
+
os.system("""osascript -e 'tell app "Ollama" to quit'""")
|
54 |
+
return f"(Hopefully) Stopped Ollama server using osascript..."
|
55 |
+
except ProcessLookupError:
|
56 |
+
return f"No process found with PID {pid}"
|
57 |
+
except Exception as e:
|
58 |
+
return f"Error stopping Ollama server: {e}"
|
59 |
+
|
60 |
+
|
61 |
+
def create_ollama_tab():
|
62 |
+
with gr.Tab("Ollama Model Serving"):
|
63 |
+
gr.Markdown("# Ollama Model Serving")
|
64 |
+
|
65 |
+
with gr.Row():
|
66 |
+
model_list = gr.Dropdown(label="Available Models", choices=get_ollama_models())
|
67 |
+
refresh_button = gr.Button("Refresh Model List")
|
68 |
+
|
69 |
+
with gr.Row():
|
70 |
+
new_model_name = gr.Textbox(label="Model to Pull")
|
71 |
+
pull_button = gr.Button("Pull Model")
|
72 |
+
|
73 |
+
pull_output = gr.Textbox(label="Pull Status")
|
74 |
+
|
75 |
+
with gr.Row():
|
76 |
+
# FIXME - Update to update config.txt file
|
77 |
+
serve_model = gr.Dropdown(label="Model to Serve", choices=get_ollama_models())
|
78 |
+
port = gr.Number(label="Port", value=11434, precision=0)
|
79 |
+
serve_button = gr.Button("Start Server")
|
80 |
+
|
81 |
+
serve_output = gr.Textbox(label="Server Status")
|
82 |
+
|
83 |
+
with gr.Row():
|
84 |
+
pid = gr.Number(label="Server Process ID", precision=0)
|
85 |
+
stop_button = gr.Button("Stop Server")
|
86 |
+
|
87 |
+
stop_output = gr.Textbox(label="Stop Status")
|
88 |
+
|
89 |
+
def update_model_lists():
|
90 |
+
models = get_ollama_models()
|
91 |
+
return gr.update(choices=models), gr.update(choices=models)
|
92 |
+
|
93 |
+
refresh_button.click(update_model_lists, outputs=[model_list, serve_model])
|
94 |
+
pull_button.click(pull_ollama_model, inputs=[new_model_name], outputs=[pull_output])
|
95 |
+
serve_button.click(serve_ollama_model, inputs=[serve_model, port], outputs=[serve_output])
|
96 |
+
stop_button.click(stop_ollama_server, inputs=[pid], outputs=[stop_output])
|
App_Function_Libraries/Metrics/__init__.py
ADDED
File without changes
|
App_Function_Libraries/PDF/PDF_Ingestion_Lib.py
CHANGED
@@ -11,171 +11,35 @@
|
|
11 |
#
|
12 |
#
|
13 |
####################
|
14 |
-
import re
|
15 |
-
|
16 |
# Import necessary libraries
|
17 |
-
|
18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
# Import Local
|
20 |
-
|
21 |
-
|
22 |
-
# Function Definitions
|
23 |
#
|
24 |
-
|
25 |
-
# Ingest a text file into the database with Title/Author/Keywords
|
26 |
-
|
27 |
-
|
28 |
# Constants
|
29 |
MAX_FILE_SIZE_MB = 50
|
30 |
CONVERSION_TIMEOUT_SECONDS = 300
|
31 |
-
|
32 |
-
# Marker PDF solution
|
33 |
-
# def convert_pdf_to_markdown(pdf_path):
|
34 |
-
# """
|
35 |
-
# Convert a PDF file to Markdown by calling a script in another virtual environment.
|
36 |
-
# """
|
37 |
-
#
|
38 |
-
# logging.debug(f"Marker: Converting PDF file to Markdown: {pdf_path}")
|
39 |
-
# # Check if the file size exceeds the maximum allowed size
|
40 |
-
# file_size_mb = os.path.getsize(pdf_path) / (1024 * 1024)
|
41 |
-
# if file_size_mb > MAX_FILE_SIZE_MB:
|
42 |
-
# raise ValueError(f"File size ({file_size_mb:.2f} MB) exceeds the maximum allowed size of {MAX_FILE_SIZE_MB} MB")
|
43 |
-
#
|
44 |
-
# logging.debug("Marker: Converting PDF file to Markdown using Marker virtual environment")
|
45 |
-
# # Path to the Python interpreter in the other virtual environment
|
46 |
-
# other_venv_python = "Helper_Scripts/marker_venv/bin/python"
|
47 |
-
#
|
48 |
-
# # Path to the conversion script
|
49 |
-
# converter_script = "Helper_Scripts/PDF_Converter.py"
|
50 |
-
#
|
51 |
-
# logging.debug("Marker: Attempting to convert PDF file to Markdown...")
|
52 |
-
# try:
|
53 |
-
# result = subprocess.run(
|
54 |
-
# [other_venv_python, converter_script, pdf_path],
|
55 |
-
# capture_output=True,
|
56 |
-
# text=True,
|
57 |
-
# timeout=CONVERSION_TIMEOUT_SECONDS
|
58 |
-
# )
|
59 |
-
# if result.returncode != 0:
|
60 |
-
# raise Exception(f"Conversion failed: {result.stderr}")
|
61 |
-
# return result.stdout
|
62 |
-
# except subprocess.TimeoutExpired:
|
63 |
-
# raise Exception(f"PDF conversion timed out after {CONVERSION_TIMEOUT_SECONDS} seconds")
|
64 |
-
#
|
65 |
-
#
|
66 |
-
# def process_and_ingest_pdf(file, title, author, keywords):
|
67 |
-
# if file is None:
|
68 |
-
# return "Please select a PDF file to upload."
|
69 |
-
#
|
70 |
-
# try:
|
71 |
-
# # Create a temporary directory
|
72 |
-
# with tempfile.TemporaryDirectory() as temp_dir:
|
73 |
-
# # Create a path for the temporary PDF file
|
74 |
-
# temp_path = os.path.join(temp_dir, "temp.pdf")
|
75 |
-
#
|
76 |
-
# # Copy the contents of the uploaded file to the temporary file
|
77 |
-
# shutil.copy(file.name, temp_path)
|
78 |
-
#
|
79 |
-
# # Call the ingest_pdf_file function with the temporary file path
|
80 |
-
# result = ingest_pdf_file(temp_path, title, author, keywords)
|
81 |
-
#
|
82 |
-
# return result
|
83 |
-
# except Exception as e:
|
84 |
-
# return f"Error processing PDF: {str(e)}"
|
85 |
-
#
|
86 |
-
#
|
87 |
-
# def ingest_pdf_file(file_path, title=None, author=None, keywords=None):
|
88 |
-
# try:
|
89 |
-
# # Convert PDF to Markdown
|
90 |
-
# markdown_content = convert_pdf_to_markdown(file_path)
|
91 |
-
#
|
92 |
-
# # If title is not provided, use the filename without extension
|
93 |
-
# if not title:
|
94 |
-
# title = os.path.splitext(os.path.basename(file_path))[0]
|
95 |
-
#
|
96 |
-
# # If author is not provided, set it to 'Unknown'
|
97 |
-
# if not author:
|
98 |
-
# author = 'Unknown'
|
99 |
-
#
|
100 |
-
# # If keywords are not provided, use a default keyword
|
101 |
-
# if not keywords:
|
102 |
-
# keywords = 'pdf_file,markdown_converted'
|
103 |
-
# else:
|
104 |
-
# keywords = f'pdf_file,markdown_converted,{keywords}'
|
105 |
-
#
|
106 |
-
# # Add the markdown content to the database
|
107 |
-
# add_media_with_keywords(
|
108 |
-
# url=file_path,
|
109 |
-
# title=title,
|
110 |
-
# media_type='document',
|
111 |
-
# content=markdown_content,
|
112 |
-
# keywords=keywords,
|
113 |
-
# prompt='No prompt for PDF files',
|
114 |
-
# summary='No summary for PDF files',
|
115 |
-
# transcription_model='None',
|
116 |
-
# author=author,
|
117 |
-
# ingestion_date=datetime.now().strftime('%Y-%m-%d')
|
118 |
-
# )
|
119 |
-
#
|
120 |
-
# return f"PDF file '{title}' converted to Markdown and ingested successfully.", file_path
|
121 |
-
# except ValueError as e:
|
122 |
-
# logging.error(f"File size error: {str(e)}")
|
123 |
-
# return f"Error: {str(e)}", file_path
|
124 |
-
# except Exception as e:
|
125 |
-
# logging.error(f"Error ingesting PDF file: {str(e)}")
|
126 |
-
# return f"Error ingesting PDF file: {str(e)}", file_path
|
127 |
-
#
|
128 |
-
#
|
129 |
-
# def process_and_cleanup_pdf(file, title, author, keywords):
|
130 |
-
# # FIXME - Update to validate file upload/filetype is pdf....
|
131 |
-
# if file is None:
|
132 |
-
# return "No file uploaded. Please upload a PDF file."
|
133 |
-
#
|
134 |
-
# temp_dir = tempfile.mkdtemp()
|
135 |
-
# temp_file_path = os.path.join(temp_dir, "temp.pdf")
|
136 |
-
#
|
137 |
-
# try:
|
138 |
-
# # Copy the uploaded file to a temporary location
|
139 |
-
# shutil.copy2(file.name, temp_file_path)
|
140 |
-
#
|
141 |
-
# # Process the file
|
142 |
-
# result, _ = ingest_pdf_file(temp_file_path, title, author, keywords)
|
143 |
-
#
|
144 |
-
# return result
|
145 |
-
# except Exception as e:
|
146 |
-
# logging.error(f"Error in processing and cleanup: {str(e)}")
|
147 |
-
# return f"Error: {str(e)}"
|
148 |
-
# finally:
|
149 |
-
# # Clean up the temporary directory and its contents
|
150 |
-
# try:
|
151 |
-
# shutil.rmtree(temp_dir)
|
152 |
-
# logging.info(f"Removed temporary directory: {temp_dir}")
|
153 |
-
# except Exception as cleanup_error:
|
154 |
-
# logging.error(f"Error during cleanup: {str(cleanup_error)}")
|
155 |
-
# result += f"\nWarning: Could not remove temporary files: {str(cleanup_error)}"
|
156 |
-
|
157 |
-
|
158 |
-
import logging
|
159 |
-
#
|
160 |
#
|
161 |
#######################################################################################################################
|
|
|
162 |
#
|
163 |
-
# Non-Marker implementation
|
164 |
-
import os
|
165 |
-
import shutil
|
166 |
-
import tempfile
|
167 |
-
from datetime import datetime
|
168 |
-
|
169 |
-
import pymupdf
|
170 |
-
|
171 |
-
from App_Function_Libraries.DB.DB_Manager import add_media_with_keywords
|
172 |
-
|
173 |
|
174 |
def extract_text_and_format_from_pdf(pdf_path):
|
175 |
"""
|
176 |
Extract text from a PDF file and convert it to Markdown, preserving formatting.
|
177 |
"""
|
178 |
try:
|
|
|
|
|
|
|
179 |
markdown_text = ""
|
180 |
with pymupdf.open(pdf_path) as doc:
|
181 |
for page_num, page in enumerate(doc, 1):
|
@@ -228,9 +92,15 @@ def extract_text_and_format_from_pdf(pdf_path):
|
|
228 |
# Clean up hyphenated words
|
229 |
markdown_text = re.sub(r'(\w+)-\s*\n(\w+)', r'\1\2', markdown_text)
|
230 |
|
|
|
|
|
|
|
|
|
|
|
231 |
return markdown_text
|
232 |
except Exception as e:
|
233 |
logging.error(f"Error extracting text and formatting from PDF: {str(e)}")
|
|
|
234 |
raise
|
235 |
|
236 |
|
@@ -239,19 +109,26 @@ def extract_metadata_from_pdf(pdf_path):
|
|
239 |
Extract metadata from a PDF file using PyMuPDF.
|
240 |
"""
|
241 |
try:
|
|
|
242 |
with pymupdf.open(pdf_path) as doc:
|
243 |
metadata = doc.metadata
|
|
|
244 |
return metadata
|
245 |
except Exception as e:
|
246 |
logging.error(f"Error extracting metadata from PDF: {str(e)}")
|
|
|
247 |
return {}
|
248 |
|
249 |
|
250 |
def process_and_ingest_pdf(file, title, author, keywords):
|
251 |
if file is None:
|
|
|
252 |
return "Please select a PDF file to upload."
|
253 |
|
254 |
try:
|
|
|
|
|
|
|
255 |
# Create a temporary directory
|
256 |
with tempfile.TemporaryDirectory() as temp_dir:
|
257 |
# Create a path for the temporary PDF file
|
@@ -296,23 +173,40 @@ def process_and_ingest_pdf(file, title, author, keywords):
|
|
296 |
ingestion_date=datetime.now().strftime('%Y-%m-%d')
|
297 |
)
|
298 |
|
|
|
|
|
|
|
|
|
|
|
299 |
return f"PDF file '{title}' by {author} ingested successfully and converted to Markdown."
|
300 |
except Exception as e:
|
301 |
logging.error(f"Error ingesting PDF file: {str(e)}")
|
|
|
302 |
return f"Error ingesting PDF file: {str(e)}"
|
303 |
|
304 |
|
305 |
def process_and_cleanup_pdf(file, title, author, keywords):
|
306 |
if file is None:
|
|
|
307 |
return "No file uploaded. Please upload a PDF file."
|
308 |
|
309 |
try:
|
|
|
|
|
|
|
310 |
result = process_and_ingest_pdf(file, title, author, keywords)
|
|
|
|
|
|
|
|
|
|
|
|
|
311 |
return result
|
312 |
except Exception as e:
|
313 |
logging.error(f"Error in processing and cleanup: {str(e)}")
|
|
|
314 |
return f"Error: {str(e)}"
|
315 |
|
316 |
#
|
317 |
# End of PDF_Ingestion_Lib.py
|
318 |
-
#######################################################################################################################
|
|
|
11 |
#
|
12 |
#
|
13 |
####################
|
|
|
|
|
14 |
# Import necessary libraries
|
15 |
+
import re
|
16 |
+
import os
|
17 |
+
import shutil
|
18 |
+
import tempfile
|
19 |
+
from datetime import datetime
|
20 |
+
import pymupdf
|
21 |
+
import logging
|
22 |
+
#
|
23 |
# Import Local
|
24 |
+
from App_Function_Libraries.DB.DB_Manager import add_media_with_keywords
|
25 |
+
from App_Function_Libraries.Metrics.metrics_logger import log_counter, log_histogram
|
|
|
26 |
#
|
|
|
|
|
|
|
|
|
27 |
# Constants
|
28 |
MAX_FILE_SIZE_MB = 50
|
29 |
CONVERSION_TIMEOUT_SECONDS = 300
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
#
|
31 |
#######################################################################################################################
|
32 |
+
# Function Definitions
|
33 |
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
|
35 |
def extract_text_and_format_from_pdf(pdf_path):
|
36 |
"""
|
37 |
Extract text from a PDF file and convert it to Markdown, preserving formatting.
|
38 |
"""
|
39 |
try:
|
40 |
+
log_counter("pdf_text_extraction_attempt", labels={"file_path": pdf_path})
|
41 |
+
start_time = datetime.now()
|
42 |
+
|
43 |
markdown_text = ""
|
44 |
with pymupdf.open(pdf_path) as doc:
|
45 |
for page_num, page in enumerate(doc, 1):
|
|
|
92 |
# Clean up hyphenated words
|
93 |
markdown_text = re.sub(r'(\w+)-\s*\n(\w+)', r'\1\2', markdown_text)
|
94 |
|
95 |
+
end_time = datetime.now()
|
96 |
+
processing_time = (end_time - start_time).total_seconds()
|
97 |
+
log_histogram("pdf_text_extraction_duration", processing_time, labels={"file_path": pdf_path})
|
98 |
+
log_counter("pdf_text_extraction_success", labels={"file_path": pdf_path})
|
99 |
+
|
100 |
return markdown_text
|
101 |
except Exception as e:
|
102 |
logging.error(f"Error extracting text and formatting from PDF: {str(e)}")
|
103 |
+
log_counter("pdf_text_extraction_error", labels={"file_path": pdf_path, "error": str(e)})
|
104 |
raise
|
105 |
|
106 |
|
|
|
109 |
Extract metadata from a PDF file using PyMuPDF.
|
110 |
"""
|
111 |
try:
|
112 |
+
log_counter("pdf_metadata_extraction_attempt", labels={"file_path": pdf_path})
|
113 |
with pymupdf.open(pdf_path) as doc:
|
114 |
metadata = doc.metadata
|
115 |
+
log_counter("pdf_metadata_extraction_success", labels={"file_path": pdf_path})
|
116 |
return metadata
|
117 |
except Exception as e:
|
118 |
logging.error(f"Error extracting metadata from PDF: {str(e)}")
|
119 |
+
log_counter("pdf_metadata_extraction_error", labels={"file_path": pdf_path, "error": str(e)})
|
120 |
return {}
|
121 |
|
122 |
|
123 |
def process_and_ingest_pdf(file, title, author, keywords):
|
124 |
if file is None:
|
125 |
+
log_counter("pdf_ingestion_error", labels={"error": "No file uploaded"})
|
126 |
return "Please select a PDF file to upload."
|
127 |
|
128 |
try:
|
129 |
+
log_counter("pdf_ingestion_attempt", labels={"file_name": file.name})
|
130 |
+
start_time = datetime.now()
|
131 |
+
|
132 |
# Create a temporary directory
|
133 |
with tempfile.TemporaryDirectory() as temp_dir:
|
134 |
# Create a path for the temporary PDF file
|
|
|
173 |
ingestion_date=datetime.now().strftime('%Y-%m-%d')
|
174 |
)
|
175 |
|
176 |
+
end_time = datetime.now()
|
177 |
+
processing_time = (end_time - start_time).total_seconds()
|
178 |
+
log_histogram("pdf_ingestion_duration", processing_time, labels={"file_name": file.name})
|
179 |
+
log_counter("pdf_ingestion_success", labels={"file_name": file.name})
|
180 |
+
|
181 |
return f"PDF file '{title}' by {author} ingested successfully and converted to Markdown."
|
182 |
except Exception as e:
|
183 |
logging.error(f"Error ingesting PDF file: {str(e)}")
|
184 |
+
log_counter("pdf_ingestion_error", labels={"file_name": file.name, "error": str(e)})
|
185 |
return f"Error ingesting PDF file: {str(e)}"
|
186 |
|
187 |
|
188 |
def process_and_cleanup_pdf(file, title, author, keywords):
|
189 |
if file is None:
|
190 |
+
log_counter("pdf_processing_error", labels={"error": "No file uploaded"})
|
191 |
return "No file uploaded. Please upload a PDF file."
|
192 |
|
193 |
try:
|
194 |
+
log_counter("pdf_processing_attempt", labels={"file_name": file.name})
|
195 |
+
start_time = datetime.now()
|
196 |
+
|
197 |
result = process_and_ingest_pdf(file, title, author, keywords)
|
198 |
+
|
199 |
+
end_time = datetime.now()
|
200 |
+
processing_time = (end_time - start_time).total_seconds()
|
201 |
+
log_histogram("pdf_processing_duration", processing_time, labels={"file_name": file.name})
|
202 |
+
log_counter("pdf_processing_success", labels={"file_name": file.name})
|
203 |
+
|
204 |
return result
|
205 |
except Exception as e:
|
206 |
logging.error(f"Error in processing and cleanup: {str(e)}")
|
207 |
+
log_counter("pdf_processing_error", labels={"file_name": file.name, "error": str(e)})
|
208 |
return f"Error: {str(e)}"
|
209 |
|
210 |
#
|
211 |
# End of PDF_Ingestion_Lib.py
|
212 |
+
#######################################################################################################################
|
App_Function_Libraries/RAG/ChromaDB_Library.py
CHANGED
@@ -8,14 +8,13 @@ from typing import List, Dict, Any
|
|
8 |
import chromadb
|
9 |
from chromadb import Settings
|
10 |
from itertools import islice
|
|
|
11 |
#
|
12 |
# Local Imports:
|
13 |
from App_Function_Libraries.Chunk_Lib import chunk_for_embedding, chunk_options
|
14 |
from App_Function_Libraries.DB.DB_Manager import get_unprocessed_media, mark_media_as_processed
|
15 |
from App_Function_Libraries.DB.SQLite_DB import process_chunks
|
16 |
-
from App_Function_Libraries.RAG.Embeddings_Create import create_embeddings_batch
|
17 |
-
# FIXME - related to Chunking
|
18 |
-
from App_Function_Libraries.RAG.Embeddings_Create import create_embedding
|
19 |
from App_Function_Libraries.Summarization.Summarization_General_Lib import summarize
|
20 |
from App_Function_Libraries.Utils.Utils import get_database_path, ensure_directory_exists, \
|
21 |
load_comprehensive_config
|
@@ -216,55 +215,119 @@ def reset_chroma_collection(collection_name: str):
|
|
216 |
logging.error(f"Error resetting ChromaDB collection: {str(e)}")
|
217 |
|
218 |
|
219 |
-
|
220 |
-
|
221 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
222 |
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
229 |
|
230 |
-
#
|
231 |
collection.upsert(
|
232 |
documents=texts,
|
233 |
embeddings=embeddings,
|
234 |
ids=ids,
|
235 |
metadatas=metadatas
|
236 |
)
|
|
|
237 |
|
238 |
-
# Verify
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
|
|
243 |
else:
|
244 |
-
logging.
|
245 |
-
logging.debug(f"Stored document: {
|
246 |
-
logging.debug(f"Stored metadata: {
|
|
|
|
|
247 |
|
248 |
except Exception as e:
|
249 |
-
logging.error(f"Error
|
250 |
raise
|
251 |
|
|
|
|
|
252 |
|
253 |
# Function to perform vector search using ChromaDB + Keywords from the media_db
|
|
|
254 |
def vector_search(collection_name: str, query: str, k: int = 10) -> List[Dict[str, Any]]:
|
255 |
try:
|
256 |
-
query_embedding = create_embedding(query, embedding_provider, embedding_model, embedding_api_url)
|
257 |
collection = chroma_client.get_collection(name=collection_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
258 |
results = collection.query(
|
259 |
query_embeddings=[query_embedding],
|
260 |
n_results=k,
|
261 |
include=["documents", "metadatas"]
|
262 |
)
|
|
|
|
|
|
|
|
|
|
|
263 |
return [{"content": doc, "metadata": meta} for doc, meta in zip(results['documents'][0], results['metadatas'][0])]
|
264 |
except Exception as e:
|
265 |
-
logging.error(f"Error in vector_search: {str(e)}")
|
266 |
raise
|
267 |
|
|
|
268 |
def schedule_embedding(media_id: int, content: str, media_name: str):
|
269 |
try:
|
270 |
chunks = chunk_for_embedding(content, media_name, chunk_options)
|
@@ -312,4 +375,142 @@ def schedule_embedding(media_id: int, content: str, media_name: str):
|
|
312 |
|
313 |
#
|
314 |
# End of Functions for ChromaDB
|
315 |
-
#######################################################################################################################
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
import chromadb
|
9 |
from chromadb import Settings
|
10 |
from itertools import islice
|
11 |
+
import numpy as np
|
12 |
#
|
13 |
# Local Imports:
|
14 |
from App_Function_Libraries.Chunk_Lib import chunk_for_embedding, chunk_options
|
15 |
from App_Function_Libraries.DB.DB_Manager import get_unprocessed_media, mark_media_as_processed
|
16 |
from App_Function_Libraries.DB.SQLite_DB import process_chunks
|
17 |
+
from App_Function_Libraries.RAG.Embeddings_Create import create_embedding, create_embeddings_batch
|
|
|
|
|
18 |
from App_Function_Libraries.Summarization.Summarization_General_Lib import summarize
|
19 |
from App_Function_Libraries.Utils.Utils import get_database_path, ensure_directory_exists, \
|
20 |
load_comprehensive_config
|
|
|
215 |
logging.error(f"Error resetting ChromaDB collection: {str(e)}")
|
216 |
|
217 |
|
218 |
+
#v2
|
219 |
+
def store_in_chroma(collection_name: str, texts: List[str], embeddings: Any, ids: List[str],
|
220 |
+
metadatas: List[Dict[str, Any]]):
|
221 |
+
# Convert embeddings to list if it's a numpy array
|
222 |
+
if isinstance(embeddings, np.ndarray):
|
223 |
+
embeddings = embeddings.tolist()
|
224 |
+
elif not isinstance(embeddings, list):
|
225 |
+
raise TypeError("Embeddings must be either a list or a numpy array")
|
226 |
+
|
227 |
+
if not embeddings:
|
228 |
+
raise ValueError("No embeddings provided")
|
229 |
+
|
230 |
+
embedding_dim = len(embeddings[0])
|
231 |
|
232 |
+
logging.info(f"Storing embeddings in ChromaDB - Collection: {collection_name}")
|
233 |
+
logging.info(f"Number of embeddings: {len(embeddings)}, Dimension: {embedding_dim}")
|
234 |
+
|
235 |
+
try:
|
236 |
+
# Attempt to get or create the collection
|
237 |
+
try:
|
238 |
+
collection = chroma_client.get_collection(name=collection_name)
|
239 |
+
logging.info(f"Existing collection '{collection_name}' found")
|
240 |
+
|
241 |
+
# Check dimension of existing embeddings
|
242 |
+
existing_embeddings = collection.get(limit=1, include=['embeddings'])['embeddings']
|
243 |
+
if existing_embeddings:
|
244 |
+
existing_dim = len(existing_embeddings[0])
|
245 |
+
if existing_dim != embedding_dim:
|
246 |
+
logging.warning(f"Embedding dimension mismatch. Existing: {existing_dim}, New: {embedding_dim}")
|
247 |
+
logging.warning("Deleting existing collection and creating a new one")
|
248 |
+
chroma_client.delete_collection(name=collection_name)
|
249 |
+
collection = chroma_client.create_collection(name=collection_name)
|
250 |
+
else:
|
251 |
+
logging.info("No existing embeddings in the collection")
|
252 |
+
except Exception as e:
|
253 |
+
logging.info(f"Collection '{collection_name}' not found. Creating new collection")
|
254 |
+
collection = chroma_client.create_collection(name=collection_name)
|
255 |
|
256 |
+
# Perform the upsert operation
|
257 |
collection.upsert(
|
258 |
documents=texts,
|
259 |
embeddings=embeddings,
|
260 |
ids=ids,
|
261 |
metadatas=metadatas
|
262 |
)
|
263 |
+
logging.info(f"Successfully upserted {len(embeddings)} embeddings")
|
264 |
|
265 |
+
# Verify all stored embeddings
|
266 |
+
results = collection.get(ids=ids, include=["documents", "embeddings", "metadatas"])
|
267 |
+
|
268 |
+
for i, doc_id in enumerate(ids):
|
269 |
+
if results['embeddings'][i] is None:
|
270 |
+
raise ValueError(f"Failed to store embedding for {doc_id}")
|
271 |
else:
|
272 |
+
logging.debug(f"Embedding stored successfully for {doc_id}")
|
273 |
+
logging.debug(f"Stored document preview: {results['documents'][i][:100]}...")
|
274 |
+
logging.debug(f"Stored metadata: {results['metadatas'][i]}")
|
275 |
+
|
276 |
+
logging.info("Successfully stored and verified all embeddings in ChromaDB")
|
277 |
|
278 |
except Exception as e:
|
279 |
+
logging.error(f"Error in store_in_chroma: {str(e)}")
|
280 |
raise
|
281 |
|
282 |
+
return collection
|
283 |
+
|
284 |
|
285 |
# Function to perform vector search using ChromaDB + Keywords from the media_db
|
286 |
+
#v2
|
287 |
def vector_search(collection_name: str, query: str, k: int = 10) -> List[Dict[str, Any]]:
|
288 |
try:
|
|
|
289 |
collection = chroma_client.get_collection(name=collection_name)
|
290 |
+
|
291 |
+
# Fetch a sample of embeddings to check metadata
|
292 |
+
sample_results = collection.get(limit=10, include=["metadatas"])
|
293 |
+
if not sample_results['metadatas']:
|
294 |
+
raise ValueError("No metadata found in the collection")
|
295 |
+
|
296 |
+
# Check if all embeddings use the same model and provider
|
297 |
+
embedding_models = [metadata.get('embedding_model') for metadata in sample_results['metadatas'] if metadata.get('embedding_model')]
|
298 |
+
embedding_providers = [metadata.get('embedding_provider') for metadata in sample_results['metadatas'] if metadata.get('embedding_provider')]
|
299 |
+
|
300 |
+
if not embedding_models or not embedding_providers:
|
301 |
+
raise ValueError("Embedding model or provider information not found in metadata")
|
302 |
+
|
303 |
+
embedding_model = max(set(embedding_models), key=embedding_models.count)
|
304 |
+
embedding_provider = max(set(embedding_providers), key=embedding_providers.count)
|
305 |
+
|
306 |
+
logging.info(f"Using embedding model: {embedding_model} from provider: {embedding_provider}")
|
307 |
+
|
308 |
+
# Generate query embedding using the existing create_embedding function
|
309 |
+
query_embedding = create_embedding(query, embedding_provider, embedding_model, embedding_api_url)
|
310 |
+
|
311 |
+
# Ensure query_embedding is a list
|
312 |
+
if isinstance(query_embedding, np.ndarray):
|
313 |
+
query_embedding = query_embedding.tolist()
|
314 |
+
|
315 |
results = collection.query(
|
316 |
query_embeddings=[query_embedding],
|
317 |
n_results=k,
|
318 |
include=["documents", "metadatas"]
|
319 |
)
|
320 |
+
|
321 |
+
if not results['documents'][0]:
|
322 |
+
logging.warning("No results found for the query")
|
323 |
+
return []
|
324 |
+
|
325 |
return [{"content": doc, "metadata": meta} for doc, meta in zip(results['documents'][0], results['metadatas'][0])]
|
326 |
except Exception as e:
|
327 |
+
logging.error(f"Error in vector_search: {str(e)}", exc_info=True)
|
328 |
raise
|
329 |
|
330 |
+
|
331 |
def schedule_embedding(media_id: int, content: str, media_name: str):
|
332 |
try:
|
333 |
chunks = chunk_for_embedding(content, media_name, chunk_options)
|
|
|
375 |
|
376 |
#
|
377 |
# End of Functions for ChromaDB
|
378 |
+
#######################################################################################################################
|
379 |
+
|
380 |
+
|
381 |
+
# FIXME - Suggestions from ChatGPT:
|
382 |
+
# 2. Detailed Mapping and Assessment
|
383 |
+
# a. preprocess_all_content
|
384 |
+
#
|
385 |
+
# Test: test_preprocess_all_content
|
386 |
+
#
|
387 |
+
# Coverage:
|
388 |
+
#
|
389 |
+
# Mocks the get_unprocessed_media function to return a predefined unprocessed media list.
|
390 |
+
# Mocks process_and_store_content and mark_media_as_processed to verify their invocation with correct arguments.
|
391 |
+
# Asserts that process_and_store_content and mark_media_as_processed are called exactly once with expected parameters.
|
392 |
+
#
|
393 |
+
# Assessment:
|
394 |
+
#
|
395 |
+
# Strengths: Ensures that preprocess_all_content correctly retrieves unprocessed media, processes each item, and marks it as processed.
|
396 |
+
# Suggestions:
|
397 |
+
# Multiple Media Items: Test with multiple media items to verify loop handling.
|
398 |
+
# Exception Handling: Simulate exceptions within process_and_store_content to ensure proper logging and continuation or halting as intended.
|
399 |
+
#
|
400 |
+
# b. process_and_store_content
|
401 |
+
#
|
402 |
+
# Test: test_process_and_store_content
|
403 |
+
#
|
404 |
+
# Coverage:
|
405 |
+
#
|
406 |
+
# Mocks dependencies: chunk_for_embedding, process_chunks, situate_context, create_embeddings_batch, and chroma_client.
|
407 |
+
# Simulates the scenario where the specified ChromaDB collection does not exist initially and needs to be created.
|
408 |
+
# Verifies that chunks are processed, embeddings are created, stored in ChromaDB, and database queries are executed correctly.
|
409 |
+
#
|
410 |
+
# Assessment:
|
411 |
+
#
|
412 |
+
# Strengths: Thoroughly checks the workflow of processing content, including chunking, embedding creation, and storage.
|
413 |
+
# Suggestions:
|
414 |
+
# Existing Collection: Add a test case where the collection already exists to ensure that get_collection is used without attempting to create a new one.
|
415 |
+
# Embedding Creation Disabled: Test with create_embeddings=False to verify alternative code paths.
|
416 |
+
# Error Scenarios: Simulate failures in embedding creation or storage to ensure exceptions are handled gracefully.
|
417 |
+
#
|
418 |
+
# c. check_embedding_status
|
419 |
+
#
|
420 |
+
# Test: test_check_embedding_status
|
421 |
+
#
|
422 |
+
# Coverage:
|
423 |
+
#
|
424 |
+
# Mocks the ChromaDB client to return predefined embeddings and metadata.
|
425 |
+
# Verifies that the function correctly identifies the existence of embeddings and retrieves relevant metadata.
|
426 |
+
#
|
427 |
+
# Assessment:
|
428 |
+
#
|
429 |
+
# Strengths: Confirms that the function accurately detects existing embeddings and handles metadata appropriately.
|
430 |
+
# Suggestions:
|
431 |
+
# No Embeddings Found: Test the scenario where no embeddings exist for the selected item.
|
432 |
+
# Missing Metadata: Simulate missing or incomplete metadata to ensure robust error handling.
|
433 |
+
#
|
434 |
+
# d. reset_chroma_collection
|
435 |
+
#
|
436 |
+
# Test: test_reset_chroma_collection
|
437 |
+
#
|
438 |
+
# Coverage:
|
439 |
+
#
|
440 |
+
# Mocks the ChromaDB client’s delete_collection and create_collection methods.
|
441 |
+
# Verifies that the specified collection is deleted and recreated.
|
442 |
+
#
|
443 |
+
# Assessment:
|
444 |
+
#
|
445 |
+
# Strengths: Ensures that the reset operation performs both deletion and creation as intended.
|
446 |
+
# Suggestions:
|
447 |
+
# Non-Existent Collection: Test resetting a collection that does not exist to verify behavior.
|
448 |
+
# Exception Handling: Simulate failures during deletion or creation to check error logging and propagation.
|
449 |
+
#
|
450 |
+
# e. store_in_chroma
|
451 |
+
#
|
452 |
+
# Test: test_store_in_chroma
|
453 |
+
#
|
454 |
+
# Coverage:
|
455 |
+
#
|
456 |
+
# Mocks the ChromaDB client to return a mock collection.
|
457 |
+
# Verifies that documents, embeddings, IDs, and metadata are upserted correctly into the collection.
|
458 |
+
#
|
459 |
+
# Assessment:
|
460 |
+
#
|
461 |
+
# Strengths: Confirms that embeddings and associated data are stored accurately in ChromaDB.
|
462 |
+
# Suggestions:
|
463 |
+
# Empty Embeddings: Test storing with empty embeddings to ensure proper error handling.
|
464 |
+
# Embedding Dimension Mismatch: Simulate a dimension mismatch to verify that the function handles it as expected.
|
465 |
+
#
|
466 |
+
# f. vector_search
|
467 |
+
#
|
468 |
+
# Test: test_vector_search
|
469 |
+
#
|
470 |
+
# Coverage:
|
471 |
+
#
|
472 |
+
# Mocks the ChromaDB client’s get_collection, get, and query methods.
|
473 |
+
# Mocks the create_embedding function to return a predefined embedding.
|
474 |
+
# Verifies that the search retrieves the correct documents and metadata based on the query.
|
475 |
+
#
|
476 |
+
# Assessment:
|
477 |
+
#
|
478 |
+
# Strengths: Ensures that the vector search mechanism correctly interacts with ChromaDB and returns expected results.
|
479 |
+
# Suggestions:
|
480 |
+
# No Results Found: Test queries that return no results to verify handling.
|
481 |
+
# Multiple Results: Ensure that multiple documents are retrieved and correctly formatted.
|
482 |
+
# Metadata Variations: Test with diverse metadata to confirm accurate retrieval.
|
483 |
+
#
|
484 |
+
# g. batched
|
485 |
+
#
|
486 |
+
# Test: test_batched
|
487 |
+
#
|
488 |
+
# Coverage:
|
489 |
+
#
|
490 |
+
# Uses pytest.mark.parametrize to test multiple scenarios:
|
491 |
+
# Regular batching.
|
492 |
+
# Batch size larger than the iterable.
|
493 |
+
# Empty iterable.
|
494 |
+
#
|
495 |
+
# Assessment:
|
496 |
+
#
|
497 |
+
# Strengths: Comprehensive coverage of typical and edge batching scenarios.
|
498 |
+
# Suggestions:
|
499 |
+
# Non-Integer Batch Sizes: Test with invalid batch sizes (e.g., zero, negative numbers) to ensure proper handling or error raising.
|
500 |
+
#
|
501 |
+
# h. situate_context and schedule_embedding
|
502 |
+
#
|
503 |
+
# Tests: Not directly tested
|
504 |
+
#
|
505 |
+
# Coverage:
|
506 |
+
#
|
507 |
+
# These functions are currently not directly tested in the test_chromadb.py suite.
|
508 |
+
#
|
509 |
+
# Assessment:
|
510 |
+
#
|
511 |
+
# Suggestions:
|
512 |
+
# situate_context:
|
513 |
+
# Unit Test: Since it's a pure function that interacts with the summarize function, create a separate test to mock summarize and verify the context generation.
|
514 |
+
# Edge Cases: Test with empty strings, very long texts, or special characters to ensure robustness.
|
515 |
+
# schedule_embedding:
|
516 |
+
# Integration Test: Since it orchestrates multiple operations (chunking, embedding creation, storage), consider writing an integration test that mocks all dependent functions and verifies the complete workflow.
|
App_Function_Libraries/RAG/Embeddings_Create.py
CHANGED
@@ -3,12 +3,15 @@
|
|
3 |
#
|
4 |
# Imports:
|
5 |
import logging
|
|
|
6 |
import time
|
7 |
from functools import wraps
|
8 |
from threading import Lock, Timer
|
9 |
from typing import List
|
10 |
#
|
11 |
# 3rd-Party Imports:
|
|
|
|
|
12 |
import requests
|
13 |
from transformers import AutoTokenizer, AutoModel
|
14 |
import torch
|
@@ -16,44 +19,75 @@ import torch
|
|
16 |
# Local Imports:
|
17 |
from App_Function_Libraries.LLM_API_Calls import get_openai_embeddings
|
18 |
from App_Function_Libraries.Utils.Utils import load_comprehensive_config
|
|
|
19 |
#
|
20 |
#######################################################################################################################
|
21 |
#
|
22 |
# Functions:
|
23 |
|
24 |
-
# FIXME -
|
|
|
|
|
25 |
loaded_config = load_comprehensive_config()
|
26 |
embedding_provider = loaded_config['Embeddings']['embedding_provider']
|
27 |
embedding_model = loaded_config['Embeddings']['embedding_model']
|
28 |
embedding_api_url = loaded_config['Embeddings']['embedding_api_url']
|
29 |
embedding_api_key = loaded_config['Embeddings']['embedding_api_key']
|
|
|
30 |
|
31 |
# Embedding Chunking Settings
|
32 |
chunk_size = loaded_config['Embeddings']['chunk_size']
|
33 |
overlap = loaded_config['Embeddings']['overlap']
|
34 |
|
|
|
|
|
35 |
|
36 |
-
#
|
|
|
|
|
|
|
|
|
|
|
37 |
|
38 |
class HuggingFaceEmbedder:
|
39 |
-
def __init__(self, model_name, timeout_seconds=
|
40 |
self.model_name = model_name
|
|
|
41 |
self.tokenizer = None
|
42 |
self.model = None
|
43 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
44 |
self.timeout_seconds = timeout_seconds
|
45 |
self.last_used_time = 0
|
46 |
self.unload_timer = None
|
|
|
47 |
|
48 |
def load_model(self):
|
|
|
|
|
|
|
49 |
if self.model is None:
|
50 |
-
|
51 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
self.model.to(self.device)
|
53 |
self.last_used_time = time.time()
|
54 |
self.reset_timer()
|
|
|
|
|
|
|
55 |
|
56 |
def unload_model(self):
|
|
|
57 |
if self.model is not None:
|
58 |
del self.model
|
59 |
del self.tokenizer
|
@@ -71,17 +105,119 @@ class HuggingFaceEmbedder:
|
|
71 |
self.unload_timer.start()
|
72 |
|
73 |
def create_embeddings(self, texts):
|
|
|
|
|
74 |
self.load_model()
|
75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
|
82 |
-
|
83 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
|
86 |
class RateLimiter:
|
87 |
def __init__(self, max_calls, period):
|
@@ -102,7 +238,6 @@ class RateLimiter:
|
|
102 |
return func(*args, **kwargs)
|
103 |
return wrapper
|
104 |
|
105 |
-
|
106 |
def exponential_backoff(max_retries=5, base_delay=1):
|
107 |
def decorator(func):
|
108 |
@wraps(func)
|
@@ -119,72 +254,353 @@ def exponential_backoff(max_retries=5, base_delay=1):
|
|
119 |
return wrapper
|
120 |
return decorator
|
121 |
|
122 |
-
|
123 |
-
# FIXME - refactor/setup to use config file & perform chunking
|
124 |
@exponential_backoff()
|
125 |
-
@RateLimiter(max_calls=50, period=60)
|
126 |
-
def create_embeddings_batch(texts: List[str],
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
embeddings = huggingface_embedder.create_embeddings(texts).tolist()
|
137 |
-
return embeddings
|
138 |
-
|
139 |
-
elif provider.lower() == 'openai':
|
140 |
-
logging.debug(f"Creating embeddings for {len(texts)} texts using OpenAI API")
|
141 |
-
return [create_openai_embedding(text, model) for text in texts]
|
142 |
-
|
143 |
-
elif provider.lower() == 'local':
|
144 |
-
response = requests.post(
|
145 |
-
api_url,
|
146 |
-
json={"texts": texts, "model": model},
|
147 |
-
headers={"Authorization": f"Bearer {embedding_api_key}"}
|
148 |
-
)
|
149 |
-
if response.status_code == 200:
|
150 |
-
return response.json()['embeddings']
|
151 |
-
else:
|
152 |
-
raise Exception(f"Error from local API: {response.text}")
|
153 |
-
else:
|
154 |
-
raise ValueError(f"Unsupported embedding provider: {provider}")
|
155 |
-
|
156 |
-
|
157 |
-
def create_embedding(text: str, provider: str, model: str, api_url: str) -> List[float]:
|
158 |
-
return create_embeddings_batch([text], provider, model, api_url)[0]
|
159 |
-
|
160 |
-
# FIXME
|
161 |
-
def create_stella_embeddings(text: str) -> List[float]:
|
162 |
-
if embedding_provider == 'local':
|
163 |
-
# Load the model and tokenizer
|
164 |
-
tokenizer = AutoTokenizer.from_pretrained("dunzhang/stella_en_400M_v5")
|
165 |
-
model = AutoModel.from_pretrained("dunzhang/stella_en_400M_v5")
|
166 |
|
167 |
-
|
168 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
|
170 |
-
|
171 |
-
|
172 |
-
|
|
|
|
|
|
|
|
|
173 |
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
182 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
183 |
|
184 |
def create_openai_embedding(text: str, model: str) -> List[float]:
|
|
|
|
|
185 |
embedding = get_openai_embeddings(text, model)
|
|
|
|
|
|
|
186 |
return embedding
|
187 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
188 |
#
|
189 |
# End of File.
|
190 |
#######################################################################################################################
|
|
|
3 |
#
|
4 |
# Imports:
|
5 |
import logging
|
6 |
+
import os
|
7 |
import time
|
8 |
from functools import wraps
|
9 |
from threading import Lock, Timer
|
10 |
from typing import List
|
11 |
#
|
12 |
# 3rd-Party Imports:
|
13 |
+
import numpy as np
|
14 |
+
import onnxruntime as ort
|
15 |
import requests
|
16 |
from transformers import AutoTokenizer, AutoModel
|
17 |
import torch
|
|
|
19 |
# Local Imports:
|
20 |
from App_Function_Libraries.LLM_API_Calls import get_openai_embeddings
|
21 |
from App_Function_Libraries.Utils.Utils import load_comprehensive_config
|
22 |
+
from App_Function_Libraries.Metrics.metrics_logger import log_counter, log_histogram
|
23 |
#
|
24 |
#######################################################################################################################
|
25 |
#
|
26 |
# Functions:
|
27 |
|
28 |
+
# FIXME - Version 2
|
29 |
+
|
30 |
+
# Load configuration
|
31 |
loaded_config = load_comprehensive_config()
|
32 |
embedding_provider = loaded_config['Embeddings']['embedding_provider']
|
33 |
embedding_model = loaded_config['Embeddings']['embedding_model']
|
34 |
embedding_api_url = loaded_config['Embeddings']['embedding_api_url']
|
35 |
embedding_api_key = loaded_config['Embeddings']['embedding_api_key']
|
36 |
+
model_dir = loaded_config['Embeddings'].get('model_dir', './App_Function_Libraries/models/embedding_models/')
|
37 |
|
38 |
# Embedding Chunking Settings
|
39 |
chunk_size = loaded_config['Embeddings']['chunk_size']
|
40 |
overlap = loaded_config['Embeddings']['overlap']
|
41 |
|
42 |
+
# Global cache for embedding models
|
43 |
+
embedding_models = {}
|
44 |
|
45 |
+
# Commit hashes
|
46 |
+
commit_hashes = {
|
47 |
+
"jinaai/jina-embeddings-v3": "4be32c2f5d65b95e4bcce473545b7883ec8d2edd",
|
48 |
+
"Alibaba-NLP/gte-large-en-v1.5": "104333d6af6f97649377c2afbde10a7704870c7b",
|
49 |
+
"dunzhang/setll_en_400M_v5": "2aa5579fcae1c579de199a3866b6e514bbbf5d10"
|
50 |
+
}
|
51 |
|
52 |
class HuggingFaceEmbedder:
|
53 |
+
def __init__(self, model_name, cache_dir, timeout_seconds=30):
|
54 |
self.model_name = model_name
|
55 |
+
self.cache_dir = cache_dir # Store cache_dir
|
56 |
self.tokenizer = None
|
57 |
self.model = None
|
58 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
59 |
self.timeout_seconds = timeout_seconds
|
60 |
self.last_used_time = 0
|
61 |
self.unload_timer = None
|
62 |
+
log_counter("huggingface_embedder_init", labels={"model_name": model_name})
|
63 |
|
64 |
def load_model(self):
|
65 |
+
log_counter("huggingface_model_load_attempt", labels={"model_name": self.model_name})
|
66 |
+
start_time = time.time()
|
67 |
+
# https://huggingface.co/docs/transformers/custom_models
|
68 |
if self.model is None:
|
69 |
+
# Pass cache_dir to from_pretrained to specify download directory
|
70 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
71 |
+
self.model_name,
|
72 |
+
trust_remote_code=True,
|
73 |
+
cache_dir=self.cache_dir, # Specify cache directory
|
74 |
+
revision=commit_hashes.get(self.model_name, None) # Pass commit hash
|
75 |
+
)
|
76 |
+
self.model = AutoModel.from_pretrained(
|
77 |
+
self.model_name,
|
78 |
+
trust_remote_code=True,
|
79 |
+
cache_dir=self.cache_dir, # Specify cache directory
|
80 |
+
revision=commit_hashes.get(self.model_name, None) # Pass commit hash
|
81 |
+
)
|
82 |
self.model.to(self.device)
|
83 |
self.last_used_time = time.time()
|
84 |
self.reset_timer()
|
85 |
+
load_time = time.time() - start_time
|
86 |
+
log_histogram("huggingface_model_load_duration", load_time, labels={"model_name": self.model_name})
|
87 |
+
log_counter("huggingface_model_load_success", labels={"model_name": self.model_name})
|
88 |
|
89 |
def unload_model(self):
|
90 |
+
log_counter("huggingface_model_unload", labels={"model_name": self.model_name})
|
91 |
if self.model is not None:
|
92 |
del self.model
|
93 |
del self.tokenizer
|
|
|
105 |
self.unload_timer.start()
|
106 |
|
107 |
def create_embeddings(self, texts):
|
108 |
+
log_counter("huggingface_create_embeddings_attempt", labels={"model_name": self.model_name})
|
109 |
+
start_time = time.time()
|
110 |
self.load_model()
|
111 |
+
# https://huggingface.co/docs/transformers/custom_models
|
112 |
+
inputs = self.tokenizer(
|
113 |
+
texts,
|
114 |
+
return_tensors="pt",
|
115 |
+
padding=True,
|
116 |
+
truncation=True,
|
117 |
+
max_length=512
|
118 |
+
)
|
119 |
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
120 |
+
try:
|
121 |
+
with torch.no_grad():
|
122 |
+
outputs = self.model(**inputs)
|
123 |
+
embeddings = outputs.last_hidden_state.mean(dim=1)
|
124 |
+
return embeddings.cpu().float().numpy() # Convert to float32 before returning
|
125 |
+
except RuntimeError as e:
|
126 |
+
if "Got unsupported ScalarType BFloat16" in str(e):
|
127 |
+
logging.warning("BFloat16 not supported. Falling back to float32.")
|
128 |
+
# Convert model to float32
|
129 |
+
self.model = self.model.float()
|
130 |
+
with torch.no_grad():
|
131 |
+
outputs = self.model(**inputs)
|
132 |
+
embeddings = outputs.last_hidden_state.mean(dim=1)
|
133 |
+
embedding_time = time.time() - start_time
|
134 |
+
log_histogram("huggingface_create_embeddings_duration", embedding_time,
|
135 |
+
labels={"model_name": self.model_name})
|
136 |
+
log_counter("huggingface_create_embeddings_success", labels={"model_name": self.model_name})
|
137 |
+
return embeddings.cpu().float().numpy()
|
138 |
+
else:
|
139 |
+
log_counter("huggingface_create_embeddings_failure", labels={"model_name": self.model_name})
|
140 |
+
raise
|
141 |
+
|
142 |
+
class ONNXEmbedder:
|
143 |
+
def __init__(self, model_name, onnx_model_dir, timeout_seconds=30):
|
144 |
+
self.model_name = model_name
|
145 |
+
self.model_path = os.path.join(onnx_model_dir, f"{model_name}.onnx")
|
146 |
+
# https://huggingface.co/docs/transformers/custom_models
|
147 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
148 |
+
model_name,
|
149 |
+
trust_remote_code=True,
|
150 |
+
cache_dir=onnx_model_dir, # Ensure tokenizer uses the same directory
|
151 |
+
revision=commit_hashes.get(model_name, None) # Pass commit hash
|
152 |
+
)
|
153 |
+
self.session = None
|
154 |
+
self.timeout_seconds = timeout_seconds
|
155 |
+
self.last_used_time = 0
|
156 |
+
self.unload_timer = None
|
157 |
+
self.device = "cpu" # ONNX Runtime will default to CPU unless GPU is configured
|
158 |
+
log_counter("onnx_embedder_init", labels={"model_name": model_name})
|
159 |
+
|
160 |
+
def load_model(self):
|
161 |
+
log_counter("onnx_model_load_attempt", labels={"model_name": self.model_name})
|
162 |
+
start_time = time.time()
|
163 |
+
if self.session is None:
|
164 |
+
if not os.path.exists(self.model_path):
|
165 |
+
raise FileNotFoundError(f"ONNX model not found at {self.model_path}")
|
166 |
+
logging.info(f"Loading ONNX model from {self.model_path}")
|
167 |
+
self.session = ort.InferenceSession(self.model_path)
|
168 |
+
self.last_used_time = time.time()
|
169 |
+
self.reset_timer()
|
170 |
+
load_time = time.time() - start_time
|
171 |
+
log_histogram("onnx_model_load_duration", load_time, labels={"model_name": self.model_name})
|
172 |
+
log_counter("onnx_model_load_success", labels={"model_name": self.model_name})
|
173 |
|
174 |
+
def unload_model(self):
|
175 |
+
log_counter("onnx_model_unload", labels={"model_name": self.model_name})
|
176 |
+
if self.session is not None:
|
177 |
+
logging.info("Unloading ONNX model to free resources.")
|
178 |
+
self.session = None
|
179 |
+
if self.unload_timer:
|
180 |
+
self.unload_timer.cancel()
|
181 |
+
|
182 |
+
def reset_timer(self):
|
183 |
+
if self.unload_timer:
|
184 |
+
self.unload_timer.cancel()
|
185 |
+
self.unload_timer = Timer(self.timeout_seconds, self.unload_model)
|
186 |
+
self.unload_timer.start()
|
187 |
+
|
188 |
+
def create_embeddings(self, texts: List[str]) -> List[List[float]]:
|
189 |
+
log_counter("onnx_create_embeddings_attempt", labels={"model_name": self.model_name})
|
190 |
+
start_time = time.time()
|
191 |
+
self.load_model()
|
192 |
+
try:
|
193 |
+
inputs = self.tokenizer(
|
194 |
+
texts,
|
195 |
+
return_tensors="np",
|
196 |
+
padding=True,
|
197 |
+
truncation=True,
|
198 |
+
max_length=512
|
199 |
+
)
|
200 |
+
input_ids = inputs["input_ids"].astype(np.int64)
|
201 |
+
attention_mask = inputs["attention_mask"].astype(np.int64)
|
202 |
|
203 |
+
ort_inputs = {
|
204 |
+
"input_ids": input_ids,
|
205 |
+
"attention_mask": attention_mask
|
206 |
+
}
|
207 |
+
|
208 |
+
ort_outputs = self.session.run(None, ort_inputs)
|
209 |
+
|
210 |
+
last_hidden_state = ort_outputs[0]
|
211 |
+
embeddings = np.mean(last_hidden_state, axis=1)
|
212 |
+
|
213 |
+
embedding_time = time.time() - start_time
|
214 |
+
log_histogram("onnx_create_embeddings_duration", embedding_time, labels={"model_name": self.model_name})
|
215 |
+
log_counter("onnx_create_embeddings_success", labels={"model_name": self.model_name})
|
216 |
+
return embeddings.tolist()
|
217 |
+
except Exception as e:
|
218 |
+
log_counter("onnx_create_embeddings_failure", labels={"model_name": self.model_name})
|
219 |
+
logging.error(f"Error creating embeddings with ONNX model: {str(e)}")
|
220 |
+
raise
|
221 |
|
222 |
class RateLimiter:
|
223 |
def __init__(self, max_calls, period):
|
|
|
238 |
return func(*args, **kwargs)
|
239 |
return wrapper
|
240 |
|
|
|
241 |
def exponential_backoff(max_retries=5, base_delay=1):
|
242 |
def decorator(func):
|
243 |
@wraps(func)
|
|
|
254 |
return wrapper
|
255 |
return decorator
|
256 |
|
|
|
|
|
257 |
@exponential_backoff()
|
258 |
+
@RateLimiter(max_calls=50, period=60)
|
259 |
+
def create_embeddings_batch(texts: List[str],
|
260 |
+
provider: str,
|
261 |
+
model: str,
|
262 |
+
api_url: str,
|
263 |
+
timeout_seconds: int = 300
|
264 |
+
) -> List[List[float]]:
|
265 |
+
global embedding_models
|
266 |
+
log_counter("create_embeddings_batch_attempt", labels={"provider": provider, "model": model})
|
267 |
+
start_time = time.time()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
268 |
|
269 |
+
try:
|
270 |
+
if provider.lower() == 'huggingface':
|
271 |
+
if model not in embedding_models:
|
272 |
+
if model == "dunzhang/stella_en_400M_v5":
|
273 |
+
embedding_models[model] = ONNXEmbedder(model, model_dir, timeout_seconds)
|
274 |
+
else:
|
275 |
+
# Pass model_dir to HuggingFaceEmbedder
|
276 |
+
embedding_models[model] = HuggingFaceEmbedder(model, model_dir, timeout_seconds)
|
277 |
+
embedder = embedding_models[model]
|
278 |
+
embedding_time = time.time() - start_time
|
279 |
+
log_histogram("create_embeddings_batch_duration", embedding_time,
|
280 |
+
labels={"provider": provider, "model": model})
|
281 |
+
log_counter("create_embeddings_batch_success", labels={"provider": provider, "model": model})
|
282 |
+
return embedder.create_embeddings(texts)
|
283 |
|
284 |
+
elif provider.lower() == 'openai':
|
285 |
+
logging.debug(f"Creating embeddings for {len(texts)} texts using OpenAI API")
|
286 |
+
embedding_time = time.time() - start_time
|
287 |
+
log_histogram("create_embeddings_batch_duration", embedding_time,
|
288 |
+
labels={"provider": provider, "model": model})
|
289 |
+
log_counter("create_embeddings_batch_success", labels={"provider": provider, "model": model})
|
290 |
+
return [create_openai_embedding(text, model) for text in texts]
|
291 |
|
292 |
+
elif provider.lower() == 'local':
|
293 |
+
response = requests.post(
|
294 |
+
api_url,
|
295 |
+
json={"texts": texts, "model": model},
|
296 |
+
headers={"Authorization": f"Bearer {embedding_api_key}"}
|
297 |
+
)
|
298 |
+
if response.status_code == 200:
|
299 |
+
embedding_time = time.time() - start_time
|
300 |
+
log_histogram("create_embeddings_batch_duration", embedding_time,
|
301 |
+
labels={"provider": provider, "model": model})
|
302 |
+
log_counter("create_embeddings_batch_success", labels={"provider": provider, "model": model})
|
303 |
+
return response.json()['embeddings']
|
304 |
+
else:
|
305 |
+
raise Exception(f"Error from local API: {response.text}")
|
306 |
+
else:
|
307 |
+
raise ValueError(f"Unsupported embedding provider: {provider}")
|
308 |
+
except Exception as e:
|
309 |
+
log_counter("create_embeddings_batch_error", labels={"provider": provider, "model": model, "error": str(e)})
|
310 |
+
logging.error(f"Error in create_embeddings_batch: {str(e)}")
|
311 |
+
raise
|
312 |
|
313 |
+
def create_embedding(text: str, provider: str, model: str, api_url: str) -> List[float]:
|
314 |
+
log_counter("create_embedding_attempt", labels={"provider": provider, "model": model})
|
315 |
+
start_time = time.time()
|
316 |
+
embedding = create_embeddings_batch([text], provider, model, api_url)[0]
|
317 |
+
if isinstance(embedding, np.ndarray):
|
318 |
+
embedding = embedding.tolist()
|
319 |
+
embedding_time = time.time() - start_time
|
320 |
+
log_histogram("create_embedding_duration", embedding_time, labels={"provider": provider, "model": model})
|
321 |
+
log_counter("create_embedding_success", labels={"provider": provider, "model": model})
|
322 |
+
return embedding
|
323 |
|
324 |
def create_openai_embedding(text: str, model: str) -> List[float]:
|
325 |
+
log_counter("create_openai_embedding_attempt", labels={"model": model})
|
326 |
+
start_time = time.time()
|
327 |
embedding = get_openai_embeddings(text, model)
|
328 |
+
embedding_time = time.time() - start_time
|
329 |
+
log_histogram("create_openai_embedding_duration", embedding_time, labels={"model": model})
|
330 |
+
log_counter("create_openai_embedding_success", labels={"model": model})
|
331 |
return embedding
|
332 |
|
333 |
+
|
334 |
+
# FIXME - Version 1
|
335 |
+
# # FIXME - Add all globals to summarize.py
|
336 |
+
# loaded_config = load_comprehensive_config()
|
337 |
+
# embedding_provider = loaded_config['Embeddings']['embedding_provider']
|
338 |
+
# embedding_model = loaded_config['Embeddings']['embedding_model']
|
339 |
+
# embedding_api_url = loaded_config['Embeddings']['embedding_api_url']
|
340 |
+
# embedding_api_key = loaded_config['Embeddings']['embedding_api_key']
|
341 |
+
#
|
342 |
+
# # Embedding Chunking Settings
|
343 |
+
# chunk_size = loaded_config['Embeddings']['chunk_size']
|
344 |
+
# overlap = loaded_config['Embeddings']['overlap']
|
345 |
+
#
|
346 |
+
#
|
347 |
+
# # FIXME - Add logging
|
348 |
+
#
|
349 |
+
# class HuggingFaceEmbedder:
|
350 |
+
# def __init__(self, model_name, timeout_seconds=120): # Default timeout of 2 minutes
|
351 |
+
# self.model_name = model_name
|
352 |
+
# self.tokenizer = None
|
353 |
+
# self.model = None
|
354 |
+
# self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
355 |
+
# self.timeout_seconds = timeout_seconds
|
356 |
+
# self.last_used_time = 0
|
357 |
+
# self.unload_timer = None
|
358 |
+
#
|
359 |
+
# def load_model(self):
|
360 |
+
# if self.model is None:
|
361 |
+
# self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
362 |
+
# self.model = AutoModel.from_pretrained(self.model_name)
|
363 |
+
# self.model.to(self.device)
|
364 |
+
# self.last_used_time = time.time()
|
365 |
+
# self.reset_timer()
|
366 |
+
#
|
367 |
+
# def unload_model(self):
|
368 |
+
# if self.model is not None:
|
369 |
+
# del self.model
|
370 |
+
# del self.tokenizer
|
371 |
+
# if torch.cuda.is_available():
|
372 |
+
# torch.cuda.empty_cache()
|
373 |
+
# self.model = None
|
374 |
+
# self.tokenizer = None
|
375 |
+
# if self.unload_timer:
|
376 |
+
# self.unload_timer.cancel()
|
377 |
+
#
|
378 |
+
# def reset_timer(self):
|
379 |
+
# if self.unload_timer:
|
380 |
+
# self.unload_timer.cancel()
|
381 |
+
# self.unload_timer = Timer(self.timeout_seconds, self.unload_model)
|
382 |
+
# self.unload_timer.start()
|
383 |
+
#
|
384 |
+
# def create_embeddings(self, texts):
|
385 |
+
# self.load_model()
|
386 |
+
# inputs = self.tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=512)
|
387 |
+
# inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
388 |
+
# with torch.no_grad():
|
389 |
+
# outputs = self.model(**inputs)
|
390 |
+
# embeddings = outputs.last_hidden_state.mean(dim=1)
|
391 |
+
# return embeddings.cpu().numpy()
|
392 |
+
#
|
393 |
+
# # Global variable to hold the embedder
|
394 |
+
# huggingface_embedder = None
|
395 |
+
#
|
396 |
+
#
|
397 |
+
# class RateLimiter:
|
398 |
+
# def __init__(self, max_calls, period):
|
399 |
+
# self.max_calls = max_calls
|
400 |
+
# self.period = period
|
401 |
+
# self.calls = []
|
402 |
+
# self.lock = Lock()
|
403 |
+
#
|
404 |
+
# def __call__(self, func):
|
405 |
+
# def wrapper(*args, **kwargs):
|
406 |
+
# with self.lock:
|
407 |
+
# now = time.time()
|
408 |
+
# self.calls = [call for call in self.calls if call > now - self.period]
|
409 |
+
# if len(self.calls) >= self.max_calls:
|
410 |
+
# sleep_time = self.calls[0] - (now - self.period)
|
411 |
+
# time.sleep(sleep_time)
|
412 |
+
# self.calls.append(time.time())
|
413 |
+
# return func(*args, **kwargs)
|
414 |
+
# return wrapper
|
415 |
+
#
|
416 |
+
#
|
417 |
+
# def exponential_backoff(max_retries=5, base_delay=1):
|
418 |
+
# def decorator(func):
|
419 |
+
# @wraps(func)
|
420 |
+
# def wrapper(*args, **kwargs):
|
421 |
+
# for attempt in range(max_retries):
|
422 |
+
# try:
|
423 |
+
# return func(*args, **kwargs)
|
424 |
+
# except Exception as e:
|
425 |
+
# if attempt == max_retries - 1:
|
426 |
+
# raise
|
427 |
+
# delay = base_delay * (2 ** attempt)
|
428 |
+
# logging.warning(f"Attempt {attempt + 1} failed. Retrying in {delay} seconds. Error: {str(e)}")
|
429 |
+
# time.sleep(delay)
|
430 |
+
# return wrapper
|
431 |
+
# return decorator
|
432 |
+
#
|
433 |
+
#
|
434 |
+
# # FIXME - refactor/setup to use config file & perform chunking
|
435 |
+
# @exponential_backoff()
|
436 |
+
# @RateLimiter(max_calls=50, period=60)
|
437 |
+
# def create_embeddings_batch(texts: List[str], provider: str, model: str, api_url: str, timeout_seconds: int = 300) -> List[List[float]]:
|
438 |
+
# global embedding_models
|
439 |
+
#
|
440 |
+
# try:
|
441 |
+
# if provider.lower() == 'huggingface':
|
442 |
+
# if model not in embedding_models:
|
443 |
+
# if model == "dunzhang/stella_en_400M_v5":
|
444 |
+
# embedding_models[model] = ONNXEmbedder(model, model_dir, timeout_seconds)
|
445 |
+
# else:
|
446 |
+
# embedding_models[model] = HuggingFaceEmbedder(model, timeout_seconds)
|
447 |
+
# embedder = embedding_models[model]
|
448 |
+
# return embedder.create_embeddings(texts)
|
449 |
+
#
|
450 |
+
# elif provider.lower() == 'openai':
|
451 |
+
# logging.debug(f"Creating embeddings for {len(texts)} texts using OpenAI API")
|
452 |
+
# return [create_openai_embedding(text, model) for text in texts]
|
453 |
+
#
|
454 |
+
# elif provider.lower() == 'local':
|
455 |
+
# response = requests.post(
|
456 |
+
# api_url,
|
457 |
+
# json={"texts": texts, "model": model},
|
458 |
+
# headers={"Authorization": f"Bearer {embedding_api_key}"}
|
459 |
+
# )
|
460 |
+
# if response.status_code == 200:
|
461 |
+
# return response.json()['embeddings']
|
462 |
+
# else:
|
463 |
+
# raise Exception(f"Error from local API: {response.text}")
|
464 |
+
# else:
|
465 |
+
# raise ValueError(f"Unsupported embedding provider: {provider}")
|
466 |
+
# except Exception as e:
|
467 |
+
# logging.error(f"Error in create_embeddings_batch: {str(e)}")
|
468 |
+
# raise
|
469 |
+
#
|
470 |
+
# def create_embedding(text: str, provider: str, model: str, api_url: str) -> List[float]:
|
471 |
+
# return create_embeddings_batch([text], provider, model, api_url)[0]
|
472 |
+
#
|
473 |
+
#
|
474 |
+
# def create_openai_embedding(text: str, model: str) -> List[float]:
|
475 |
+
# embedding = get_openai_embeddings(text, model)
|
476 |
+
# return embedding
|
477 |
+
#
|
478 |
+
#
|
479 |
+
# # FIXME - refactor to use onnx embeddings callout
|
480 |
+
# def create_stella_embeddings(text: str) -> List[float]:
|
481 |
+
# if embedding_provider == 'local':
|
482 |
+
# # Load the model and tokenizer
|
483 |
+
# tokenizer = AutoTokenizer.from_pretrained("dunzhang/stella_en_400M_v5")
|
484 |
+
# model = AutoModel.from_pretrained("dunzhang/stella_en_400M_v5")
|
485 |
+
#
|
486 |
+
# # Tokenize and encode the text
|
487 |
+
# inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
|
488 |
+
#
|
489 |
+
# # Generate embeddings
|
490 |
+
# with torch.no_grad():
|
491 |
+
# outputs = model(**inputs)
|
492 |
+
#
|
493 |
+
# # Use the mean of the last hidden state as the sentence embedding
|
494 |
+
# embeddings = outputs.last_hidden_state.mean(dim=1)
|
495 |
+
#
|
496 |
+
# return embeddings[0].tolist() # Convert to list for consistency
|
497 |
+
# elif embedding_provider == 'openai':
|
498 |
+
# return get_openai_embeddings(text, embedding_model)
|
499 |
+
# else:
|
500 |
+
# raise ValueError(f"Unsupported embedding provider: {embedding_provider}")
|
501 |
+
# #
|
502 |
+
# # End of F
|
503 |
+
# ##############################################################
|
504 |
+
#
|
505 |
+
#
|
506 |
+
# ##############################################################
|
507 |
+
# #
|
508 |
+
# # ONNX Embeddings Functions
|
509 |
+
#
|
510 |
+
# # FIXME - UPDATE
|
511 |
+
# # Define the model path
|
512 |
+
# model_dir = "/tldw/App_Function_Libraries/models/embedding_models/"
|
513 |
+
# model_name = "your-huggingface-model-name"
|
514 |
+
# onnx_model_path = os.path.join(model_dir, model_name, "model.onnx")
|
515 |
+
#
|
516 |
+
# # Tokenizer download (if applicable)
|
517 |
+
# #tokenizer = AutoTokenizer.from_pretrained(model_name)
|
518 |
+
#
|
519 |
+
# # Ensure the model directory exists
|
520 |
+
# #if not os.path.exists(onnx_model_path):
|
521 |
+
# # You can add logic to download the ONNX model from a remote source
|
522 |
+
# # if it's not already available in the folder.
|
523 |
+
# # Example: huggingface_hub.download (if model is hosted on Hugging Face Hub)
|
524 |
+
# # raise Exception(f"ONNX model not found at {onnx_model_path}")
|
525 |
+
#
|
526 |
+
# class ONNXEmbedder:
|
527 |
+
# def __init__(self, model_name, model_dir, timeout_seconds=120):
|
528 |
+
# self.model_name = model_name
|
529 |
+
# self.model_path = os.path.join(model_dir, f"{model_name}.onnx")
|
530 |
+
# self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
531 |
+
# self.session = None
|
532 |
+
# self.timeout_seconds = timeout_seconds
|
533 |
+
# self.last_used_time = 0
|
534 |
+
# self.unload_timer = None
|
535 |
+
# self.device = "cpu" # ONNX Runtime will default to CPU unless GPU is configured
|
536 |
+
#
|
537 |
+
# def load_model(self):
|
538 |
+
# if self.session is None:
|
539 |
+
# if not os.path.exists(self.model_path):
|
540 |
+
# raise FileNotFoundError(f"ONNX model not found at {self.model_path}")
|
541 |
+
# logging.info(f"Loading ONNX model from {self.model_path}")
|
542 |
+
# self.session = ort.InferenceSession(self.model_path)
|
543 |
+
# self.last_used_time = time.time()
|
544 |
+
# self.reset_timer()
|
545 |
+
#
|
546 |
+
# def unload_model(self):
|
547 |
+
# if self.session is not None:
|
548 |
+
# logging.info("Unloading ONNX model to free resources.")
|
549 |
+
# self.session = None
|
550 |
+
# if self.unload_timer:
|
551 |
+
# self.unload_timer.cancel()
|
552 |
+
#
|
553 |
+
# def reset_timer(self):
|
554 |
+
# if self.unload_timer:
|
555 |
+
# self.unload_timer.cancel()
|
556 |
+
# self.unload_timer = Timer(self.timeout_seconds, self.unload_model)
|
557 |
+
# self.unload_timer.start()
|
558 |
+
#
|
559 |
+
# def create_embeddings(self, texts: List[str]) -> List[List[float]]:
|
560 |
+
# self.load_model()
|
561 |
+
#
|
562 |
+
# try:
|
563 |
+
# inputs = self.tokenizer(texts, return_tensors="np", padding=True, truncation=True, max_length=512)
|
564 |
+
# input_ids = inputs["input_ids"].astype(np.int64)
|
565 |
+
# attention_mask = inputs["attention_mask"].astype(np.int64)
|
566 |
+
#
|
567 |
+
# ort_inputs = {
|
568 |
+
# "input_ids": input_ids,
|
569 |
+
# "attention_mask": attention_mask
|
570 |
+
# }
|
571 |
+
#
|
572 |
+
# ort_outputs = self.session.run(None, ort_inputs)
|
573 |
+
#
|
574 |
+
# last_hidden_state = ort_outputs[0]
|
575 |
+
# embeddings = np.mean(last_hidden_state, axis=1)
|
576 |
+
#
|
577 |
+
# return embeddings.tolist()
|
578 |
+
# except Exception as e:
|
579 |
+
# logging.error(f"Error creating embeddings with ONNX model: {str(e)}")
|
580 |
+
# raise
|
581 |
+
#
|
582 |
+
# # Global cache for the ONNX embedder instance
|
583 |
+
# onnx_embedder = None
|
584 |
+
#
|
585 |
+
# # Global cache for embedding models
|
586 |
+
# embedding_models = {}
|
587 |
+
#
|
588 |
+
# def create_onnx_embeddings(texts: List[str]) -> List[List[float]]:
|
589 |
+
# global onnx_embedder
|
590 |
+
# model_dir = "/tldw/App_Function_Libraries/models/embedding_models/"
|
591 |
+
# model_name = "your-huggingface-model-name" # This can be pulled from config
|
592 |
+
#
|
593 |
+
# if onnx_embedder is None:
|
594 |
+
# onnx_embedder = ONNXEmbedder(model_name=model_name, model_dir=model_dir)
|
595 |
+
#
|
596 |
+
# # Generate embeddings
|
597 |
+
# embeddings = onnx_embedder.create_embeddings(texts)
|
598 |
+
# return embeddings
|
599 |
+
#
|
600 |
+
# #
|
601 |
+
# # End of ONNX Embeddings Functions
|
602 |
+
# ##############################################################
|
603 |
+
|
604 |
#
|
605 |
# End of File.
|
606 |
#######################################################################################################################
|
App_Function_Libraries/RAG/RAG_Library_2.py
CHANGED
@@ -5,15 +5,24 @@
|
|
5 |
import configparser
|
6 |
import logging
|
7 |
import os
|
|
|
8 |
from typing import Dict, Any, List, Optional
|
|
|
|
|
|
|
|
|
9 |
# Local Imports
|
10 |
from App_Function_Libraries.RAG.ChromaDB_Library import process_and_store_content, vector_search, chroma_client
|
|
|
|
|
11 |
from App_Function_Libraries.Web_Scraping.Article_Extractor_Lib import scrape_article
|
12 |
from App_Function_Libraries.DB.DB_Manager import search_db, fetch_keywords_for_media
|
13 |
from App_Function_Libraries.Utils.Utils import load_comprehensive_config
|
|
|
14 |
#
|
15 |
# 3rd-Party Imports
|
16 |
import openai
|
|
|
17 |
#
|
18 |
########################################################################################################################
|
19 |
#
|
@@ -109,6 +118,8 @@ config.read('config.txt')
|
|
109 |
|
110 |
# RAG Search with keyword filtering
|
111 |
def enhanced_rag_pipeline(query: str, api_choice: str, keywords: str = None) -> Dict[str, Any]:
|
|
|
|
|
112 |
try:
|
113 |
# Load embedding provider from config, or fallback to 'openai'
|
114 |
embedding_provider = config.get('Embeddings', 'provider', fallback='openai')
|
@@ -118,32 +129,57 @@ def enhanced_rag_pipeline(query: str, api_choice: str, keywords: str = None) ->
|
|
118 |
|
119 |
# Process keywords if provided
|
120 |
keyword_list = [k.strip().lower() for k in keywords.split(',')] if keywords else []
|
121 |
-
logging.debug(f"
|
122 |
|
123 |
# Fetch relevant media IDs based on keywords if keywords are provided
|
124 |
relevant_media_ids = fetch_relevant_media_ids(keyword_list) if keyword_list else None
|
125 |
-
logging.debug(f"
|
126 |
|
127 |
# Perform vector search
|
128 |
vector_results = perform_vector_search(query, relevant_media_ids)
|
129 |
-
logging.debug(f"
|
130 |
|
131 |
# Perform full-text search
|
132 |
fts_results = perform_full_text_search(query, relevant_media_ids)
|
133 |
-
logging.debug(
|
|
|
|
|
|
|
|
|
134 |
|
135 |
# Combine results
|
136 |
all_results = vector_results + fts_results
|
137 |
|
138 |
-
|
139 |
-
apply_re_ranking = False
|
140 |
if apply_re_ranking:
|
141 |
-
|
142 |
-
|
143 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
144 |
context = "\n".join([result['content'] for result in all_results[:10]]) # Limit to top 10 results
|
145 |
logging.debug(f"Context length: {len(context)}")
|
146 |
logging.debug(f"Context: {context[:200]}")
|
|
|
147 |
# Generate answer using the selected API
|
148 |
answer = generate_answer(api_choice, context, query)
|
149 |
|
@@ -153,111 +189,220 @@ def enhanced_rag_pipeline(query: str, api_choice: str, keywords: str = None) ->
|
|
153 |
"answer": "No relevant information based on your query and keywords were found in the database. Your query has been directly passed to the LLM, and here is its answer: \n\n" + answer,
|
154 |
"context": "No relevant information based on your query and keywords were found in the database. The only context used was your query: \n\n" + query
|
155 |
}
|
156 |
-
|
|
|
|
|
|
|
157 |
return {
|
158 |
"answer": answer,
|
159 |
"context": context
|
160 |
}
|
161 |
|
162 |
except Exception as e:
|
|
|
|
|
|
|
163 |
logging.error(f"Error in enhanced_rag_pipeline: {str(e)}")
|
164 |
return {
|
165 |
"answer": "An error occurred while processing your request.",
|
166 |
"context": ""
|
167 |
}
|
168 |
|
169 |
-
|
170 |
def generate_answer(api_choice: str, context: str, query: str) -> str:
|
|
|
|
|
|
|
171 |
logging.debug("Entering generate_answer function")
|
172 |
config = load_comprehensive_config()
|
173 |
logging.debug(f"Config sections: {config.sections()}")
|
174 |
prompt = f"Context: {context}\n\nQuestion: {query}"
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
222 |
|
223 |
def perform_vector_search(query: str, relevant_media_ids: List[str] = None) -> List[Dict[str, Any]]:
|
|
|
|
|
224 |
all_collections = chroma_client.list_collections()
|
225 |
vector_results = []
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
235 |
|
236 |
def perform_full_text_search(query: str, relevant_media_ids: List[str] = None) -> List[Dict[str, Any]]:
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
247 |
|
248 |
|
249 |
def fetch_relevant_media_ids(keywords: List[str]) -> List[int]:
|
|
|
|
|
250 |
relevant_ids = set()
|
251 |
-
|
252 |
-
|
253 |
media_ids = fetch_keywords_for_media(keyword)
|
254 |
relevant_ids.update(media_ids)
|
255 |
-
|
256 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
257 |
return list(relevant_ids)
|
258 |
|
259 |
|
260 |
def filter_results_by_keywords(results: List[Dict[str, Any]], keywords: List[str]) -> List[Dict[str, Any]]:
|
|
|
|
|
261 |
if not keywords:
|
262 |
return results
|
263 |
|
@@ -283,6 +428,9 @@ def filter_results_by_keywords(results: List[Dict[str, Any]], keywords: List[str
|
|
283 |
except Exception as e:
|
284 |
logging.error(f"Error processing result: {result}. Error: {str(e)}")
|
285 |
|
|
|
|
|
|
|
286 |
return filtered_results
|
287 |
|
288 |
# FIXME: to be implememted
|
@@ -300,6 +448,173 @@ def extract_media_id_from_result(result: str) -> Optional[int]:
|
|
300 |
########################################################################################################################
|
301 |
|
302 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
303 |
# Function to preprocess and store all existing content in the database
|
304 |
# def preprocess_all_content(database, create_contextualized=True, api_name="gpt-3.5-turbo"):
|
305 |
# unprocessed_media = get_unprocessed_media()
|
|
|
5 |
import configparser
|
6 |
import logging
|
7 |
import os
|
8 |
+
import time
|
9 |
from typing import Dict, Any, List, Optional
|
10 |
+
|
11 |
+
from App_Function_Libraries.DB.Character_Chat_DB import get_character_chats, perform_full_text_search_chat, \
|
12 |
+
fetch_keywords_for_chats
|
13 |
+
#
|
14 |
# Local Imports
|
15 |
from App_Function_Libraries.RAG.ChromaDB_Library import process_and_store_content, vector_search, chroma_client
|
16 |
+
from App_Function_Libraries.RAG.RAG_Persona_Chat import perform_vector_search_chat
|
17 |
+
from App_Function_Libraries.Summarization.Local_Summarization_Lib import summarize_with_custom_openai
|
18 |
from App_Function_Libraries.Web_Scraping.Article_Extractor_Lib import scrape_article
|
19 |
from App_Function_Libraries.DB.DB_Manager import search_db, fetch_keywords_for_media
|
20 |
from App_Function_Libraries.Utils.Utils import load_comprehensive_config
|
21 |
+
from App_Function_Libraries.Metrics.metrics_logger import log_counter, log_histogram
|
22 |
#
|
23 |
# 3rd-Party Imports
|
24 |
import openai
|
25 |
+
from flashrank import Ranker, RerankRequest
|
26 |
#
|
27 |
########################################################################################################################
|
28 |
#
|
|
|
118 |
|
119 |
# RAG Search with keyword filtering
|
120 |
def enhanced_rag_pipeline(query: str, api_choice: str, keywords: str = None) -> Dict[str, Any]:
|
121 |
+
log_counter("enhanced_rag_pipeline_attempt", labels={"api_choice": api_choice})
|
122 |
+
start_time = time.time()
|
123 |
try:
|
124 |
# Load embedding provider from config, or fallback to 'openai'
|
125 |
embedding_provider = config.get('Embeddings', 'provider', fallback='openai')
|
|
|
129 |
|
130 |
# Process keywords if provided
|
131 |
keyword_list = [k.strip().lower() for k in keywords.split(',')] if keywords else []
|
132 |
+
logging.debug(f"\n\nenhanced_rag_pipeline - Keywords: {keyword_list}")
|
133 |
|
134 |
# Fetch relevant media IDs based on keywords if keywords are provided
|
135 |
relevant_media_ids = fetch_relevant_media_ids(keyword_list) if keyword_list else None
|
136 |
+
logging.debug(f"\n\nenhanced_rag_pipeline - relevant media IDs: {relevant_media_ids}")
|
137 |
|
138 |
# Perform vector search
|
139 |
vector_results = perform_vector_search(query, relevant_media_ids)
|
140 |
+
logging.debug(f"\n\nenhanced_rag_pipeline - Vector search results: {vector_results}")
|
141 |
|
142 |
# Perform full-text search
|
143 |
fts_results = perform_full_text_search(query, relevant_media_ids)
|
144 |
+
logging.debug("\n\nenhanced_rag_pipeline - Full-text search results:")
|
145 |
+
logging.debug(
|
146 |
+
"\n\nenhanced_rag_pipeline - Full-text search results:\n" + "\n".join(
|
147 |
+
[str(item) for item in fts_results]) + "\n"
|
148 |
+
)
|
149 |
|
150 |
# Combine results
|
151 |
all_results = vector_results + fts_results
|
152 |
|
153 |
+
apply_re_ranking = True
|
|
|
154 |
if apply_re_ranking:
|
155 |
+
logging.debug(f"\nenhanced_rag_pipeline - Applying Re-Ranking")
|
156 |
+
# FIXME - add option to use re-ranking at call time
|
157 |
+
# FIXME - specify model + add param to modify at call time
|
158 |
+
# FIXME - add option to set a custom top X results
|
159 |
+
# You can specify a model if necessary, e.g., model_name="ms-marco-MiniLM-L-12-v2"
|
160 |
+
ranker = Ranker()
|
161 |
+
|
162 |
+
# Prepare passages for re-ranking
|
163 |
+
passages = [{"id": i, "text": result['content']} for i, result in enumerate(all_results)]
|
164 |
+
rerank_request = RerankRequest(query=query, passages=passages)
|
165 |
+
|
166 |
+
# Rerank the results
|
167 |
+
reranked_results = ranker.rerank(rerank_request)
|
168 |
+
|
169 |
+
# Sort results based on the re-ranking score
|
170 |
+
reranked_results = sorted(reranked_results, key=lambda x: x['score'], reverse=True)
|
171 |
+
|
172 |
+
# Log reranked results
|
173 |
+
logging.debug(f"\n\nenhanced_rag_pipeline - Reranked results: {reranked_results}")
|
174 |
+
|
175 |
+
# Update all_results based on reranking
|
176 |
+
all_results = [all_results[result['id']] for result in reranked_results]
|
177 |
+
|
178 |
+
# Extract content from results (top 10)
|
179 |
context = "\n".join([result['content'] for result in all_results[:10]]) # Limit to top 10 results
|
180 |
logging.debug(f"Context length: {len(context)}")
|
181 |
logging.debug(f"Context: {context[:200]}")
|
182 |
+
|
183 |
# Generate answer using the selected API
|
184 |
answer = generate_answer(api_choice, context, query)
|
185 |
|
|
|
189 |
"answer": "No relevant information based on your query and keywords were found in the database. Your query has been directly passed to the LLM, and here is its answer: \n\n" + answer,
|
190 |
"context": "No relevant information based on your query and keywords were found in the database. The only context used was your query: \n\n" + query
|
191 |
}
|
192 |
+
# Metrics
|
193 |
+
pipeline_duration = time.time() - start_time
|
194 |
+
log_histogram("enhanced_rag_pipeline_duration", pipeline_duration, labels={"api_choice": api_choice})
|
195 |
+
log_counter("enhanced_rag_pipeline_success", labels={"api_choice": api_choice})
|
196 |
return {
|
197 |
"answer": answer,
|
198 |
"context": context
|
199 |
}
|
200 |
|
201 |
except Exception as e:
|
202 |
+
# Metrics
|
203 |
+
log_counter("enhanced_rag_pipeline_error", labels={"api_choice": api_choice, "error": str(e)})
|
204 |
+
logging.error(f"Error in enhanced_rag_pipeline: {str(e)}")
|
205 |
logging.error(f"Error in enhanced_rag_pipeline: {str(e)}")
|
206 |
return {
|
207 |
"answer": "An error occurred while processing your request.",
|
208 |
"context": ""
|
209 |
}
|
210 |
|
211 |
+
# Need to write a test for this function FIXME
|
212 |
def generate_answer(api_choice: str, context: str, query: str) -> str:
|
213 |
+
# Metrics
|
214 |
+
log_counter("generate_answer_attempt", labels={"api_choice": api_choice})
|
215 |
+
start_time = time.time()
|
216 |
logging.debug("Entering generate_answer function")
|
217 |
config = load_comprehensive_config()
|
218 |
logging.debug(f"Config sections: {config.sections()}")
|
219 |
prompt = f"Context: {context}\n\nQuestion: {query}"
|
220 |
+
try:
|
221 |
+
if api_choice == "OpenAI":
|
222 |
+
from App_Function_Libraries.Summarization.Summarization_General_Lib import summarize_with_openai
|
223 |
+
answer_generation_duration = time.time() - start_time
|
224 |
+
log_histogram("generate_answer_duration", answer_generation_duration, labels={"api_choice": api_choice})
|
225 |
+
log_counter("generate_answer_success", labels={"api_choice": api_choice})
|
226 |
+
return summarize_with_openai(config['API']['openai_api_key'], prompt, "")
|
227 |
+
|
228 |
+
elif api_choice == "Anthropic":
|
229 |
+
from App_Function_Libraries.Summarization.Summarization_General_Lib import summarize_with_anthropic
|
230 |
+
answer_generation_duration = time.time() - start_time
|
231 |
+
log_histogram("generate_answer_duration", answer_generation_duration, labels={"api_choice": api_choice})
|
232 |
+
log_counter("generate_answer_success", labels={"api_choice": api_choice})
|
233 |
+
return summarize_with_anthropic(config['API']['anthropic_api_key'], prompt, "")
|
234 |
+
|
235 |
+
elif api_choice == "Cohere":
|
236 |
+
from App_Function_Libraries.Summarization.Summarization_General_Lib import summarize_with_cohere
|
237 |
+
answer_generation_duration = time.time() - start_time
|
238 |
+
log_histogram("generate_answer_duration", answer_generation_duration, labels={"api_choice": api_choice})
|
239 |
+
log_counter("generate_answer_success", labels={"api_choice": api_choice})
|
240 |
+
return summarize_with_cohere(config['API']['cohere_api_key'], prompt, "")
|
241 |
+
|
242 |
+
elif api_choice == "Groq":
|
243 |
+
from App_Function_Libraries.Summarization.Summarization_General_Lib import summarize_with_groq
|
244 |
+
answer_generation_duration = time.time() - start_time
|
245 |
+
log_histogram("generate_answer_duration", answer_generation_duration, labels={"api_choice": api_choice})
|
246 |
+
log_counter("generate_answer_success", labels={"api_choice": api_choice})
|
247 |
+
return summarize_with_groq(config['API']['groq_api_key'], prompt, "")
|
248 |
+
|
249 |
+
elif api_choice == "OpenRouter":
|
250 |
+
from App_Function_Libraries.Summarization.Summarization_General_Lib import summarize_with_openrouter
|
251 |
+
answer_generation_duration = time.time() - start_time
|
252 |
+
log_histogram("generate_answer_duration", answer_generation_duration, labels={"api_choice": api_choice})
|
253 |
+
log_counter("generate_answer_success", labels={"api_choice": api_choice})
|
254 |
+
return summarize_with_openrouter(config['API']['openrouter_api_key'], prompt, "")
|
255 |
+
|
256 |
+
elif api_choice == "HuggingFace":
|
257 |
+
from App_Function_Libraries.Summarization.Summarization_General_Lib import summarize_with_huggingface
|
258 |
+
answer_generation_duration = time.time() - start_time
|
259 |
+
log_histogram("generate_answer_duration", answer_generation_duration, labels={"api_choice": api_choice})
|
260 |
+
log_counter("generate_answer_success", labels={"api_choice": api_choice})
|
261 |
+
return summarize_with_huggingface(config['API']['huggingface_api_key'], prompt, "")
|
262 |
+
|
263 |
+
elif api_choice == "DeepSeek":
|
264 |
+
from App_Function_Libraries.Summarization.Summarization_General_Lib import summarize_with_deepseek
|
265 |
+
answer_generation_duration = time.time() - start_time
|
266 |
+
log_histogram("generate_answer_duration", answer_generation_duration, labels={"api_choice": api_choice})
|
267 |
+
log_counter("generate_answer_success", labels={"api_choice": api_choice})
|
268 |
+
return summarize_with_deepseek(config['API']['deepseek_api_key'], prompt, "")
|
269 |
+
|
270 |
+
elif api_choice == "Mistral":
|
271 |
+
from App_Function_Libraries.Summarization.Summarization_General_Lib import summarize_with_mistral
|
272 |
+
answer_generation_duration = time.time() - start_time
|
273 |
+
log_histogram("generate_answer_duration", answer_generation_duration, labels={"api_choice": api_choice})
|
274 |
+
log_counter("generate_answer_success", labels={"api_choice": api_choice})
|
275 |
+
return summarize_with_mistral(config['API']['mistral_api_key'], prompt, "")
|
276 |
+
|
277 |
+
# Local LLM APIs
|
278 |
+
elif api_choice == "Local-LLM":
|
279 |
+
from App_Function_Libraries.Summarization.Local_Summarization_Lib import summarize_with_local_llm
|
280 |
+
answer_generation_duration = time.time() - start_time
|
281 |
+
log_histogram("generate_answer_duration", answer_generation_duration, labels={"api_choice": api_choice})
|
282 |
+
log_counter("generate_answer_success", labels={"api_choice": api_choice})
|
283 |
+
# FIXME
|
284 |
+
return summarize_with_local_llm(config['Local-API']['local_llm_path'], prompt, "")
|
285 |
+
|
286 |
+
elif api_choice == "Llama.cpp":
|
287 |
+
from App_Function_Libraries.Summarization.Local_Summarization_Lib import summarize_with_llama
|
288 |
+
answer_generation_duration = time.time() - start_time
|
289 |
+
log_histogram("generate_answer_duration", answer_generation_duration, labels={"api_choice": api_choice})
|
290 |
+
log_counter("generate_answer_success", labels={"api_choice": api_choice})
|
291 |
+
return summarize_with_llama(prompt, "", config['Local-API']['llama_api_key'], None, None)
|
292 |
+
elif api_choice == "Kobold":
|
293 |
+
from App_Function_Libraries.Summarization.Local_Summarization_Lib import summarize_with_kobold
|
294 |
+
answer_generation_duration = time.time() - start_time
|
295 |
+
log_histogram("generate_answer_duration", answer_generation_duration, labels={"api_choice": api_choice})
|
296 |
+
log_counter("generate_answer_success", labels={"api_choice": api_choice})
|
297 |
+
return summarize_with_kobold(prompt, config['Local-API']['kobold_api_key'], "", system_message=None, temp=None)
|
298 |
+
|
299 |
+
elif api_choice == "Ooba":
|
300 |
+
from App_Function_Libraries.Summarization.Local_Summarization_Lib import summarize_with_oobabooga
|
301 |
+
answer_generation_duration = time.time() - start_time
|
302 |
+
log_histogram("generate_answer_duration", answer_generation_duration, labels={"api_choice": api_choice})
|
303 |
+
log_counter("generate_answer_success", labels={"api_choice": api_choice})
|
304 |
+
return summarize_with_oobabooga(prompt, config['Local-API']['ooba_api_key'], custom_prompt="", system_message=None, temp=None)
|
305 |
+
|
306 |
+
elif api_choice == "TabbyAPI":
|
307 |
+
from App_Function_Libraries.Summarization.Local_Summarization_Lib import summarize_with_tabbyapi
|
308 |
+
answer_generation_duration = time.time() - start_time
|
309 |
+
log_histogram("generate_answer_duration", answer_generation_duration, labels={"api_choice": api_choice})
|
310 |
+
log_counter("generate_answer_success", labels={"api_choice": api_choice})
|
311 |
+
return summarize_with_tabbyapi(prompt, None, None, None, None, )
|
312 |
+
|
313 |
+
elif api_choice == "vLLM":
|
314 |
+
from App_Function_Libraries.Summarization.Local_Summarization_Lib import summarize_with_vllm
|
315 |
+
answer_generation_duration = time.time() - start_time
|
316 |
+
log_histogram("generate_answer_duration", answer_generation_duration, labels={"api_choice": api_choice})
|
317 |
+
log_counter("generate_answer_success", labels={"api_choice": api_choice})
|
318 |
+
return summarize_with_vllm(prompt, "", config['Local-API']['vllm_api_key'], None, None)
|
319 |
+
|
320 |
+
elif api_choice.lower() == "ollama":
|
321 |
+
from App_Function_Libraries.Summarization.Local_Summarization_Lib import summarize_with_ollama
|
322 |
+
answer_generation_duration = time.time() - start_time
|
323 |
+
log_histogram("generate_answer_duration", answer_generation_duration, labels={"api_choice": api_choice})
|
324 |
+
log_counter("generate_answer_success", labels={"api_choice": api_choice})
|
325 |
+
return summarize_with_ollama(prompt, "", config['Local-API']['ollama_api_IP'], config['Local-API']['ollama_api_key'], None, None, None)
|
326 |
+
|
327 |
+
elif api_choice.lower() == "custom_openai_api":
|
328 |
+
logging.debug(f"RAG Answer Gen: Trying with Custom_OpenAI API")
|
329 |
+
summary = summarize_with_custom_openai(prompt, "", config['API']['custom_openai_api_key'], None,
|
330 |
+
None)
|
331 |
+
else:
|
332 |
+
log_counter("generate_answer_error", labels={"api_choice": api_choice, "error": str()})
|
333 |
+
raise ValueError(f"Unsupported API choice: {api_choice}")
|
334 |
+
except Exception as e:
|
335 |
+
log_counter("generate_answer_error", labels={"api_choice": api_choice, "error": str(e)})
|
336 |
+
logging.error(f"Error in generate_answer: {str(e)}")
|
337 |
+
return "An error occurred while generating the answer."
|
338 |
|
339 |
def perform_vector_search(query: str, relevant_media_ids: List[str] = None) -> List[Dict[str, Any]]:
|
340 |
+
log_counter("perform_vector_search_attempt")
|
341 |
+
start_time = time.time()
|
342 |
all_collections = chroma_client.list_collections()
|
343 |
vector_results = []
|
344 |
+
try:
|
345 |
+
for collection in all_collections:
|
346 |
+
collection_results = vector_search(collection.name, query, k=5)
|
347 |
+
filtered_results = [
|
348 |
+
result for result in collection_results
|
349 |
+
if relevant_media_ids is None or result['metadata'].get('media_id') in relevant_media_ids
|
350 |
+
]
|
351 |
+
vector_results.extend(filtered_results)
|
352 |
+
search_duration = time.time() - start_time
|
353 |
+
log_histogram("perform_vector_search_duration", search_duration)
|
354 |
+
log_counter("perform_vector_search_success", labels={"result_count": len(vector_results)})
|
355 |
+
return vector_results
|
356 |
+
except Exception as e:
|
357 |
+
log_counter("perform_vector_search_error", labels={"error": str(e)})
|
358 |
+
logging.error(f"Error in perform_vector_search: {str(e)}")
|
359 |
+
raise
|
360 |
|
361 |
def perform_full_text_search(query: str, relevant_media_ids: List[str] = None) -> List[Dict[str, Any]]:
|
362 |
+
log_counter("perform_full_text_search_attempt")
|
363 |
+
start_time = time.time()
|
364 |
+
try:
|
365 |
+
fts_results = search_db(query, ["content"], "", page=1, results_per_page=5)
|
366 |
+
filtered_fts_results = [
|
367 |
+
{
|
368 |
+
"content": result['content'],
|
369 |
+
"metadata": {"media_id": result['id']}
|
370 |
+
}
|
371 |
+
for result in fts_results
|
372 |
+
if relevant_media_ids is None or result['id'] in relevant_media_ids
|
373 |
+
]
|
374 |
+
search_duration = time.time() - start_time
|
375 |
+
log_histogram("perform_full_text_search_duration", search_duration)
|
376 |
+
log_counter("perform_full_text_search_success", labels={"result_count": len(filtered_fts_results)})
|
377 |
+
return filtered_fts_results
|
378 |
+
except Exception as e:
|
379 |
+
log_counter("perform_full_text_search_error", labels={"error": str(e)})
|
380 |
+
logging.error(f"Error in perform_full_text_search: {str(e)}")
|
381 |
+
raise
|
382 |
|
383 |
|
384 |
def fetch_relevant_media_ids(keywords: List[str]) -> List[int]:
|
385 |
+
log_counter("fetch_relevant_media_ids_attempt", labels={"keyword_count": len(keywords)})
|
386 |
+
start_time = time.time()
|
387 |
relevant_ids = set()
|
388 |
+
for keyword in keywords:
|
389 |
+
try:
|
390 |
media_ids = fetch_keywords_for_media(keyword)
|
391 |
relevant_ids.update(media_ids)
|
392 |
+
except Exception as e:
|
393 |
+
log_counter("fetch_relevant_media_ids_error", labels={"error": str(e)})
|
394 |
+
logging.error(f"Error fetching relevant media IDs for keyword '{keyword}': {str(e)}")
|
395 |
+
# Continue processing other keywords
|
396 |
+
|
397 |
+
fetch_duration = time.time() - start_time
|
398 |
+
log_histogram("fetch_relevant_media_ids_duration", fetch_duration)
|
399 |
+
log_counter("fetch_relevant_media_ids_success", labels={"result_count": len(relevant_ids)})
|
400 |
return list(relevant_ids)
|
401 |
|
402 |
|
403 |
def filter_results_by_keywords(results: List[Dict[str, Any]], keywords: List[str]) -> List[Dict[str, Any]]:
|
404 |
+
log_counter("filter_results_by_keywords_attempt", labels={"result_count": len(results), "keyword_count": len(keywords)})
|
405 |
+
start_time = time.time()
|
406 |
if not keywords:
|
407 |
return results
|
408 |
|
|
|
428 |
except Exception as e:
|
429 |
logging.error(f"Error processing result: {result}. Error: {str(e)}")
|
430 |
|
431 |
+
filter_duration = time.time() - start_time
|
432 |
+
log_histogram("filter_results_by_keywords_duration", filter_duration)
|
433 |
+
log_counter("filter_results_by_keywords_success", labels={"filtered_count": len(filtered_results)})
|
434 |
return filtered_results
|
435 |
|
436 |
# FIXME: to be implememted
|
|
|
448 |
########################################################################################################################
|
449 |
|
450 |
|
451 |
+
############################################################################################################
|
452 |
+
#
|
453 |
+
# Chat RAG
|
454 |
+
|
455 |
+
def enhanced_rag_pipeline_chat(query: str, api_choice: str, character_id: int, keywords: Optional[str] = None) -> Dict[str, Any]:
|
456 |
+
"""
|
457 |
+
Enhanced RAG pipeline tailored for the Character Chat tab.
|
458 |
+
|
459 |
+
Args:
|
460 |
+
query (str): The user's input query.
|
461 |
+
api_choice (str): The API to use for generating the response.
|
462 |
+
character_id (int): The ID of the character being interacted with.
|
463 |
+
keywords (Optional[str]): Comma-separated keywords to filter search results.
|
464 |
+
|
465 |
+
Returns:
|
466 |
+
Dict[str, Any]: Contains the generated answer and the context used.
|
467 |
+
"""
|
468 |
+
log_counter("enhanced_rag_pipeline_chat_attempt", labels={"api_choice": api_choice, "character_id": character_id})
|
469 |
+
start_time = time.time()
|
470 |
+
try:
|
471 |
+
# Load embedding provider from config, or fallback to 'openai'
|
472 |
+
embedding_provider = config.get('Embeddings', 'provider', fallback='openai')
|
473 |
+
logging.debug(f"Using embedding provider: {embedding_provider}")
|
474 |
+
|
475 |
+
# Process keywords if provided
|
476 |
+
keyword_list = [k.strip().lower() for k in keywords.split(',')] if keywords else []
|
477 |
+
logging.debug(f"enhanced_rag_pipeline_chat - Keywords: {keyword_list}")
|
478 |
+
|
479 |
+
# Fetch relevant chat IDs based on character_id and keywords
|
480 |
+
if keyword_list:
|
481 |
+
relevant_chat_ids = fetch_keywords_for_chats(keyword_list)
|
482 |
+
else:
|
483 |
+
relevant_chat_ids = fetch_all_chat_ids(character_id)
|
484 |
+
logging.debug(f"enhanced_rag_pipeline_chat - Relevant chat IDs: {relevant_chat_ids}")
|
485 |
+
|
486 |
+
if not relevant_chat_ids:
|
487 |
+
logging.info(f"No chats found for the given keywords and character ID: {character_id}")
|
488 |
+
# Fallback to generating answer without context
|
489 |
+
answer = generate_answer(api_choice, "", query)
|
490 |
+
# Metrics
|
491 |
+
pipeline_duration = time.time() - start_time
|
492 |
+
log_histogram("enhanced_rag_pipeline_chat_duration", pipeline_duration, labels={"api_choice": api_choice})
|
493 |
+
log_counter("enhanced_rag_pipeline_chat_success",
|
494 |
+
labels={"api_choice": api_choice, "character_id": character_id})
|
495 |
+
return {
|
496 |
+
"answer": answer,
|
497 |
+
"context": ""
|
498 |
+
}
|
499 |
+
|
500 |
+
# Perform vector search within the relevant chats
|
501 |
+
vector_results = perform_vector_search_chat(query, relevant_chat_ids)
|
502 |
+
logging.debug(f"enhanced_rag_pipeline_chat - Vector search results: {vector_results}")
|
503 |
+
|
504 |
+
# Perform full-text search within the relevant chats
|
505 |
+
fts_results = perform_full_text_search_chat(query, relevant_chat_ids)
|
506 |
+
logging.debug("enhanced_rag_pipeline_chat - Full-text search results:")
|
507 |
+
logging.debug("\n".join([str(item) for item in fts_results]))
|
508 |
+
|
509 |
+
# Combine results
|
510 |
+
all_results = vector_results + fts_results
|
511 |
+
|
512 |
+
apply_re_ranking = True
|
513 |
+
if apply_re_ranking:
|
514 |
+
logging.debug("enhanced_rag_pipeline_chat - Applying Re-Ranking")
|
515 |
+
ranker = Ranker()
|
516 |
+
|
517 |
+
# Prepare passages for re-ranking
|
518 |
+
passages = [{"id": i, "text": result['content']} for i, result in enumerate(all_results)]
|
519 |
+
rerank_request = RerankRequest(query=query, passages=passages)
|
520 |
+
|
521 |
+
# Rerank the results
|
522 |
+
reranked_results = ranker.rerank(rerank_request)
|
523 |
+
|
524 |
+
# Sort results based on the re-ranking score
|
525 |
+
reranked_results = sorted(reranked_results, key=lambda x: x['score'], reverse=True)
|
526 |
+
|
527 |
+
# Log reranked results
|
528 |
+
logging.debug(f"enhanced_rag_pipeline_chat - Reranked results: {reranked_results}")
|
529 |
+
|
530 |
+
# Update all_results based on reranking
|
531 |
+
all_results = [all_results[result['id']] for result in reranked_results]
|
532 |
+
|
533 |
+
# Extract context from top results (limit to top 10)
|
534 |
+
context = "\n".join([result['content'] for result in all_results[:10]])
|
535 |
+
logging.debug(f"Context length: {len(context)}")
|
536 |
+
logging.debug(f"Context: {context[:200]}") # Log only the first 200 characters for brevity
|
537 |
+
|
538 |
+
# Generate answer using the selected API
|
539 |
+
answer = generate_answer(api_choice, context, query)
|
540 |
+
|
541 |
+
if not all_results:
|
542 |
+
logging.info(f"No results found. Query: {query}, Keywords: {keywords}")
|
543 |
+
return {
|
544 |
+
"answer": "No relevant information based on your query and keywords were found in the database. Your query has been directly passed to the LLM, and here is its answer: \n\n" + answer,
|
545 |
+
"context": "No relevant information based on your query and keywords were found in the database. The only context used was your query: \n\n" + query
|
546 |
+
}
|
547 |
+
|
548 |
+
return {
|
549 |
+
"answer": answer,
|
550 |
+
"context": context
|
551 |
+
}
|
552 |
+
|
553 |
+
except Exception as e:
|
554 |
+
log_counter("enhanced_rag_pipeline_chat_error", labels={"api_choice": api_choice, "character_id": character_id, "error": str(e)})
|
555 |
+
logging.error(f"Error in enhanced_rag_pipeline_chat: {str(e)}")
|
556 |
+
return {
|
557 |
+
"answer": "An error occurred while processing your request.",
|
558 |
+
"context": ""
|
559 |
+
}
|
560 |
+
|
561 |
+
|
562 |
+
def fetch_relevant_chat_ids(character_id: int, keywords: List[str]) -> List[int]:
|
563 |
+
"""
|
564 |
+
Fetch chat IDs associated with a character and filtered by keywords.
|
565 |
+
|
566 |
+
Args:
|
567 |
+
character_id (int): The ID of the character.
|
568 |
+
keywords (List[str]): List of keywords to filter chats.
|
569 |
+
|
570 |
+
Returns:
|
571 |
+
List[int]: List of relevant chat IDs.
|
572 |
+
"""
|
573 |
+
log_counter("fetch_relevant_chat_ids_attempt", labels={"character_id": character_id, "keyword_count": len(keywords)})
|
574 |
+
start_time = time.time()
|
575 |
+
relevant_ids = set()
|
576 |
+
try:
|
577 |
+
media_ids = fetch_keywords_for_chats(keywords)
|
578 |
+
fetch_duration = time.time() - start_time
|
579 |
+
log_histogram("fetch_relevant_chat_ids_duration", fetch_duration)
|
580 |
+
log_counter("fetch_relevant_chat_ids_success",
|
581 |
+
labels={"character_id": character_id, "result_count": len(relevant_ids)})
|
582 |
+
relevant_ids.update(media_ids)
|
583 |
+
return list(relevant_ids)
|
584 |
+
except Exception as e:
|
585 |
+
log_counter("fetch_relevant_chat_ids_error", labels={"character_id": character_id, "error": str(e)})
|
586 |
+
logging.error(f"Error fetching relevant chat IDs: {str(e)}")
|
587 |
+
return []
|
588 |
+
|
589 |
+
|
590 |
+
def fetch_all_chat_ids(character_id: int) -> List[int]:
|
591 |
+
"""
|
592 |
+
Fetch all chat IDs associated with a specific character.
|
593 |
+
|
594 |
+
Args:
|
595 |
+
character_id (int): The ID of the character.
|
596 |
+
|
597 |
+
Returns:
|
598 |
+
List[int]: List of all chat IDs for the character.
|
599 |
+
"""
|
600 |
+
log_counter("fetch_all_chat_ids_attempt", labels={"character_id": character_id})
|
601 |
+
start_time = time.time()
|
602 |
+
try:
|
603 |
+
chats = get_character_chats(character_id=character_id)
|
604 |
+
chat_ids = [chat['id'] for chat in chats]
|
605 |
+
fetch_duration = time.time() - start_time
|
606 |
+
log_histogram("fetch_all_chat_ids_duration", fetch_duration)
|
607 |
+
log_counter("fetch_all_chat_ids_success", labels={"character_id": character_id, "chat_count": len(chat_ids)})
|
608 |
+
return chat_ids
|
609 |
+
except Exception as e:
|
610 |
+
log_counter("fetch_all_chat_ids_error", labels={"character_id": character_id, "error": str(e)})
|
611 |
+
logging.error(f"Error fetching all chat IDs for character {character_id}: {str(e)}")
|
612 |
+
return []
|
613 |
+
|
614 |
+
#
|
615 |
+
# End of Chat RAG
|
616 |
+
############################################################################################################
|
617 |
+
|
618 |
# Function to preprocess and store all existing content in the database
|
619 |
# def preprocess_all_content(database, create_contextualized=True, api_name="gpt-3.5-turbo"):
|
620 |
# unprocessed_media = get_unprocessed_media()
|
App_Function_Libraries/RAG/RAG_Persona_Chat.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# RAG_Persona_Chat.py
|
2 |
+
# Description: Functions for RAG Persona Chat
|
3 |
+
#
|
4 |
+
# Imports
|
5 |
+
import logging
|
6 |
+
from typing import List, Dict, Any, Tuple
|
7 |
+
#
|
8 |
+
# External Imports
|
9 |
+
#
|
10 |
+
# Local Imports
|
11 |
+
from App_Function_Libraries.RAG.Embeddings_Create import create_embedding, embedding_provider, embedding_model, \
|
12 |
+
embedding_api_url
|
13 |
+
from App_Function_Libraries.RAG.ChromaDB_Library import chroma_client, store_in_chroma
|
14 |
+
#
|
15 |
+
#######################################################################################################################
|
16 |
+
#
|
17 |
+
# RAG Chat Embeddings
|
18 |
+
|
19 |
+
def perform_vector_search_chat(query: str, relevant_chat_ids: List[int], k: int = 10) -> List[Dict[str, Any]]:
|
20 |
+
"""
|
21 |
+
Perform a vector search within the specified chat IDs.
|
22 |
+
|
23 |
+
Args:
|
24 |
+
query (str): The user's query.
|
25 |
+
relevant_chat_ids (List[int]): List of chat IDs to search within.
|
26 |
+
k (int): Number of top results to retrieve.
|
27 |
+
|
28 |
+
Returns:
|
29 |
+
List[Dict[str, Any]]: List of search results with content and metadata.
|
30 |
+
"""
|
31 |
+
try:
|
32 |
+
# Convert chat IDs to unique identifiers used in ChromaDB
|
33 |
+
chat_ids = [f"chat_{chat_id}" for chat_id in relevant_chat_ids]
|
34 |
+
|
35 |
+
# Define the collection name for chat embeddings
|
36 |
+
collection_name = "all_chat_embeddings" # Ensure this collection exists and contains chat embeddings
|
37 |
+
|
38 |
+
# Generate the query embedding
|
39 |
+
query_embedding = create_embedding(query, embedding_provider, embedding_model, embedding_api_url)
|
40 |
+
|
41 |
+
# Get the collection
|
42 |
+
collection = chroma_client.get_collection(name=collection_name)
|
43 |
+
|
44 |
+
# Perform the vector search
|
45 |
+
results = collection.query(
|
46 |
+
query_embeddings=[query_embedding],
|
47 |
+
where={"id": {"$in": chat_ids}}, # Assuming 'id' is stored as document IDs
|
48 |
+
n_results=k,
|
49 |
+
include=["documents", "metadatas"]
|
50 |
+
)
|
51 |
+
|
52 |
+
# Process results
|
53 |
+
search_results = []
|
54 |
+
for doc, meta in zip(results['documents'][0], results['metadatas'][0]):
|
55 |
+
search_results.append({
|
56 |
+
"content": doc,
|
57 |
+
"metadata": meta
|
58 |
+
})
|
59 |
+
|
60 |
+
return search_results
|
61 |
+
except Exception as e:
|
62 |
+
logging.error(f"Error in perform_vector_search_chat: {e}")
|
63 |
+
return []
|
64 |
+
|
65 |
+
|
66 |
+
def embed_and_store_chat(chat_id: int, chat_history: List[Tuple[str, str]], conversation_name: str):
|
67 |
+
"""
|
68 |
+
Embed and store chat messages in ChromaDB.
|
69 |
+
|
70 |
+
Args:
|
71 |
+
chat_id (int): The ID of the chat.
|
72 |
+
chat_history (List[Tuple[str, str]]): List of (user_message, bot_response) tuples.
|
73 |
+
conversation_name (str): The name of the conversation.
|
74 |
+
"""
|
75 |
+
try:
|
76 |
+
for idx, (user_msg, bot_msg) in enumerate(chat_history, 1):
|
77 |
+
# Combine user and bot messages for context
|
78 |
+
combined_content = f"User: {user_msg}\nBot: {bot_msg}"
|
79 |
+
|
80 |
+
# Create embedding
|
81 |
+
embedding = create_embedding(combined_content, embedding_provider, embedding_model, embedding_api_url)
|
82 |
+
|
83 |
+
# Unique identifier for ChromaDB
|
84 |
+
document_id = f"chat_{chat_id}_msg_{idx}"
|
85 |
+
|
86 |
+
# Metadata with chat_id
|
87 |
+
metadata = {"chat_id": chat_id, "message_index": idx, "conversation_name": conversation_name}
|
88 |
+
|
89 |
+
# Store in ChromaDB
|
90 |
+
store_in_chroma(
|
91 |
+
collection_name="all_chat_embeddings",
|
92 |
+
texts=[combined_content],
|
93 |
+
embeddings=[embedding],
|
94 |
+
ids=[document_id],
|
95 |
+
metadatas=[metadata]
|
96 |
+
)
|
97 |
+
logging.debug(f"Stored chat message {idx} of chat ID {chat_id} in ChromaDB.")
|
98 |
+
except Exception as e:
|
99 |
+
logging.error(f"Error embedding and storing chat ID {chat_id}: {e}")
|
100 |
+
|
101 |
+
#
|
102 |
+
# End of RAG_Persona_Chat.py
|
103 |
+
#######################################################################################################################
|
App_Function_Libraries/RAG/RAG_QA_Chat.py
CHANGED
@@ -1,84 +1,137 @@
|
|
1 |
-
#
|
2 |
-
# Description:
|
3 |
-
#
|
4 |
-
# Imports
|
5 |
-
#
|
6 |
-
#
|
7 |
-
# External Imports
|
8 |
-
import json
|
9 |
-
import logging
|
10 |
-
import tempfile
|
11 |
-
|
12 |
-
|
13 |
-
#
|
14 |
-
|
15 |
-
from App_Function_Libraries.
|
16 |
-
|
17 |
-
|
18 |
-
#
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
#
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# RAG_QA_Chat.py
|
2 |
+
# Description: Functions supporting the RAG QA Chat functionality
|
3 |
+
#
|
4 |
+
# Imports
|
5 |
+
#
|
6 |
+
#
|
7 |
+
# External Imports
|
8 |
+
import json
|
9 |
+
import logging
|
10 |
+
import tempfile
|
11 |
+
import time
|
12 |
+
from typing import List, Tuple, IO, Union
|
13 |
+
#
|
14 |
+
# Local Imports
|
15 |
+
from App_Function_Libraries.DB.DB_Manager import db, search_db, DatabaseError, get_media_content
|
16 |
+
from App_Function_Libraries.RAG.RAG_Library_2 import generate_answer
|
17 |
+
from App_Function_Libraries.Metrics.metrics_logger import log_counter, log_histogram
|
18 |
+
#
|
19 |
+
########################################################################################################################
|
20 |
+
#
|
21 |
+
# Functions:
|
22 |
+
|
23 |
+
def rag_qa_chat(message: str, history: List[Tuple[str, str]], context: Union[str, IO[str]], api_choice: str) -> Tuple[List[Tuple[str, str]], str]:
|
24 |
+
log_counter("rag_qa_chat_attempt", labels={"api_choice": api_choice})
|
25 |
+
start_time = time.time()
|
26 |
+
try:
|
27 |
+
# Prepare the context based on the selected source
|
28 |
+
if hasattr(context, 'read'):
|
29 |
+
# Handle uploaded file
|
30 |
+
context_text = context.read()
|
31 |
+
if isinstance(context_text, bytes):
|
32 |
+
context_text = context_text.decode('utf-8')
|
33 |
+
log_counter("rag_qa_chat_uploaded_file")
|
34 |
+
elif isinstance(context, str) and context.startswith("media_id:"):
|
35 |
+
# Handle existing file or search result
|
36 |
+
media_id = int(context.split(":")[1])
|
37 |
+
context_text = get_media_content(media_id)
|
38 |
+
log_counter("rag_qa_chat_existing_media", labels={"media_id": media_id})
|
39 |
+
else:
|
40 |
+
context_text = str(context)
|
41 |
+
log_counter("rag_qa_chat_string_context")
|
42 |
+
|
43 |
+
# Prepare the full context including chat history
|
44 |
+
full_context = "\n".join([f"Human: {h[0]}\nAI: {h[1]}" for h in history])
|
45 |
+
full_context += f"\n\nContext: {context_text}\n\nHuman: {message}\nAI:"
|
46 |
+
|
47 |
+
# Generate response using the selected API
|
48 |
+
response = generate_answer(api_choice, full_context, message)
|
49 |
+
|
50 |
+
# Update history
|
51 |
+
history.append((message, response))
|
52 |
+
|
53 |
+
chat_duration = time.time() - start_time
|
54 |
+
log_histogram("rag_qa_chat_duration", chat_duration, labels={"api_choice": api_choice})
|
55 |
+
log_counter("rag_qa_chat_success", labels={"api_choice": api_choice})
|
56 |
+
|
57 |
+
return history, ""
|
58 |
+
except DatabaseError as e:
|
59 |
+
log_counter("rag_qa_chat_database_error", labels={"error": str(e)})
|
60 |
+
logging.error(f"Database error in rag_qa_chat: {str(e)}")
|
61 |
+
return history, f"An error occurred while accessing the database: {str(e)}"
|
62 |
+
except Exception as e:
|
63 |
+
log_counter("rag_qa_chat_unexpected_error", labels={"error": str(e)})
|
64 |
+
logging.error(f"Unexpected error in rag_qa_chat: {str(e)}")
|
65 |
+
return history, f"An unexpected error occurred: {str(e)}"
|
66 |
+
|
67 |
+
|
68 |
+
|
69 |
+
def save_chat_history(history: List[Tuple[str, str]]) -> str:
|
70 |
+
# Save chat history to a file
|
71 |
+
log_counter("save_chat_history_attempt")
|
72 |
+
start_time = time.time()
|
73 |
+
try:
|
74 |
+
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.json') as temp_file:
|
75 |
+
json.dump(history, temp_file)
|
76 |
+
save_duration = time.time() - start_time
|
77 |
+
log_histogram("save_chat_history_duration", save_duration)
|
78 |
+
log_counter("save_chat_history_success")
|
79 |
+
return temp_file.name
|
80 |
+
except Exception as e:
|
81 |
+
log_counter("save_chat_history_error", labels={"error": str(e)})
|
82 |
+
logging.error(f"Error saving chat history: {str(e)}")
|
83 |
+
raise
|
84 |
+
|
85 |
+
|
86 |
+
def load_chat_history(file: IO[str]) -> List[Tuple[str, str]]:
|
87 |
+
log_counter("load_chat_history_attempt")
|
88 |
+
start_time = time.time()
|
89 |
+
try:
|
90 |
+
# Load chat history from a file
|
91 |
+
history = json.load(file)
|
92 |
+
load_duration = time.time() - start_time
|
93 |
+
log_histogram("load_chat_history_duration", load_duration)
|
94 |
+
log_counter("load_chat_history_success")
|
95 |
+
return history
|
96 |
+
except Exception as e:
|
97 |
+
log_counter("load_chat_history_error", labels={"error": str(e)})
|
98 |
+
logging.error(f"Error loading chat history: {str(e)}")
|
99 |
+
raise
|
100 |
+
|
101 |
+
def search_database(query: str) -> List[Tuple[int, str]]:
|
102 |
+
try:
|
103 |
+
log_counter("search_database_attempt")
|
104 |
+
start_time = time.time()
|
105 |
+
# Implement database search functionality
|
106 |
+
results = search_db(query, ["title", "content"], "", page=1, results_per_page=10)
|
107 |
+
search_duration = time.time() - start_time
|
108 |
+
log_histogram("search_database_duration", search_duration)
|
109 |
+
log_counter("search_database_success", labels={"result_count": len(results)})
|
110 |
+
return [(result['id'], result['title']) for result in results]
|
111 |
+
except Exception as e:
|
112 |
+
log_counter("search_database_error", labels={"error": str(e)})
|
113 |
+
logging.error(f"Error searching database: {str(e)}")
|
114 |
+
raise
|
115 |
+
|
116 |
+
|
117 |
+
def get_existing_files() -> List[Tuple[int, str]]:
|
118 |
+
log_counter("get_existing_files_attempt")
|
119 |
+
start_time = time.time()
|
120 |
+
try:
|
121 |
+
# Fetch list of existing files from the database
|
122 |
+
with db.get_connection() as conn:
|
123 |
+
cursor = conn.cursor()
|
124 |
+
cursor.execute("SELECT id, title FROM Media ORDER BY title")
|
125 |
+
results = cursor.fetchall()
|
126 |
+
fetch_duration = time.time() - start_time
|
127 |
+
log_histogram("get_existing_files_duration", fetch_duration)
|
128 |
+
log_counter("get_existing_files_success", labels={"file_count": len(results)})
|
129 |
+
return results
|
130 |
+
except Exception as e:
|
131 |
+
log_counter("get_existing_files_error", labels={"error": str(e)})
|
132 |
+
logging.error(f"Error fetching existing files: {str(e)}")
|
133 |
+
raise
|
134 |
+
|
135 |
+
#
|
136 |
+
# End of RAG_QA_Chat.py
|
137 |
+
########################################################################################################################
|
App_Function_Libraries/RAG/eval_Chroma_Embeddings.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# eval_Chroma_Embeddings.py
|
2 |
+
# Description: This script is used to evaluate the embeddings and chunking process for the ChromaDB model.
|
3 |
+
#
|
4 |
+
# Imports
|
5 |
+
import io
|
6 |
+
from typing import List
|
7 |
+
#
|
8 |
+
# External Imports
|
9 |
+
from chromadb import Documents, EmbeddingFunction, Embeddings
|
10 |
+
from chromadb.utils import embedding_functions
|
11 |
+
from chunking_evaluation import BaseChunker, rigorous_document_search
|
12 |
+
from chunking_evaluation import BaseChunker, GeneralEvaluation
|
13 |
+
from chunking_evaluation.evaluation_framework.base_evaluation import BaseEvaluation
|
14 |
+
|
15 |
+
#
|
16 |
+
# Local Imports
|
17 |
+
from App_Function_Libraries.Chunk_Lib import improved_chunking_process
|
18 |
+
from App_Function_Libraries.RAG.ChromaDB_Library import embedding_model, embedding_api_url
|
19 |
+
from App_Function_Libraries.RAG.Embeddings_Create import create_embeddings_batch, embedding_provider
|
20 |
+
from App_Function_Libraries.Utils.Utils import load_comprehensive_config
|
21 |
+
#
|
22 |
+
########################################################################################################################
|
23 |
+
#
|
24 |
+
# Functions:
|
25 |
+
import chardet
|
26 |
+
# FIXME
|
27 |
+
|
28 |
+
|
29 |
+
def detect_file_encoding(file_path):
|
30 |
+
with open(file_path, 'rb') as file:
|
31 |
+
raw_data = file.read()
|
32 |
+
print(chardet.detect(raw_data)['encoding'])
|
33 |
+
return chardet.detect(raw_data)['encoding']
|
34 |
+
|
35 |
+
|
36 |
+
class CustomEmbeddingFunction(EmbeddingFunction):
|
37 |
+
def __call__(self, input: Documents) -> Embeddings:
|
38 |
+
# Load config here
|
39 |
+
config = load_comprehensive_config()
|
40 |
+
embedding_provider = config.get('Embeddings', 'embedding_provider', fallback='openai')
|
41 |
+
embedding_model = config.get('Embeddings', 'embedding_model', fallback='text-embedding-3-small')
|
42 |
+
embedding_api_url = config.get('Embeddings', 'api_url', fallback='')
|
43 |
+
|
44 |
+
# Use your existing create_embeddings_batch function
|
45 |
+
embeddings = create_embeddings_batch(input, embedding_provider, embedding_model, embedding_api_url)
|
46 |
+
return embeddings
|
47 |
+
|
48 |
+
|
49 |
+
class CustomChunker(BaseChunker):
|
50 |
+
def __init__(self, chunk_options):
|
51 |
+
self.chunk_options = chunk_options
|
52 |
+
|
53 |
+
def split_text(self, text: str) -> List[str]:
|
54 |
+
# Use your existing improved_chunking_process function
|
55 |
+
chunks = improved_chunking_process(text, self.chunk_options)
|
56 |
+
return [chunk['text'] for chunk in chunks]
|
57 |
+
|
58 |
+
def read_file(self, file_path: str) -> str:
|
59 |
+
encoding = detect_file_encoding(file_path)
|
60 |
+
with open(file_path, 'r', encoding=encoding) as file:
|
61 |
+
return file.read()
|
62 |
+
|
63 |
+
def utf8_file_reader(file_path):
|
64 |
+
with io.open(file_path, 'r', encoding='utf-8') as file:
|
65 |
+
return file.read()
|
66 |
+
|
67 |
+
|
68 |
+
class CustomEvaluation(BaseEvaluation):
|
69 |
+
def _get_chunks_and_metadata(self, splitter):
|
70 |
+
documents = []
|
71 |
+
metadatas = []
|
72 |
+
for corpus_id in self.corpus_list:
|
73 |
+
corpus_path = corpus_id
|
74 |
+
if self.corpora_id_paths is not None:
|
75 |
+
corpus_path = self.corpora_id_paths[corpus_id]
|
76 |
+
|
77 |
+
corpus = splitter.read_file(corpus_path)
|
78 |
+
|
79 |
+
current_documents = splitter.split_text(corpus)
|
80 |
+
current_metadatas = []
|
81 |
+
for document in current_documents:
|
82 |
+
try:
|
83 |
+
_, start_index, end_index = rigorous_document_search(corpus, document)
|
84 |
+
except:
|
85 |
+
print(f"Error in finding {document} in {corpus_id}")
|
86 |
+
raise Exception(f"Error in finding {document} in {corpus_id}")
|
87 |
+
current_metadatas.append({"start_index": start_index, "end_index": end_index, "corpus_id": corpus_id})
|
88 |
+
documents.extend(current_documents)
|
89 |
+
metadatas.extend(current_metadatas)
|
90 |
+
return documents, metadatas
|
91 |
+
|
92 |
+
|
93 |
+
# Instantiate your custom chunker
|
94 |
+
chunk_options = {
|
95 |
+
'method': 'words',
|
96 |
+
'max_size': 400,
|
97 |
+
'overlap': 200,
|
98 |
+
'adaptive': False,
|
99 |
+
'multi_level': False,
|
100 |
+
'language': 'english'
|
101 |
+
}
|
102 |
+
custom_chunker = CustomChunker(chunk_options)
|
103 |
+
|
104 |
+
# Instantiate your custom embedding function
|
105 |
+
custom_ef = CustomEmbeddingFunction()
|
106 |
+
|
107 |
+
|
108 |
+
# Evaluate the embedding function
|
109 |
+
|
110 |
+
# Evaluate the chunker
|
111 |
+
evaluation = GeneralEvaluation()
|
112 |
+
import chardet
|
113 |
+
|
114 |
+
def smart_file_reader(file_path):
|
115 |
+
encoding = detect_file_encoding(file_path)
|
116 |
+
with io.open(file_path, 'r', encoding=encoding) as file:
|
117 |
+
return file.read()
|
118 |
+
|
119 |
+
# Set the custom file reader
|
120 |
+
#evaluation._file_reader = smart_file_reader
|
121 |
+
|
122 |
+
|
123 |
+
# Generate Embedding results
|
124 |
+
embedding_results = evaluation.run(custom_chunker, custom_ef)
|
125 |
+
print(f"Embedding Results:\n\t{embedding_results}")
|
126 |
+
|
127 |
+
# Generate Chunking results
|
128 |
+
chunk_results = evaluation.run(custom_chunker, custom_ef)
|
129 |
+
print(f"Chunking Results:\n\t{chunk_results}")
|
130 |
+
|
131 |
+
#
|
132 |
+
# End of File
|
133 |
+
########################################################################################################################
|
App_Function_Libraries/Summarization/Local_Summarization_Lib.py
CHANGED
@@ -21,6 +21,7 @@
|
|
21 |
import json
|
22 |
import logging
|
23 |
import os
|
|
|
24 |
from typing import Union
|
25 |
|
26 |
import requests
|
@@ -640,10 +641,19 @@ def summarize_with_vllm(
|
|
640 |
return f"Error: Unexpected error during vLLM summarization - {str(e)}"
|
641 |
|
642 |
|
643 |
-
|
644 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
645 |
try:
|
646 |
-
logging.debug("
|
647 |
loaded_config_data = load_and_log_configs()
|
648 |
if loaded_config_data is None:
|
649 |
logging.error("Failed to load configuration data")
|
@@ -661,7 +671,19 @@ def summarize_with_ollama(input_data, custom_prompt, api_key=None, temp=None, sy
|
|
661 |
else:
|
662 |
logging.warning("Ollama: No API key found in config file")
|
663 |
|
664 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
665 |
|
666 |
# Load transcript
|
667 |
logging.debug("Ollama: Loading JSON data")
|
@@ -690,57 +712,92 @@ def summarize_with_ollama(input_data, custom_prompt, api_key=None, temp=None, sy
|
|
690 |
else:
|
691 |
raise ValueError("Ollama: Invalid input data format")
|
692 |
|
693 |
-
if custom_prompt is None:
|
694 |
-
custom_prompt = f"{summarizer_prompt}\n\n\n\n{text}"
|
695 |
-
else:
|
696 |
-
custom_prompt = f"{custom_prompt}\n\n\n\n{text}"
|
697 |
-
|
698 |
headers = {
|
699 |
'accept': 'application/json',
|
700 |
'content-type': 'application/json',
|
701 |
}
|
702 |
-
if len(ollama_api_key) > 5:
|
703 |
headers['Authorization'] = f'Bearer {ollama_api_key}'
|
704 |
|
705 |
-
ollama_prompt = f"{custom_prompt}
|
706 |
-
if system_message is None:
|
707 |
-
system_message = "You are a helpful AI assistant."
|
708 |
-
logging.debug(f"llama: Prompt being sent is {ollama_prompt}")
|
709 |
if system_message is None:
|
710 |
system_message = "You are a helpful AI assistant."
|
|
|
711 |
|
712 |
-
|
713 |
"model": model,
|
714 |
"messages": [
|
715 |
-
{
|
716 |
-
|
717 |
-
|
718 |
-
|
719 |
-
|
720 |
-
|
|
|
|
|
721 |
],
|
|
|
722 |
}
|
723 |
|
724 |
-
|
725 |
-
|
726 |
-
|
727 |
-
|
728 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
729 |
|
730 |
-
|
731 |
-
|
732 |
-
|
733 |
-
|
734 |
-
|
735 |
-
|
736 |
-
|
737 |
-
|
738 |
-
|
739 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
740 |
|
741 |
except Exception as e:
|
742 |
-
logging.error("
|
743 |
-
return f"Ollama: Error occurred while processing summary with
|
744 |
|
745 |
|
746 |
# FIXME - update to be a summarize request
|
|
|
21 |
import json
|
22 |
import logging
|
23 |
import os
|
24 |
+
import time
|
25 |
from typing import Union
|
26 |
|
27 |
import requests
|
|
|
641 |
return f"Error: Unexpected error during vLLM summarization - {str(e)}"
|
642 |
|
643 |
|
644 |
+
def summarize_with_ollama(
|
645 |
+
input_data,
|
646 |
+
custom_prompt,
|
647 |
+
api_url="http://127.0.0.1:11434/v1/chat/completions",
|
648 |
+
api_key=None,
|
649 |
+
temp=None,
|
650 |
+
system_message=None,
|
651 |
+
model=None,
|
652 |
+
max_retries=5,
|
653 |
+
retry_delay=20
|
654 |
+
):
|
655 |
try:
|
656 |
+
logging.debug("Ollama: Loading and validating configurations")
|
657 |
loaded_config_data = load_and_log_configs()
|
658 |
if loaded_config_data is None:
|
659 |
logging.error("Failed to load configuration data")
|
|
|
671 |
else:
|
672 |
logging.warning("Ollama: No API key found in config file")
|
673 |
|
674 |
+
# Set model from parameter or config
|
675 |
+
if model is None:
|
676 |
+
model = loaded_config_data['models'].get('ollama')
|
677 |
+
if model is None:
|
678 |
+
logging.error("Ollama: Model not found in config file")
|
679 |
+
return "Ollama: Model not found in config file"
|
680 |
+
|
681 |
+
# Set api_url from parameter or config
|
682 |
+
if api_url is None:
|
683 |
+
api_url = loaded_config_data['local_api_ip'].get('ollama')
|
684 |
+
if api_url is None:
|
685 |
+
logging.error("Ollama: API URL not found in config file")
|
686 |
+
return "Ollama: API URL not found in config file"
|
687 |
|
688 |
# Load transcript
|
689 |
logging.debug("Ollama: Loading JSON data")
|
|
|
712 |
else:
|
713 |
raise ValueError("Ollama: Invalid input data format")
|
714 |
|
|
|
|
|
|
|
|
|
|
|
715 |
headers = {
|
716 |
'accept': 'application/json',
|
717 |
'content-type': 'application/json',
|
718 |
}
|
719 |
+
if ollama_api_key and len(ollama_api_key) > 5:
|
720 |
headers['Authorization'] = f'Bearer {ollama_api_key}'
|
721 |
|
722 |
+
ollama_prompt = f"{custom_prompt}\n\n{text}"
|
|
|
|
|
|
|
723 |
if system_message is None:
|
724 |
system_message = "You are a helpful AI assistant."
|
725 |
+
logging.debug(f"Ollama: Prompt being sent is: {ollama_prompt}")
|
726 |
|
727 |
+
data_payload = {
|
728 |
"model": model,
|
729 |
"messages": [
|
730 |
+
{
|
731 |
+
"role": "system",
|
732 |
+
"content": system_message
|
733 |
+
},
|
734 |
+
{
|
735 |
+
"role": "user",
|
736 |
+
"content": ollama_prompt
|
737 |
+
}
|
738 |
],
|
739 |
+
'temperature': temp
|
740 |
}
|
741 |
|
742 |
+
for attempt in range(1, max_retries + 1):
|
743 |
+
logging.debug("Ollama: Submitting request to API endpoint")
|
744 |
+
print("Ollama: Submitting request to API endpoint")
|
745 |
+
try:
|
746 |
+
response = requests.post(api_url, headers=headers, json=data_payload, timeout=30)
|
747 |
+
response.raise_for_status() # Raises HTTPError for bad responses
|
748 |
+
response_data = response.json()
|
749 |
+
except requests.exceptions.Timeout:
|
750 |
+
logging.error("Ollama: Request timed out.")
|
751 |
+
return "Ollama: Request timed out."
|
752 |
+
except requests.exceptions.HTTPError as http_err:
|
753 |
+
logging.error(f"Ollama: HTTP error occurred: {http_err}")
|
754 |
+
return f"Ollama: HTTP error occurred: {http_err}"
|
755 |
+
except requests.exceptions.RequestException as req_err:
|
756 |
+
logging.error(f"Ollama: Request exception: {req_err}")
|
757 |
+
return f"Ollama: Request exception: {req_err}"
|
758 |
+
except json.JSONDecodeError:
|
759 |
+
logging.error("Ollama: Failed to decode JSON response")
|
760 |
+
return "Ollama: Failed to decode JSON response."
|
761 |
+
except Exception as e:
|
762 |
+
logging.error(f"Ollama: An unexpected error occurred: {str(e)}")
|
763 |
+
return f"Ollama: An unexpected error occurred: {str(e)}"
|
764 |
+
|
765 |
+
logging.debug(f"API Response Data: {response_data}")
|
766 |
|
767 |
+
if response.status_code == 200:
|
768 |
+
# Inspect available keys
|
769 |
+
available_keys = list(response_data.keys())
|
770 |
+
logging.debug(f"Ollama: Available keys in response: {available_keys}")
|
771 |
+
|
772 |
+
# Attempt to retrieve 'response'
|
773 |
+
summary = None
|
774 |
+
if 'response' in response_data and response_data['response']:
|
775 |
+
summary = response_data['response'].strip()
|
776 |
+
elif 'choices' in response_data and len(response_data['choices']) > 0:
|
777 |
+
choice = response_data['choices'][0]
|
778 |
+
if 'message' in choice and 'content' in choice['message']:
|
779 |
+
summary = choice['message']['content'].strip()
|
780 |
+
|
781 |
+
if summary:
|
782 |
+
logging.debug("Ollama: Chat request successful")
|
783 |
+
print("\n\nChat request successful.")
|
784 |
+
return summary
|
785 |
+
elif response_data.get('done_reason') == 'load':
|
786 |
+
logging.warning(f"Ollama: Model is loading. Attempt {attempt} of {max_retries}. Retrying in {retry_delay} seconds...")
|
787 |
+
time.sleep(retry_delay)
|
788 |
+
else:
|
789 |
+
logging.error("Ollama: API response does not contain 'response' or 'choices'.")
|
790 |
+
return "Ollama: API response does not contain 'response' or 'choices'."
|
791 |
+
else:
|
792 |
+
logging.error(f"Ollama: API request failed with status code {response.status_code}: {response.text}")
|
793 |
+
return f"Ollama: API request failed: {response.text}"
|
794 |
+
|
795 |
+
logging.error("Ollama: Maximum retry attempts reached. Model is still loading.")
|
796 |
+
return "Ollama: Maximum retry attempts reached. Model is still loading."
|
797 |
|
798 |
except Exception as e:
|
799 |
+
logging.error("\n\nOllama: Error in processing: %s", str(e))
|
800 |
+
return f"Ollama: Error occurred while processing summary with Ollama: {str(e)}"
|
801 |
|
802 |
|
803 |
# FIXME - update to be a summarize request
|
App_Function_Libraries/Summarization/Summarization_General_Lib.py
CHANGED
@@ -73,7 +73,7 @@ def summarize(
|
|
73 |
elif api_name.lower() == "mistral":
|
74 |
return summarize_with_mistral(api_key, input_data, custom_prompt_arg, temp, system_message)
|
75 |
elif api_name.lower() == "llama.cpp":
|
76 |
-
return summarize_with_llama(input_data, custom_prompt_arg, temp, system_message)
|
77 |
elif api_name.lower() == "kobold":
|
78 |
return summarize_with_kobold(input_data, api_key, custom_prompt_arg, temp, system_message)
|
79 |
elif api_name.lower() == "ooba":
|
@@ -86,6 +86,10 @@ def summarize(
|
|
86 |
return summarize_with_local_llm(input_data, custom_prompt_arg, temp, system_message)
|
87 |
elif api_name.lower() == "huggingface":
|
88 |
return summarize_with_huggingface(api_key, input_data, custom_prompt_arg, temp, )#system_message)
|
|
|
|
|
|
|
|
|
89 |
else:
|
90 |
return f"Error: Invalid API Name {api_name}"
|
91 |
|
@@ -1147,7 +1151,7 @@ def perform_transcription(video_path, offset, whisper_model, vad_filter, diarize
|
|
1147 |
|
1148 |
return audio_file_path, diarized_segments
|
1149 |
|
1150 |
-
# Non-diarized transcription
|
1151 |
if os.path.exists(segments_json_path):
|
1152 |
logging.info(f"Segments file already exists: {segments_json_path}")
|
1153 |
try:
|
|
|
73 |
elif api_name.lower() == "mistral":
|
74 |
return summarize_with_mistral(api_key, input_data, custom_prompt_arg, temp, system_message)
|
75 |
elif api_name.lower() == "llama.cpp":
|
76 |
+
return summarize_with_llama(input_data, custom_prompt_arg, api_key, temp, system_message)
|
77 |
elif api_name.lower() == "kobold":
|
78 |
return summarize_with_kobold(input_data, api_key, custom_prompt_arg, temp, system_message)
|
79 |
elif api_name.lower() == "ooba":
|
|
|
86 |
return summarize_with_local_llm(input_data, custom_prompt_arg, temp, system_message)
|
87 |
elif api_name.lower() == "huggingface":
|
88 |
return summarize_with_huggingface(api_key, input_data, custom_prompt_arg, temp, )#system_message)
|
89 |
+
elif api_name.lower() == "custom-openai":
|
90 |
+
return summarize_with_custom_openai(api_key, input_data, custom_prompt_arg, temp, system_message)
|
91 |
+
elif api_name.lower() == "ollama":
|
92 |
+
return summarize_with_ollama(input_data, custom_prompt_arg, None, api_key, temp, system_message)
|
93 |
else:
|
94 |
return f"Error: Invalid API Name {api_name}"
|
95 |
|
|
|
1151 |
|
1152 |
return audio_file_path, diarized_segments
|
1153 |
|
1154 |
+
# Non-diarized transcription
|
1155 |
if os.path.exists(segments_json_path):
|
1156 |
logging.info(f"Segments file already exists: {segments_json_path}")
|
1157 |
try:
|
App_Function_Libraries/Web_Scraping/Article_Extractor_Lib.py
CHANGED
@@ -13,13 +13,14 @@
|
|
13 |
####################
|
14 |
#
|
15 |
# Import necessary libraries
|
|
|
16 |
import logging
|
17 |
# 3rd-Party Imports
|
18 |
import asyncio
|
19 |
import os
|
20 |
import tempfile
|
21 |
from datetime import datetime
|
22 |
-
from typing import List, Dict
|
23 |
from urllib.parse import urljoin, urlparse
|
24 |
from xml.dom import minidom
|
25 |
from playwright.async_api import async_playwright
|
@@ -376,6 +377,152 @@ def scrape_and_convert_with_filter(source: str, output_file: str, filter_functio
|
|
376 |
|
377 |
logging.info(f"Scraped and filtered content saved to {output_file}")
|
378 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
379 |
#
|
380 |
#
|
381 |
#######################################################################################################################
|
|
|
13 |
####################
|
14 |
#
|
15 |
# Import necessary libraries
|
16 |
+
import json
|
17 |
import logging
|
18 |
# 3rd-Party Imports
|
19 |
import asyncio
|
20 |
import os
|
21 |
import tempfile
|
22 |
from datetime import datetime
|
23 |
+
from typing import List, Dict, Union
|
24 |
from urllib.parse import urljoin, urlparse
|
25 |
from xml.dom import minidom
|
26 |
from playwright.async_api import async_playwright
|
|
|
377 |
|
378 |
logging.info(f"Scraped and filtered content saved to {output_file}")
|
379 |
|
380 |
+
|
381 |
+
###################################################
|
382 |
+
#
|
383 |
+
# Bookmark Parsing Functions
|
384 |
+
|
385 |
+
def parse_chromium_bookmarks(json_data: dict) -> Dict[str, Union[str, List[str]]]:
|
386 |
+
"""
|
387 |
+
Parse Chromium-based browser bookmarks from JSON data.
|
388 |
+
|
389 |
+
:param json_data: The JSON data from the bookmarks file
|
390 |
+
:return: A dictionary with bookmark names as keys and URLs as values or lists of URLs if duplicates exist
|
391 |
+
"""
|
392 |
+
bookmarks = {}
|
393 |
+
|
394 |
+
def recurse_bookmarks(nodes):
|
395 |
+
for node in nodes:
|
396 |
+
if node.get('type') == 'url':
|
397 |
+
name = node.get('name')
|
398 |
+
url = node.get('url')
|
399 |
+
if name and url:
|
400 |
+
if name in bookmarks:
|
401 |
+
if isinstance(bookmarks[name], list):
|
402 |
+
bookmarks[name].append(url)
|
403 |
+
else:
|
404 |
+
bookmarks[name] = [bookmarks[name], url]
|
405 |
+
else:
|
406 |
+
bookmarks[name] = url
|
407 |
+
elif node.get('type') == 'folder' and 'children' in node:
|
408 |
+
recurse_bookmarks(node['children'])
|
409 |
+
|
410 |
+
# Chromium bookmarks have a 'roots' key
|
411 |
+
if 'roots' in json_data:
|
412 |
+
for root in json_data['roots'].values():
|
413 |
+
if 'children' in root:
|
414 |
+
recurse_bookmarks(root['children'])
|
415 |
+
else:
|
416 |
+
recurse_bookmarks(json_data.get('children', []))
|
417 |
+
|
418 |
+
return bookmarks
|
419 |
+
|
420 |
+
|
421 |
+
def parse_firefox_bookmarks(html_content: str) -> Dict[str, Union[str, List[str]]]:
|
422 |
+
"""
|
423 |
+
Parse Firefox bookmarks from HTML content.
|
424 |
+
|
425 |
+
:param html_content: The HTML content from the bookmarks file
|
426 |
+
:return: A dictionary with bookmark names as keys and URLs as values or lists of URLs if duplicates exist
|
427 |
+
"""
|
428 |
+
bookmarks = {}
|
429 |
+
soup = BeautifulSoup(html_content, 'html.parser')
|
430 |
+
|
431 |
+
# Firefox stores bookmarks within <a> tags inside <dt>
|
432 |
+
for a in soup.find_all('a'):
|
433 |
+
name = a.get_text()
|
434 |
+
url = a.get('href')
|
435 |
+
if name and url:
|
436 |
+
if name in bookmarks:
|
437 |
+
if isinstance(bookmarks[name], list):
|
438 |
+
bookmarks[name].append(url)
|
439 |
+
else:
|
440 |
+
bookmarks[name] = [bookmarks[name], url]
|
441 |
+
else:
|
442 |
+
bookmarks[name] = url
|
443 |
+
|
444 |
+
return bookmarks
|
445 |
+
|
446 |
+
|
447 |
+
def load_bookmarks(file_path: str) -> Dict[str, Union[str, List[str]]]:
|
448 |
+
"""
|
449 |
+
Load bookmarks from a file (JSON for Chrome/Edge or HTML for Firefox).
|
450 |
+
|
451 |
+
:param file_path: Path to the bookmarks file
|
452 |
+
:return: A dictionary with bookmark names as keys and URLs as values or lists of URLs if duplicates exist
|
453 |
+
:raises ValueError: If the file format is unsupported or parsing fails
|
454 |
+
"""
|
455 |
+
if not os.path.isfile(file_path):
|
456 |
+
logging.error(f"File '{file_path}' does not exist.")
|
457 |
+
raise FileNotFoundError(f"File '{file_path}' does not exist.")
|
458 |
+
|
459 |
+
_, ext = os.path.splitext(file_path)
|
460 |
+
ext = ext.lower()
|
461 |
+
|
462 |
+
if ext == '.json' or ext == '':
|
463 |
+
# Attempt to parse as JSON (Chrome/Edge)
|
464 |
+
try:
|
465 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
466 |
+
json_data = json.load(f)
|
467 |
+
return parse_chromium_bookmarks(json_data)
|
468 |
+
except json.JSONDecodeError:
|
469 |
+
logging.error("Failed to parse JSON. Ensure the file is a valid Chromium bookmarks JSON file.")
|
470 |
+
raise ValueError("Invalid JSON format for Chromium bookmarks.")
|
471 |
+
elif ext in ['.html', '.htm']:
|
472 |
+
# Parse as HTML (Firefox)
|
473 |
+
try:
|
474 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
475 |
+
html_content = f.read()
|
476 |
+
return parse_firefox_bookmarks(html_content)
|
477 |
+
except Exception as e:
|
478 |
+
logging.error(f"Failed to parse HTML bookmarks: {e}")
|
479 |
+
raise ValueError(f"Failed to parse HTML bookmarks: {e}")
|
480 |
+
else:
|
481 |
+
logging.error("Unsupported file format. Please provide a JSON (Chrome/Edge) or HTML (Firefox) bookmarks file.")
|
482 |
+
raise ValueError("Unsupported file format for bookmarks.")
|
483 |
+
|
484 |
+
|
485 |
+
def collect_bookmarks(file_path: str) -> Dict[str, Union[str, List[str]]]:
|
486 |
+
"""
|
487 |
+
Collect bookmarks from the provided bookmarks file and return a dictionary.
|
488 |
+
|
489 |
+
:param file_path: Path to the bookmarks file
|
490 |
+
:return: Dictionary with bookmark names as keys and URLs as values or lists of URLs if duplicates exist
|
491 |
+
"""
|
492 |
+
try:
|
493 |
+
bookmarks = load_bookmarks(file_path)
|
494 |
+
logging.info(f"Successfully loaded {len(bookmarks)} bookmarks from '{file_path}'.")
|
495 |
+
return bookmarks
|
496 |
+
except (FileNotFoundError, ValueError) as e:
|
497 |
+
logging.error(f"Error loading bookmarks: {e}")
|
498 |
+
return {}
|
499 |
+
|
500 |
+
# Usage:
|
501 |
+
# from Article_Extractor_Lib import collect_bookmarks
|
502 |
+
#
|
503 |
+
# # Path to your bookmarks file
|
504 |
+
# # For Chrome or Edge (JSON format)
|
505 |
+
# chromium_bookmarks_path = "/path/to/Bookmarks"
|
506 |
+
#
|
507 |
+
# # For Firefox (HTML format)
|
508 |
+
# firefox_bookmarks_path = "/path/to/bookmarks.html"
|
509 |
+
#
|
510 |
+
# # Collect bookmarks from Chromium-based browser
|
511 |
+
# chromium_bookmarks = collect_bookmarks(chromium_bookmarks_path)
|
512 |
+
# print("Chromium Bookmarks:")
|
513 |
+
# for name, url in chromium_bookmarks.items():
|
514 |
+
# print(f"{name}: {url}")
|
515 |
+
#
|
516 |
+
# # Collect bookmarks from Firefox
|
517 |
+
# firefox_bookmarks = collect_bookmarks(firefox_bookmarks_path)
|
518 |
+
# print("\nFirefox Bookmarks:")
|
519 |
+
# for name, url in firefox_bookmarks.items():
|
520 |
+
# print(f"{name}: {url}")
|
521 |
+
|
522 |
+
#
|
523 |
+
# End of Bookmarking Parsing Functions
|
524 |
+
#####################################################################
|
525 |
+
|
526 |
#
|
527 |
#
|
528 |
#######################################################################################################################
|
App_Function_Libraries/Web_Scraping/Article_Summarization_Lib.py
CHANGED
@@ -165,7 +165,7 @@ def scrape_and_summarize(url, custom_prompt_arg, api_name, api_key, keywords, cu
|
|
165 |
elif api_name.lower() == "ollama":
|
166 |
logging.debug(f"MAIN: Trying to summarize with OLLAMA")
|
167 |
# def summarize_with_ollama(input_data, api_key, custom_prompt, api_url):
|
168 |
-
summary = summarize_with_ollama(json_file_path, article_custom_prompt, api_key, None, system_message, None)
|
169 |
|
170 |
elif api_name == "custom_openai_api":
|
171 |
logging.debug(f"MAIN: Trying to summarize with Custom_OpenAI API")
|
|
|
165 |
elif api_name.lower() == "ollama":
|
166 |
logging.debug(f"MAIN: Trying to summarize with OLLAMA")
|
167 |
# def summarize_with_ollama(input_data, api_key, custom_prompt, api_url):
|
168 |
+
summary = summarize_with_ollama(json_file_path, article_custom_prompt, None, api_key, None, system_message, None)
|
169 |
|
170 |
elif api_name == "custom_openai_api":
|
171 |
logging.debug(f"MAIN: Trying to summarize with Custom_OpenAI API")
|
App_Function_Libraries/html_to_markdown/__init__.py
ADDED
File without changes
|
App_Function_Libraries/html_to_markdown/ast_utils.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# html_to_markdown/ast_utils.py
|
2 |
+
|
3 |
+
from typing import Callable, Optional, List, Union
|
4 |
+
from s_types import SemanticMarkdownAST
|
5 |
+
|
6 |
+
def find_in_ast(ast: Union[SemanticMarkdownAST, List[SemanticMarkdownAST]], predicate: Callable[[SemanticMarkdownAST], bool]) -> Optional[SemanticMarkdownAST]:
|
7 |
+
if isinstance(ast, list):
|
8 |
+
for node in ast:
|
9 |
+
result = find_in_ast(node, predicate)
|
10 |
+
if result:
|
11 |
+
return result
|
12 |
+
else:
|
13 |
+
if predicate(ast):
|
14 |
+
return ast
|
15 |
+
# Recursively search based on node type
|
16 |
+
if hasattr(ast, 'content'):
|
17 |
+
content = ast.content
|
18 |
+
if isinstance(content, list):
|
19 |
+
result = find_in_ast(content, predicate)
|
20 |
+
if result:
|
21 |
+
return result
|
22 |
+
elif isinstance(content, SemanticMarkdownAST):
|
23 |
+
result = find_in_ast(content, predicate)
|
24 |
+
if result:
|
25 |
+
return result
|
26 |
+
if hasattr(ast, 'items'):
|
27 |
+
for item in ast.items:
|
28 |
+
result = find_in_ast(item, predicate)
|
29 |
+
if result:
|
30 |
+
return result
|
31 |
+
if hasattr(ast, 'rows'):
|
32 |
+
for row in ast.rows:
|
33 |
+
result = find_in_ast(row, predicate)
|
34 |
+
if result:
|
35 |
+
return result
|
36 |
+
return None
|
37 |
+
|
38 |
+
def find_all_in_ast(ast: Union[SemanticMarkdownAST, List[SemanticMarkdownAST]], predicate: Callable[[SemanticMarkdownAST], bool]) -> List[SemanticMarkdownAST]:
|
39 |
+
results = []
|
40 |
+
if isinstance(ast, list):
|
41 |
+
for node in ast:
|
42 |
+
results.extend(find_all_in_ast(node, predicate))
|
43 |
+
else:
|
44 |
+
if predicate(ast):
|
45 |
+
results.append(ast)
|
46 |
+
# Recursively search based on node type
|
47 |
+
if hasattr(ast, 'content'):
|
48 |
+
content = ast.content
|
49 |
+
if isinstance(content, list):
|
50 |
+
results.extend(find_all_in_ast(content, predicate))
|
51 |
+
elif isinstance(content, SemanticMarkdownAST):
|
52 |
+
results.extend(find_all_in_ast(content, predicate))
|
53 |
+
if hasattr(ast, 'items'):
|
54 |
+
for item in ast.items:
|
55 |
+
results.extend(find_all_in_ast(item, predicate))
|
56 |
+
if hasattr(ast, 'rows'):
|
57 |
+
for row in ast.rows:
|
58 |
+
results.extend(find_all_in_ast(row, predicate))
|
59 |
+
return results
|
App_Function_Libraries/html_to_markdown/conversion_options.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# html_to_markdown/conversion_options.py
|
2 |
+
|
3 |
+
from typing import Callable, Optional, Union, Dict, Any, List
|
4 |
+
from dataclasses import dataclass, field
|
5 |
+
|
6 |
+
from s_types import SemanticMarkdownAST, CustomNode
|
7 |
+
|
8 |
+
@dataclass
|
9 |
+
class ConversionOptions:
|
10 |
+
website_domain: Optional[str] = None
|
11 |
+
extract_main_content: bool = False
|
12 |
+
refify_urls: bool = False
|
13 |
+
url_map: Dict[str, str] = field(default_factory=dict)
|
14 |
+
debug: bool = False
|
15 |
+
override_dom_parser: Optional[Callable[[str], Any]] = None # Placeholder for DOMParser override
|
16 |
+
enable_table_column_tracking: bool = False
|
17 |
+
override_element_processing: Optional[Callable[[Any, 'ConversionOptions', int], Optional[List[SemanticMarkdownAST]]]] = None
|
18 |
+
process_unhandled_element: Optional[Callable[[Any, 'ConversionOptions', int], Optional[List[SemanticMarkdownAST]]]] = None
|
19 |
+
override_node_renderer: Optional[Callable[[SemanticMarkdownAST, 'ConversionOptions', int], Optional[str]]] = None
|
20 |
+
render_custom_node: Optional[Callable[[CustomNode, 'ConversionOptions', int], Optional[str]]] = None
|
21 |
+
include_meta_data: Union[str, bool] = False # 'basic', 'extended', or False
|
App_Function_Libraries/html_to_markdown/dom_utils.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# html_to_markdown/dom_utils.py
|
2 |
+
|
3 |
+
from bs4 import BeautifulSoup, Tag
|
4 |
+
from typing import Optional
|
5 |
+
import logging
|
6 |
+
|
7 |
+
from conversion_options import ConversionOptions
|
8 |
+
|
9 |
+
logging.basicConfig(level=logging.INFO)
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
+
|
12 |
+
def find_main_content(soup: BeautifulSoup, options: ConversionOptions) -> Tag:
|
13 |
+
logger.debug("Entering find_main_content function")
|
14 |
+
|
15 |
+
main_element = soup.find('main')
|
16 |
+
if main_element:
|
17 |
+
logger.debug("Existing <main> element found")
|
18 |
+
return main_element
|
19 |
+
|
20 |
+
logger.debug("No <main> element found. Detecting main content.")
|
21 |
+
if not soup.body:
|
22 |
+
logger.debug("No body element found, returning the entire document")
|
23 |
+
return soup
|
24 |
+
|
25 |
+
return detect_main_content(soup.body, options)
|
26 |
+
|
27 |
+
def wrap_main_content(main_content: Tag, soup: BeautifulSoup):
|
28 |
+
if main_content.name.lower() != 'main':
|
29 |
+
logger.debug("Wrapping main content in <main> element")
|
30 |
+
main_element = soup.new_tag('main')
|
31 |
+
main_content.wrap(main_element)
|
32 |
+
main_element['id'] = 'detected-main-content'
|
33 |
+
logger.debug("Main content wrapped successfully")
|
34 |
+
else:
|
35 |
+
logger.debug("Main content already wrapped")
|
36 |
+
|
37 |
+
def detect_main_content(element: Tag, options: ConversionOptions) -> Tag:
|
38 |
+
candidates = []
|
39 |
+
min_score = 20
|
40 |
+
logger.debug(f"Collecting candidates with minimum score: {min_score}")
|
41 |
+
collect_candidates(element, candidates, min_score, options)
|
42 |
+
|
43 |
+
logger.debug(f"Total candidates found: {len(candidates)}")
|
44 |
+
|
45 |
+
if not candidates:
|
46 |
+
logger.debug("No suitable candidates found, returning root element")
|
47 |
+
return element
|
48 |
+
|
49 |
+
# Sort candidates by score descending
|
50 |
+
candidates.sort(key=lambda x: calculate_score(x, options), reverse=True)
|
51 |
+
logger.debug("Candidates sorted by score")
|
52 |
+
|
53 |
+
best_candidate = candidates[0]
|
54 |
+
for candidate in candidates[1:]:
|
55 |
+
if not any(other.contains(candidate) for other in candidates):
|
56 |
+
if calculate_score(candidate, options) > calculate_score(best_candidate, options):
|
57 |
+
best_candidate = candidate
|
58 |
+
logger.debug(f"New best independent candidate found: {element_to_string(best_candidate)}")
|
59 |
+
|
60 |
+
logger.debug(f"Final main content candidate: {element_to_string(best_candidate)}")
|
61 |
+
return best_candidate
|
62 |
+
|
63 |
+
def element_to_string(element: Optional[Tag]) -> str:
|
64 |
+
if not element:
|
65 |
+
return 'No element'
|
66 |
+
classes = '.'.join(element.get('class', []))
|
67 |
+
return f"{element.name}#{element.get('id', 'no-id')}.{classes}"
|
68 |
+
|
69 |
+
def collect_candidates(element: Tag, candidates: list, min_score: int, options: ConversionOptions):
|
70 |
+
score = calculate_score(element, options)
|
71 |
+
if score >= min_score:
|
72 |
+
candidates.append(element)
|
73 |
+
logger.debug(f"Candidate found: {element_to_string(element)}, score: {score}")
|
74 |
+
|
75 |
+
for child in element.find_all(recursive=False):
|
76 |
+
collect_candidates(child, candidates, min_score, options)
|
77 |
+
|
78 |
+
def calculate_score(element: Tag, options: ConversionOptions) -> int:
|
79 |
+
score = 0
|
80 |
+
score_log = []
|
81 |
+
|
82 |
+
# High impact attributes
|
83 |
+
high_impact_attributes = ['article', 'content', 'main-container', 'main', 'main-content']
|
84 |
+
for attr in high_impact_attributes:
|
85 |
+
if 'class' in element.attrs and attr in element['class']:
|
86 |
+
score += 10
|
87 |
+
score_log.append(f"High impact attribute found: {attr}, score increased by 10")
|
88 |
+
if 'id' in element.attrs and attr in element['id']:
|
89 |
+
score += 10
|
90 |
+
score_log.append(f"High impact ID found: {attr}, score increased by 10")
|
91 |
+
|
92 |
+
# High impact tags
|
93 |
+
high_impact_tags = ['article', 'main', 'section']
|
94 |
+
if element.name.lower() in high_impact_tags:
|
95 |
+
score += 5
|
96 |
+
score_log.append(f"High impact tag found: {element.name}, score increased by 5")
|
97 |
+
|
98 |
+
# Paragraph count
|
99 |
+
paragraph_count = len(element.find_all('p'))
|
100 |
+
paragraph_score = min(paragraph_count, 5)
|
101 |
+
if paragraph_score > 0:
|
102 |
+
score += paragraph_score
|
103 |
+
score_log.append(f"Paragraph count: {paragraph_count}, score increased by {paragraph_score}")
|
104 |
+
|
105 |
+
# Text content length
|
106 |
+
text_content_length = len(element.get_text(strip=True))
|
107 |
+
if text_content_length > 200:
|
108 |
+
text_score = min(text_content_length // 200, 5)
|
109 |
+
score += text_score
|
110 |
+
score_log.append(f"Text content length: {text_content_length}, score increased by {text_score}")
|
111 |
+
|
112 |
+
# Link density
|
113 |
+
link_density = calculate_link_density(element)
|
114 |
+
if link_density < 0.3:
|
115 |
+
score += 5
|
116 |
+
score_log.append(f"Link density: {link_density:.2f}, score increased by 5")
|
117 |
+
|
118 |
+
# Data attributes
|
119 |
+
if element.has_attr('data-main') or element.has_attr('data-content'):
|
120 |
+
score += 10
|
121 |
+
score_log.append("Data attribute for main content found, score increased by 10")
|
122 |
+
|
123 |
+
# Role attribute
|
124 |
+
if element.get('role') and 'main' in element.get('role'):
|
125 |
+
score += 10
|
126 |
+
score_log.append("Role attribute indicating main content found, score increased by 10")
|
127 |
+
|
128 |
+
if options.debug and score_log:
|
129 |
+
logger.debug(f"Scoring for {element_to_string(element)}:")
|
130 |
+
for log in score_log:
|
131 |
+
logger.debug(f" {log}")
|
132 |
+
logger.debug(f" Final score: {score}")
|
133 |
+
|
134 |
+
return score
|
135 |
+
|
136 |
+
def calculate_link_density(element: Tag) -> float:
|
137 |
+
links = element.find_all('a')
|
138 |
+
link_length = sum(len(link.get_text(strip=True)) for link in links)
|
139 |
+
text_length = len(element.get_text(strip=True)) or 1 # Avoid division by zero
|
140 |
+
return link_length / text_length
|
App_Function_Libraries/html_to_markdown/html_to_markdown.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# html_to_markdown/html_to_markdown.py
|
2 |
+
|
3 |
+
from bs4 import BeautifulSoup
|
4 |
+
from typing import Optional
|
5 |
+
|
6 |
+
from conversion_options import ConversionOptions
|
7 |
+
from dom_utils import find_main_content, wrap_main_content
|
8 |
+
from html_to_markdown_ast import html_to_markdown_ast
|
9 |
+
from markdown_ast_to_string import markdown_ast_to_string
|
10 |
+
from url_utils import refify_urls
|
11 |
+
|
12 |
+
import logging
|
13 |
+
|
14 |
+
logging.basicConfig(level=logging.INFO)
|
15 |
+
logger = logging.getLogger(__name__)
|
16 |
+
|
17 |
+
def convert_html_to_markdown(html: str, options: Optional[ConversionOptions] = None) -> str:
|
18 |
+
if options is None:
|
19 |
+
options = ConversionOptions()
|
20 |
+
|
21 |
+
if options.debug:
|
22 |
+
logger.setLevel(logging.DEBUG)
|
23 |
+
|
24 |
+
soup = BeautifulSoup(html, 'html.parser')
|
25 |
+
|
26 |
+
if options.extract_main_content:
|
27 |
+
main_content = find_main_content(soup, options)
|
28 |
+
if options.include_meta_data and soup.head and not main_content.find('head'):
|
29 |
+
# Reattach head for metadata extraction
|
30 |
+
new_html = f"<html>{soup.head}{main_content}</html>"
|
31 |
+
soup = BeautifulSoup(new_html, 'html.parser')
|
32 |
+
main_content = soup.html
|
33 |
+
else:
|
34 |
+
if options.include_meta_data and soup.head:
|
35 |
+
main_content = soup
|
36 |
+
else:
|
37 |
+
main_content = soup.body if soup.body else soup
|
38 |
+
|
39 |
+
markdown_ast = html_to_markdown_ast(main_content, options)
|
40 |
+
|
41 |
+
if options.refify_urls:
|
42 |
+
options.url_map = refify_urls(markdown_ast, options.url_map)
|
43 |
+
|
44 |
+
markdown_string = markdown_ast_to_string(markdown_ast, options)
|
45 |
+
|
46 |
+
return markdown_string
|
App_Function_Libraries/html_to_markdown/html_to_markdown_ast.py
ADDED
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# html_to_markdown/html_to_markdown_ast.py
|
2 |
+
|
3 |
+
from bs4 import BeautifulSoup, Tag, NavigableString
|
4 |
+
from typing import List, Optional, Union
|
5 |
+
|
6 |
+
from s_types import (
|
7 |
+
SemanticMarkdownAST, TextNode, BoldNode, ItalicNode, StrikethroughNode,
|
8 |
+
HeadingNode, LinkNode, ImageNode, VideoNode, ListNode, ListItemNode,
|
9 |
+
TableNode, TableRowNode, TableCellNode, CodeNode, BlockquoteNode,
|
10 |
+
SemanticHtmlNode, CustomNode, MetaDataNode
|
11 |
+
)
|
12 |
+
from conversion_options import ConversionOptions
|
13 |
+
import logging
|
14 |
+
|
15 |
+
logging.basicConfig(level=logging.INFO)
|
16 |
+
logger = logging.getLogger(__name__)
|
17 |
+
|
18 |
+
def escape_markdown_characters(text: str, is_inline_code: bool = False) -> str:
|
19 |
+
if is_inline_code or not text.strip():
|
20 |
+
return text
|
21 |
+
# Replace special characters
|
22 |
+
replacements = {
|
23 |
+
'\\': '\\\\',
|
24 |
+
'`': '\\`',
|
25 |
+
'*': '\\*',
|
26 |
+
'_': '\\_',
|
27 |
+
'{': '\\{',
|
28 |
+
'}': '\\}',
|
29 |
+
'[': '\\[',
|
30 |
+
']': '\\]',
|
31 |
+
'(': '\\(',
|
32 |
+
')': '\\)',
|
33 |
+
'#': '\\#',
|
34 |
+
'+': '\\+',
|
35 |
+
'-': '\\-',
|
36 |
+
'.': '\\.',
|
37 |
+
'!': '\\!',
|
38 |
+
'|': '\\|',
|
39 |
+
}
|
40 |
+
for char, escaped in replacements.items():
|
41 |
+
text = text.replace(char, escaped)
|
42 |
+
return text
|
43 |
+
|
44 |
+
def html_to_markdown_ast(element: Tag, options: Optional[ConversionOptions] = None, indent_level: int = 0) -> List[SemanticMarkdownAST]:
|
45 |
+
if options is None:
|
46 |
+
options = ConversionOptions()
|
47 |
+
|
48 |
+
result: List[SemanticMarkdownAST] = []
|
49 |
+
|
50 |
+
for child in element.children:
|
51 |
+
if isinstance(child, NavigableString):
|
52 |
+
text_content = escape_markdown_characters(child.strip())
|
53 |
+
if text_content:
|
54 |
+
logger.debug(f"Text Node: '{text_content}'")
|
55 |
+
result.append(TextNode(content=child.strip()))
|
56 |
+
elif isinstance(child, Tag):
|
57 |
+
# Check for overridden element processing
|
58 |
+
if options.override_element_processing:
|
59 |
+
overridden = options.override_element_processing(child, options, indent_level)
|
60 |
+
if overridden:
|
61 |
+
logger.debug(f"Element Processing Overridden: '{child.name}'")
|
62 |
+
result.extend(overridden)
|
63 |
+
continue
|
64 |
+
|
65 |
+
tag_name = child.name.lower()
|
66 |
+
|
67 |
+
if tag_name in ['h1', 'h2', 'h3', 'h4', 'h5', 'h6']:
|
68 |
+
level = int(tag_name[1])
|
69 |
+
content = escape_markdown_characters(child.get_text(strip=True))
|
70 |
+
if content:
|
71 |
+
logger.debug(f"Heading {level}: '{content}'")
|
72 |
+
result.append(HeadingNode(level=level, content=content))
|
73 |
+
elif tag_name == 'p':
|
74 |
+
logger.debug("Paragraph")
|
75 |
+
result.extend(html_to_markdown_ast(child, options, indent_level))
|
76 |
+
# Add a new line after the paragraph
|
77 |
+
result.append(TextNode(content='\n\n'))
|
78 |
+
elif tag_name == 'a':
|
79 |
+
href = child.get('href', '#')
|
80 |
+
if href.startswith("data:image"):
|
81 |
+
# Skip data URLs for images
|
82 |
+
result.append(LinkNode(href='-', content=html_to_markdown_ast(child, options, indent_level)))
|
83 |
+
else:
|
84 |
+
href = href
|
85 |
+
if options.website_domain and href.startswith(options.website_domain):
|
86 |
+
href = href[len(options.website_domain):]
|
87 |
+
# Check if all children are text
|
88 |
+
if all(isinstance(c, NavigableString) for c in child.children):
|
89 |
+
content = [TextNode(content=child.get_text(strip=True))]
|
90 |
+
result.append(LinkNode(href=href, content=content))
|
91 |
+
else:
|
92 |
+
content = html_to_markdown_ast(child, options, indent_level)
|
93 |
+
result.append(LinkNode(href=href, content=content))
|
94 |
+
elif tag_name == 'img':
|
95 |
+
src = child.get('src', '')
|
96 |
+
alt = child.get('alt', '')
|
97 |
+
if src.startswith("data:image"):
|
98 |
+
src = '-'
|
99 |
+
else:
|
100 |
+
if options.website_domain and src.startswith(options.website_domain):
|
101 |
+
src = src[len(options.website_domain):]
|
102 |
+
logger.debug(f"Image: src='{src}', alt='{alt}'")
|
103 |
+
result.append(ImageNode(src=src, alt=alt))
|
104 |
+
elif tag_name == 'video':
|
105 |
+
src = child.get('src', '')
|
106 |
+
poster = child.get('poster', '')
|
107 |
+
controls = child.has_attr('controls')
|
108 |
+
logger.debug(f"Video: src='{src}', poster='{poster}', controls='{controls}'")
|
109 |
+
result.append(VideoNode(src=src, poster=poster, controls=controls))
|
110 |
+
elif tag_name in ['ul', 'ol']:
|
111 |
+
logger.debug(f"{'Unordered' if tag_name == 'ul' else 'Ordered'} List")
|
112 |
+
ordered = tag_name == 'ol'
|
113 |
+
items = []
|
114 |
+
for li in child.find_all('li', recursive=False):
|
115 |
+
item_content = html_to_markdown_ast(li, options, indent_level + 1)
|
116 |
+
items.append(ListItemNode(content=item_content))
|
117 |
+
result.append(ListNode(ordered=ordered, items=items))
|
118 |
+
elif tag_name == 'br':
|
119 |
+
logger.debug("Line Break")
|
120 |
+
result.append(TextNode(content='\n'))
|
121 |
+
elif tag_name == 'table':
|
122 |
+
logger.debug("Table")
|
123 |
+
table_node = TableNode()
|
124 |
+
rows = child.find_all('tr')
|
125 |
+
for row in rows:
|
126 |
+
table_row = TableRowNode()
|
127 |
+
cells = row.find_all(['th', 'td'])
|
128 |
+
for cell in cells:
|
129 |
+
colspan = int(cell.get('colspan', 1))
|
130 |
+
rowspan = int(cell.get('rowspan', 1))
|
131 |
+
cell_content = cell.get_text(strip=True)
|
132 |
+
table_row.cells.append(TableCellNode(content=cell_content, colspan=colspan if colspan >1 else None,
|
133 |
+
rowspan=rowspan if rowspan >1 else None))
|
134 |
+
table_node.rows.append(table_row)
|
135 |
+
result.append(table_node)
|
136 |
+
elif tag_name == 'head' and options.include_meta_data:
|
137 |
+
meta_node = MetaDataNode(content={
|
138 |
+
'standard': {},
|
139 |
+
'openGraph': {},
|
140 |
+
'twitter': {},
|
141 |
+
'jsonLd': []
|
142 |
+
})
|
143 |
+
title = child.find('title')
|
144 |
+
if title:
|
145 |
+
meta_node.content['standard']['title'] = title.get_text(strip=True)
|
146 |
+
meta_tags = child.find_all('meta')
|
147 |
+
non_semantic_tags = ["viewport", "referrer", "Content-Security-Policy"]
|
148 |
+
for meta in meta_tags:
|
149 |
+
name = meta.get('name')
|
150 |
+
prop = meta.get('property')
|
151 |
+
content = meta.get('content', '')
|
152 |
+
if prop and prop.startswith('og:') and content:
|
153 |
+
if options.include_meta_data == 'extended':
|
154 |
+
meta_node.content['openGraph'][prop[3:]] = content
|
155 |
+
elif name and name.startswith('twitter:') and content:
|
156 |
+
if options.include_meta_data == 'extended':
|
157 |
+
meta_node.content['twitter'][name[8:]] = content
|
158 |
+
elif name and name not in non_semantic_tags and content:
|
159 |
+
meta_node.content['standard'][name] = content
|
160 |
+
# Extract JSON-LD data
|
161 |
+
if options.include_meta_data == 'extended':
|
162 |
+
json_ld_scripts = child.find_all('script', type='application/ld+json')
|
163 |
+
for script in json_ld_scripts:
|
164 |
+
try:
|
165 |
+
import json
|
166 |
+
parsed_data = json.loads(script.string)
|
167 |
+
meta_node.content['jsonLd'].append(parsed_data)
|
168 |
+
except json.JSONDecodeError as e:
|
169 |
+
logger.error(f"Failed to parse JSON-LD: {e}")
|
170 |
+
result.append(meta_node)
|
171 |
+
elif tag_name in ['strong', 'b']:
|
172 |
+
content = html_to_markdown_ast(child, options, indent_level + 1)
|
173 |
+
result.append(BoldNode(content=content if content else ""))
|
174 |
+
elif tag_name in ['em', 'i']:
|
175 |
+
content = html_to_markdown_ast(child, options, indent_level + 1)
|
176 |
+
result.append(ItalicNode(content=content if content else ""))
|
177 |
+
elif tag_name in ['s', 'strike']:
|
178 |
+
content = html_to_markdown_ast(child, options, indent_level + 1)
|
179 |
+
result.append(StrikethroughNode(content=content if content else ""))
|
180 |
+
elif tag_name == 'code':
|
181 |
+
is_code_block = child.parent.name == 'pre'
|
182 |
+
content = child.get_text(strip=True)
|
183 |
+
language = ""
|
184 |
+
if not is_code_block:
|
185 |
+
classes = child.get('class', [])
|
186 |
+
for cls in classes:
|
187 |
+
if cls.startswith("language-"):
|
188 |
+
language = cls.replace("language-", "")
|
189 |
+
break
|
190 |
+
result.append(CodeNode(content=content, language=language, inline=not is_code_block))
|
191 |
+
elif tag_name == 'blockquote':
|
192 |
+
content = html_to_markdown_ast(child, options, indent_level +1)
|
193 |
+
result.append(BlockquoteNode(content=content))
|
194 |
+
elif tag_name in [
|
195 |
+
'article', 'aside', 'details', 'figcaption', 'figure', 'footer',
|
196 |
+
'header', 'main', 'mark', 'nav', 'section', 'summary', 'time'
|
197 |
+
]:
|
198 |
+
logger.debug(f"Semantic HTML Element: '{tag_name}'")
|
199 |
+
content = html_to_markdown_ast(child, options, indent_level +1)
|
200 |
+
result.append(SemanticHtmlNode(htmlType=tag_name, content=content))
|
201 |
+
else:
|
202 |
+
# Handle unhandled elements
|
203 |
+
if options.process_unhandled_element:
|
204 |
+
processed = options.process_unhandled_element(child, options, indent_level)
|
205 |
+
if processed:
|
206 |
+
logger.debug(f"Processing Unhandled Element: '{tag_name}'")
|
207 |
+
result.extend(processed)
|
208 |
+
continue
|
209 |
+
# Generic HTML elements
|
210 |
+
logger.debug(f"Generic HTMLElement: '{tag_name}'")
|
211 |
+
result.extend(html_to_markdown_ast(child, options, indent_level +1))
|
212 |
+
return result
|
App_Function_Libraries/html_to_markdown/main.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# html_to_markdown/main.py
|
2 |
+
# Usage: python -m html_to_markdown.main input.html output.md --extract-main --refify-urls --include-meta extended --debug
|
3 |
+
# Arguments:
|
4 |
+
# input.html: Path to your input HTML file.
|
5 |
+
# output.md: Desired path for the output Markdown file.
|
6 |
+
# --extract-main: (Optional) Extracts the main content from the HTML.
|
7 |
+
# --refify-urls: (Optional) Refactors URLs to reference-style.
|
8 |
+
# --include-meta: (Optional) Includes metadata. Choose between basic or extended.
|
9 |
+
# --debug: (Optional) Enables debug logging for detailed trace.
|
10 |
+
|
11 |
+
from html_to_markdown import convert_html_to_markdown
|
12 |
+
from conversion_options import ConversionOptions
|
13 |
+
|
14 |
+
def main():
|
15 |
+
import argparse
|
16 |
+
|
17 |
+
parser = argparse.ArgumentParser(description="Convert HTML to Markdown.")
|
18 |
+
parser.add_argument('input_file', help="Path to the input HTML file.")
|
19 |
+
parser.add_argument('output_file', help="Path to the output Markdown file.")
|
20 |
+
parser.add_argument('--extract-main', action='store_true', help="Extract main content.")
|
21 |
+
parser.add_argument('--refify-urls', action='store_true', help="Refify URLs.")
|
22 |
+
parser.add_argument('--include-meta', choices=['basic', 'extended'], default=False, help="Include metadata.")
|
23 |
+
parser.add_argument('--debug', action='store_true', help="Enable debug logging.")
|
24 |
+
|
25 |
+
args = parser.parse_args()
|
26 |
+
|
27 |
+
with open(args.input_file, 'r', encoding='utf-8') as f:
|
28 |
+
html_content = f.read()
|
29 |
+
|
30 |
+
options = ConversionOptions(
|
31 |
+
extract_main_content=args.extract_main,
|
32 |
+
refify_urls=args.refify_urls,
|
33 |
+
include_meta_data=args.include_meta if args.include_meta else False,
|
34 |
+
debug=args.debug
|
35 |
+
)
|
36 |
+
|
37 |
+
markdown = convert_html_to_markdown(html_content, options)
|
38 |
+
|
39 |
+
with open(args.output_file, 'w', encoding='utf-8') as f:
|
40 |
+
f.write(markdown)
|
41 |
+
|
42 |
+
print(f"Conversion complete. Markdown saved to {args.output_file}")
|
43 |
+
|
44 |
+
if __name__ == "__main__":
|
45 |
+
main()
|
App_Function_Libraries/html_to_markdown/markdown_ast_to_string.py
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# html_to_markdown/markdown_ast_to_string.py
|
2 |
+
import json
|
3 |
+
from ast_utils import find_in_ast
|
4 |
+
from typing import List, Optional, Union
|
5 |
+
from s_types import (
|
6 |
+
SemanticMarkdownAST, TextNode, BoldNode, ItalicNode, StrikethroughNode,
|
7 |
+
HeadingNode, LinkNode, ImageNode, VideoNode, ListNode, ListItemNode,
|
8 |
+
TableNode, TableRowNode, TableCellNode, CodeNode, BlockquoteNode,
|
9 |
+
SemanticHtmlNode, CustomNode, MetaDataNode
|
10 |
+
)
|
11 |
+
from conversion_options import ConversionOptions
|
12 |
+
import logging
|
13 |
+
|
14 |
+
logging.basicConfig(level=logging.INFO)
|
15 |
+
logger = logging.getLogger(__name__)
|
16 |
+
|
17 |
+
def markdown_ast_to_string(nodes: List[SemanticMarkdownAST], options: Optional[ConversionOptions] = None, indent_level: int = 0) -> str:
|
18 |
+
if options is None:
|
19 |
+
options = ConversionOptions()
|
20 |
+
|
21 |
+
markdown_string = ""
|
22 |
+
markdown_string += markdown_meta_ast_to_string(nodes, options, indent_level)
|
23 |
+
markdown_string += markdown_content_ast_to_string(nodes, options, indent_level)
|
24 |
+
return markdown_string
|
25 |
+
|
26 |
+
def markdown_meta_ast_to_string(nodes: List[SemanticMarkdownAST], options: ConversionOptions, indent_level: int) -> str:
|
27 |
+
markdown_string = ""
|
28 |
+
if options.include_meta_data:
|
29 |
+
markdown_string += "---\n"
|
30 |
+
node = find_in_ast(nodes, lambda x: isinstance(x, MetaDataNode))
|
31 |
+
if node and isinstance(node, MetaDataNode):
|
32 |
+
standard = node.content.get('standard', {})
|
33 |
+
for key, value in standard.items():
|
34 |
+
markdown_string += f'{key}: "{value}"\n'
|
35 |
+
if options.include_meta_data == 'extended':
|
36 |
+
open_graph = node.content.get('openGraph', {})
|
37 |
+
twitter = node.content.get('twitter', {})
|
38 |
+
json_ld = node.content.get('jsonLd', [])
|
39 |
+
|
40 |
+
if open_graph:
|
41 |
+
markdown_string += "openGraph:\n"
|
42 |
+
for key, value in open_graph.items():
|
43 |
+
markdown_string += f" {key}: \"{value}\"\n"
|
44 |
+
|
45 |
+
if twitter:
|
46 |
+
markdown_string += "twitter:\n"
|
47 |
+
for key, value in twitter.items():
|
48 |
+
markdown_string += f" {key}: \"{value}\"\n"
|
49 |
+
|
50 |
+
if json_ld:
|
51 |
+
markdown_string += "schema:\n"
|
52 |
+
for item in json_ld:
|
53 |
+
jld_type = item.get('@type', '(unknown type)')
|
54 |
+
markdown_string += f" {jld_type}:\n"
|
55 |
+
for key, value in item.items():
|
56 |
+
if key in ['@context', '@type']:
|
57 |
+
continue
|
58 |
+
markdown_string += f" {key}: {json.dumps(value)}\n"
|
59 |
+
markdown_string += "---\n\n"
|
60 |
+
return markdown_string
|
61 |
+
|
62 |
+
def markdown_content_ast_to_string(nodes: List[SemanticMarkdownAST], options: ConversionOptions, indent_level: int) -> str:
|
63 |
+
markdown_string = ""
|
64 |
+
for node in nodes:
|
65 |
+
# Skip meta nodes as they are already handled
|
66 |
+
if isinstance(node, MetaDataNode):
|
67 |
+
continue
|
68 |
+
|
69 |
+
# Override node renderer if provided
|
70 |
+
if options.override_node_renderer:
|
71 |
+
override = options.override_node_renderer(node, options, indent_level)
|
72 |
+
if override:
|
73 |
+
markdown_string += override
|
74 |
+
continue
|
75 |
+
|
76 |
+
if isinstance(node, TextNode):
|
77 |
+
markdown_string += f"{node.content}"
|
78 |
+
elif isinstance(node, BoldNode):
|
79 |
+
content = ast_to_markdown(node.content, options, indent_level)
|
80 |
+
markdown_string += f"**{content}**"
|
81 |
+
elif isinstance(node, ItalicNode):
|
82 |
+
content = ast_to_markdown(node.content, options, indent_level)
|
83 |
+
markdown_string += f"*{content}*"
|
84 |
+
elif isinstance(node, StrikethroughNode):
|
85 |
+
content = ast_to_markdown(node.content, options, indent_level)
|
86 |
+
markdown_string += f"~~{content}~~"
|
87 |
+
elif isinstance(node, HeadingNode):
|
88 |
+
markdown_string += f"\n{'#' * node.level} {node.content}\n\n"
|
89 |
+
elif isinstance(node, LinkNode):
|
90 |
+
content = ast_to_markdown(node.content, options, indent_level)
|
91 |
+
if all(isinstance(c, TextNode) for c in node.content):
|
92 |
+
markdown_string += f"[{content}]({node.href})"
|
93 |
+
else:
|
94 |
+
# Use HTML <a> tag for links with rich content
|
95 |
+
markdown_string += f"<a href=\"{node.href}\">{content}</a>"
|
96 |
+
elif isinstance(node, ImageNode):
|
97 |
+
alt = node.alt or ""
|
98 |
+
src = node.src or ""
|
99 |
+
if alt.strip() or src.strip():
|
100 |
+
markdown_string += f"![{alt}]({src})"
|
101 |
+
elif isinstance(node, VideoNode):
|
102 |
+
markdown_string += f"\n![Video]({node.src})\n"
|
103 |
+
if node.poster:
|
104 |
+
markdown_string += f"![Poster]({node.poster})\n"
|
105 |
+
if node.controls:
|
106 |
+
markdown_string += f"Controls: {node.controls}\n"
|
107 |
+
markdown_string += "\n"
|
108 |
+
elif isinstance(node, ListNode):
|
109 |
+
for idx, item in enumerate(node.items):
|
110 |
+
prefix = f"{idx + 1}." if node.ordered else "-"
|
111 |
+
content = ast_to_markdown(item.content, options, indent_level +1).strip()
|
112 |
+
markdown_string += f"{' ' * indent_level}{prefix} {content}\n"
|
113 |
+
markdown_string += "\n"
|
114 |
+
elif isinstance(node, TableNode):
|
115 |
+
if not node.rows:
|
116 |
+
continue
|
117 |
+
max_columns = max(
|
118 |
+
sum(cell.colspan or 1 for cell in row.cells) for row in node.rows
|
119 |
+
)
|
120 |
+
for row_idx, row in enumerate(node.rows):
|
121 |
+
for cell in row.cells:
|
122 |
+
content = cell.content if isinstance(cell.content, str) else ast_to_markdown(cell.content, options, indent_level +1).strip()
|
123 |
+
markdown_string += f"| {content} "
|
124 |
+
# Fill remaining columns
|
125 |
+
remaining = max_columns - sum(cell.colspan or 1 for cell in row.cells)
|
126 |
+
for _ in range(remaining):
|
127 |
+
markdown_string += "| "
|
128 |
+
markdown_string += "|\n"
|
129 |
+
if row_idx == 0:
|
130 |
+
# Add header separator
|
131 |
+
markdown_string += "|" + "|".join([' --- ' for _ in range(max_columns)]) + "|\n"
|
132 |
+
markdown_string += "\n"
|
133 |
+
elif isinstance(node, CodeNode):
|
134 |
+
if node.inline:
|
135 |
+
markdown_string += f"`{node.content}`"
|
136 |
+
else:
|
137 |
+
language = node.language or ""
|
138 |
+
markdown_string += f"\n```{language}\n{node.content}\n```\n\n"
|
139 |
+
elif isinstance(node, BlockquoteNode):
|
140 |
+
content = ast_to_markdown(node.content, options, indent_level).strip()
|
141 |
+
markdown_string += f"> {content}\n\n"
|
142 |
+
elif isinstance(node, SemanticHtmlNode):
|
143 |
+
if node.htmlType in ["summary", "time", "aside", "nav", "figcaption", "main", "mark", "header", "footer", "details", "figure"]:
|
144 |
+
markdown_string += f"\n<-{node.htmlType}->\n{ast_to_markdown(node.content, options, indent_level)}\n\n</-{node.htmlType}->\n\n"
|
145 |
+
elif node.htmlType == "article":
|
146 |
+
markdown_string += f"\n\n{ast_to_markdown(node.content, options, indent_level)}\n\n"
|
147 |
+
elif node.htmlType == "section":
|
148 |
+
markdown_string += "---\n\n"
|
149 |
+
markdown_string += f"{ast_to_markdown(node.content, options, indent_level)}\n\n---\n\n"
|
150 |
+
elif isinstance(node, CustomNode):
|
151 |
+
if options.render_custom_node:
|
152 |
+
custom_render = options.render_custom_node(node, options, indent_level)
|
153 |
+
if custom_render:
|
154 |
+
markdown_string += custom_render
|
155 |
+
# Add more node types as needed
|
156 |
+
return markdown_string
|
157 |
+
|
158 |
+
def ast_to_markdown(content: Union[str, List[SemanticMarkdownAST]], options: ConversionOptions, indent_level: int) -> str:
|
159 |
+
if isinstance(content, str):
|
160 |
+
return content
|
161 |
+
else:
|
162 |
+
return markdown_content_ast_to_string(content, options, indent_level)
|
163 |
+
|
App_Function_Libraries/html_to_markdown/s_types.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# html_to_markdown/types.py
|
2 |
+
|
3 |
+
from dataclasses import dataclass, field
|
4 |
+
from typing import List, Optional, Union, Dict, Any
|
5 |
+
|
6 |
+
@dataclass
|
7 |
+
class TextNode:
|
8 |
+
type: str = "text"
|
9 |
+
content: str = ""
|
10 |
+
|
11 |
+
@dataclass
|
12 |
+
class BoldNode:
|
13 |
+
type: str = "bold"
|
14 |
+
content: Union[str, List['SemanticMarkdownAST']] = ""
|
15 |
+
|
16 |
+
@dataclass
|
17 |
+
class ItalicNode:
|
18 |
+
type: str = "italic"
|
19 |
+
content: Union[str, List['SemanticMarkdownAST']] = ""
|
20 |
+
|
21 |
+
@dataclass
|
22 |
+
class StrikethroughNode:
|
23 |
+
type: str = "strikethrough"
|
24 |
+
content: Union[str, List['SemanticMarkdownAST']] = ""
|
25 |
+
|
26 |
+
@dataclass
|
27 |
+
class HeadingNode:
|
28 |
+
type: str = "heading"
|
29 |
+
level: int = 1
|
30 |
+
content: str = ""
|
31 |
+
|
32 |
+
@dataclass
|
33 |
+
class LinkNode:
|
34 |
+
type: str = "link"
|
35 |
+
href: str = ""
|
36 |
+
content: List['SemanticMarkdownAST'] = field(default_factory=list)
|
37 |
+
|
38 |
+
@dataclass
|
39 |
+
class ImageNode:
|
40 |
+
type: str = "image"
|
41 |
+
src: str = ""
|
42 |
+
alt: Optional[str] = ""
|
43 |
+
|
44 |
+
@dataclass
|
45 |
+
class VideoNode:
|
46 |
+
type: str = "video"
|
47 |
+
src: str = ""
|
48 |
+
poster: Optional[str] = ""
|
49 |
+
controls: bool = False
|
50 |
+
|
51 |
+
@dataclass
|
52 |
+
class ListItemNode:
|
53 |
+
type: str = "listItem"
|
54 |
+
content: List['SemanticMarkdownAST'] = field(default_factory=list)
|
55 |
+
|
56 |
+
@dataclass
|
57 |
+
class ListNode:
|
58 |
+
type: str = "list"
|
59 |
+
ordered: bool = False
|
60 |
+
items: List[ListItemNode] = field(default_factory=list)
|
61 |
+
|
62 |
+
@dataclass
|
63 |
+
class TableCellNode:
|
64 |
+
type: str = "tableCell"
|
65 |
+
content: Union[str, List['SemanticMarkdownAST']] = ""
|
66 |
+
colId: Optional[str] = None
|
67 |
+
colspan: Optional[int] = None
|
68 |
+
rowspan: Optional[int] = None
|
69 |
+
|
70 |
+
@dataclass
|
71 |
+
class TableRowNode:
|
72 |
+
type: str = "tableRow"
|
73 |
+
cells: List[TableCellNode] = field(default_factory=list)
|
74 |
+
|
75 |
+
@dataclass
|
76 |
+
class TableNode:
|
77 |
+
type: str = "table"
|
78 |
+
rows: List[TableRowNode] = field(default_factory=list)
|
79 |
+
colIds: Optional[List[str]] = None
|
80 |
+
|
81 |
+
@dataclass
|
82 |
+
class CodeNode:
|
83 |
+
type: str = "code"
|
84 |
+
language: Optional[str] = ""
|
85 |
+
content: str = ""
|
86 |
+
inline: bool = False
|
87 |
+
|
88 |
+
@dataclass
|
89 |
+
class BlockquoteNode:
|
90 |
+
type: str = "blockquote"
|
91 |
+
content: List['SemanticMarkdownAST'] = field(default_factory=list)
|
92 |
+
|
93 |
+
@dataclass
|
94 |
+
class CustomNode:
|
95 |
+
type: str = "custom"
|
96 |
+
content: Any = None
|
97 |
+
|
98 |
+
@dataclass
|
99 |
+
class SemanticHtmlNode:
|
100 |
+
type: str = "semanticHtml"
|
101 |
+
htmlType: str = ""
|
102 |
+
content: List['SemanticMarkdownAST'] = field(default_factory=list)
|
103 |
+
|
104 |
+
@dataclass
|
105 |
+
class MetaDataNode:
|
106 |
+
type: str = "meta"
|
107 |
+
content: Dict[str, Any] = field(default_factory=dict)
|
108 |
+
|
109 |
+
# Union of all node types
|
110 |
+
SemanticMarkdownAST = Union[
|
111 |
+
TextNode,
|
112 |
+
BoldNode,
|
113 |
+
ItalicNode,
|
114 |
+
StrikethroughNode,
|
115 |
+
HeadingNode,
|
116 |
+
LinkNode,
|
117 |
+
ImageNode,
|
118 |
+
VideoNode,
|
119 |
+
ListNode,
|
120 |
+
TableNode,
|
121 |
+
CodeNode,
|
122 |
+
BlockquoteNode,
|
123 |
+
SemanticHtmlNode,
|
124 |
+
CustomNode,
|
125 |
+
MetaDataNode
|
126 |
+
]
|
App_Function_Libraries/html_to_markdown/url_utils.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# html_to_markdown/url_utils.py
|
2 |
+
|
3 |
+
from typing import Dict
|
4 |
+
|
5 |
+
media_suffixes = [
|
6 |
+
"jpeg", "jpg", "png", "gif", "bmp", "tiff", "tif", "svg",
|
7 |
+
"webp", "ico", "avi", "mov", "mp4", "mkv", "flv", "wmv",
|
8 |
+
"webm", "mpeg", "mpg", "mp3", "wav", "aac", "ogg", "flac",
|
9 |
+
"m4a", "pdf", "doc", "docx", "ppt", "pptx", "xls", "xlsx",
|
10 |
+
"txt", "css", "js", "xml", "json", "html", "htm"
|
11 |
+
]
|
12 |
+
|
13 |
+
def add_ref_prefix(prefix: str, prefixes_to_refs: Dict[str, str]) -> str:
|
14 |
+
if prefix not in prefixes_to_refs:
|
15 |
+
prefixes_to_refs[prefix] = f'ref{len(prefixes_to_refs)}'
|
16 |
+
return prefixes_to_refs[prefix]
|
17 |
+
|
18 |
+
def process_url(url: str, prefixes_to_refs: Dict[str, str]) -> str:
|
19 |
+
if not url.startswith('http'):
|
20 |
+
return url
|
21 |
+
else:
|
22 |
+
parts = url.split('/')
|
23 |
+
media_suffix = parts[-1].split('.')[-1].lower()
|
24 |
+
if media_suffix in media_suffixes:
|
25 |
+
prefix = '/'.join(parts[:-1])
|
26 |
+
ref_prefix = add_ref_prefix(prefix, prefixes_to_refs)
|
27 |
+
return f"{ref_prefix}://{parts[-1]}"
|
28 |
+
else:
|
29 |
+
if len(parts) > 4:
|
30 |
+
return add_ref_prefix(url, prefixes_to_refs)
|
31 |
+
else:
|
32 |
+
return url
|
33 |
+
|
34 |
+
def refify_urls(markdown_elements: list, prefixes_to_refs: Dict[str, str] = {}) -> Dict[str, str]:
|
35 |
+
for element in markdown_elements:
|
36 |
+
if isinstance(element, dict):
|
37 |
+
node_type = element.get('type')
|
38 |
+
if node_type == 'link':
|
39 |
+
original_href = element.get('href', '')
|
40 |
+
element['href'] = process_url(original_href, prefixes_to_refs)
|
41 |
+
refify_urls(element.get('content', []), prefixes_to_refs)
|
42 |
+
elif node_type in ['image', 'video']:
|
43 |
+
original_src = element.get('src', '')
|
44 |
+
element['src'] = process_url(original_src, prefixes_to_refs)
|
45 |
+
elif node_type == 'list':
|
46 |
+
for item in element.get('items', []):
|
47 |
+
refify_urls(item.get('content', []), prefixes_to_refs)
|
48 |
+
elif node_type == 'table':
|
49 |
+
for row in element.get('rows', []):
|
50 |
+
for cell in row.get('cells', []):
|
51 |
+
if isinstance(cell.get('content'), list):
|
52 |
+
refify_urls(cell['content'], prefixes_to_refs)
|
53 |
+
elif node_type in ['blockquote', 'semanticHtml']:
|
54 |
+
refify_urls(element.get('content', []), prefixes_to_refs)
|
55 |
+
return prefixes_to_refs
|
App_Function_Libraries/models/pyannote_diarization_config.yaml
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
pipeline:
|
2 |
+
params:
|
3 |
+
clustering: AgglomerativeClustering
|
4 |
+
embedding: /FULL/PATH/TO/SCRIPT/tldw/App_Function_Libraries/models/pyannote_model_wespeaker-voxceleb-resnet34-LM.bin #models/pyannote_model_wespeaker-voxceleb-resnet34-LM.bin
|
5 |
+
segmentation: /FULL/PATH/TO/SCRIPT/tldw/App_Function_Libraries/models/pyannote_model_segmentation-3.0.bin #models/pyannote_model_segmentation-3.0.bin
|
6 |
+
|
7 |
+
params:
|
8 |
+
segmentation:
|
9 |
+
min_duration_off: 0.0
|
10 |
+
clustering:
|
11 |
+
method: centroid
|
12 |
+
min_cluster_size: 12
|
13 |
+
threshold: 0.7045654963945799
|
App_Function_Libraries/models/pyannote_model_segmentation-3.0.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:da85c29829d4002daedd676e012936488234d9255e65e86dfab9bec6b1729298
|
3 |
+
size 5905440
|
App_Function_Libraries/models/pyannote_model_wespeaker-voxceleb-resnet34-LM.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:366edf44f4c80889a3eb7a9d7bdf02c4aede3127f7dd15e274dcdb826b143c56
|
3 |
+
size 26645418
|
App_Function_Libraries/test.gguf
ADDED
File without changes
|