|
import os |
|
import re |
|
import gc |
|
import shutil |
|
import torch |
|
import transformers |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter, TokenTextSplitter |
|
|
|
ENV_FILE_PATH = os.path.join(os.getenv("WRITABLE_DIR", "/tmp"), ".env") |
|
WEBHOOK_PATH = os.path.join(os.getcwd(), ".webhook_secret") |
|
SLACK_CREDENTIALS_PATH = os.path.join(os.getcwd(), ".slack_credentials") |
|
|
|
def remove_markdown(text: str) -> str: |
|
|
|
text = re.sub(r'```[a-zA-Z]*\n', '', text) |
|
text = re.sub(r'```', '', text) |
|
|
|
|
|
text = re.sub(r'^\s*#+\s+', '', text, flags=re.MULTILINE) |
|
|
|
|
|
text = re.sub(r'\*\*(.*?)\*\*', r'\1', text) |
|
text = re.sub(r'__(.*?)__', r'\1', text) |
|
text = re.sub(r'\*(.*?)\*', r'\1', text) |
|
text = re.sub(r'_(.*?)_', r'\1', text) |
|
|
|
|
|
text = re.sub(r'~~(.*?)~~', r'\1', text) |
|
|
|
|
|
text = re.sub(r'`(.*?)`', r'\1', text) |
|
|
|
|
|
text = re.sub(r'\[(.*?)\]\((.*?)\)', r'\1', text) |
|
|
|
|
|
text = re.sub(r'!\[(.*?)\]\((.*?)\)', '', text) |
|
|
|
|
|
text = re.sub(r'^\s*>\s+', '', text, flags=re.MULTILINE) |
|
|
|
|
|
text = re.sub(r'^\s*[\*\+-]\s+', '', text, flags=re.MULTILINE) |
|
text = re.sub(r'^\s*\d+\.\s+', '', text, flags=re.MULTILINE) |
|
|
|
|
|
text = re.sub(r'^\s*[-*_]{3,}\s*$', '', text, flags=re.MULTILINE) |
|
|
|
|
|
text = re.sub(r'[*_~`]', '', text) |
|
|
|
return text.strip() |
|
|
|
def remove_outer_markdown_block(chunk, _acc={"b":""}): |
|
_acc["b"] += chunk |
|
p = re.compile(r'```markdown\s*\n(.*?)\n?```', re.DOTALL|re.IGNORECASE) |
|
o = [] |
|
|
|
while True: |
|
m = p.search(_acc["b"]) |
|
if not m: |
|
break |
|
|
|
s,e = m.span() |
|
o.append(_acc["b"][:s]+m.group(1)) |
|
_acc["b"] = _acc["b"][e:] |
|
|
|
if '```markdown' not in _acc["b"].lower(): |
|
o.append(_acc["b"]) |
|
_acc["b"] = "" |
|
|
|
return "".join(o) |
|
|
|
def clear_gpu_memory(): |
|
|
|
if torch.cuda.is_available(): |
|
try: |
|
print("Starting the GPU memory cleanup process...") |
|
|
|
torch.cuda.empty_cache() |
|
|
|
device_count = torch.cuda.device_count() |
|
print(f"Number of GPUs: {device_count}") |
|
for device_id in range(device_count): |
|
print(f"Clearing GPU memory and cache for device {device_id}...") |
|
|
|
torch.cuda.set_device(device_id) |
|
torch.cuda.reset_peak_memory_stats(torch.cuda.current_device()) |
|
torch.cuda.empty_cache() |
|
|
|
torch.cuda.synchronize() |
|
torch.cuda.ipc_collect() |
|
except Exception as e: |
|
raise Exception(f"Error clearing GPU memory and cache: {e}") |
|
|
|
def clear_memory(): |
|
|
|
print("Deleting all tensors and models...") |
|
for obj in gc.get_objects(): |
|
try: |
|
if torch.is_tensor(obj): |
|
del obj |
|
elif isinstance(obj, transformers.PreTrainedModel) or \ |
|
isinstance(obj, transformers.tokenization_utils_base.PreTrainedTokenizerBase) or \ |
|
"SentenceTransformer" in str(type(obj)): |
|
|
|
model_name = "" |
|
if hasattr(obj, "name_or_path"): |
|
model_name = obj.name_or_path |
|
elif hasattr(obj, "config") and hasattr(obj.config, "_name_or_path"): |
|
model_name = obj.config._name_or_path |
|
else: |
|
model_name = str(type(obj)) |
|
|
|
print(f"Deleting model: {model_name}") |
|
del obj |
|
except Exception as e: |
|
print(f"Error during deletion: {e}") |
|
|
|
gc.collect() |
|
|
|
|
|
def chunk_text(input_text, max_chunk_length=100, overlap=0, context_length=None): |
|
|
|
chunk_size = context_length if isinstance(context_length, int) and context_length > 0 else max_chunk_length |
|
|
|
splitter = RecursiveCharacterTextSplitter( |
|
separators=["\n\n", "\n", ". ", " ", ""], |
|
chunk_size=chunk_size, |
|
chunk_overlap=overlap, |
|
length_function=len |
|
) |
|
chunks = splitter.split_text(input_text) |
|
|
|
token_splitter = TokenTextSplitter(chunk_size=max_chunk_length, chunk_overlap=overlap) \ |
|
if not context_length else None |
|
|
|
final_chunks = [] |
|
span_annotations = [] |
|
current_position = 0 |
|
|
|
for chunk in chunks: |
|
|
|
current_chunks = token_splitter.split_text(chunk) if token_splitter else [chunk] |
|
final_chunks.extend(current_chunks) |
|
|
|
for tc in current_chunks: |
|
span_annotations.append((current_position, current_position + len(tc))) |
|
current_position += len(tc) |
|
|
|
return final_chunks, span_annotations |
|
|
|
|
|
def read_env(): |
|
env_dict = {} |
|
if not os.path.exists(ENV_FILE_PATH): |
|
return env_dict |
|
|
|
with open(ENV_FILE_PATH, "r", encoding="utf-8") as f: |
|
for line in f: |
|
line = line.strip() |
|
if not line or line.startswith("#"): |
|
continue |
|
if "=" in line: |
|
var, val = line.split("=", 1) |
|
env_dict[var.strip()] = val.strip() |
|
return env_dict |
|
|
|
|
|
def update_env_vars(new_values: dict): |
|
|
|
def load_webhook_url_securely() -> str: |
|
if os.path.exists(WEBHOOK_PATH): |
|
with open(WEBHOOK_PATH, "r", encoding="utf-8") as f: |
|
return f.read().strip() |
|
raise FileNotFoundError(f"Webhook secret file not found at {WEBHOOK_PATH}") |
|
|
|
|
|
webhook_url = load_webhook_url_securely() |
|
new_values["PIPEDREAM_WEBHOOK_URL"] = webhook_url |
|
|
|
|
|
if os.path.exists(SLACK_CREDENTIALS_PATH): |
|
with open(SLACK_CREDENTIALS_PATH, "r", encoding="utf-8") as f: |
|
lines = [line.strip() for line in f if line.strip()] |
|
if len(lines) >= 2: |
|
new_values["SLACK_CLIENT_ID"] = lines[0] |
|
new_values["SLACK_CLIENT_SECRET"] = lines[1] |
|
|
|
|
|
with open(ENV_FILE_PATH, "w", encoding="utf-8") as f: |
|
for var, val in new_values.items(): |
|
f.write(f"{var}={val}\n") |
|
|
|
|
|
def prepare_provider_key_updates(provider: str, multiline_keys: str) -> dict: |
|
lines = [ln.strip() for ln in multiline_keys.splitlines() if ln.strip()] |
|
updates = {} |
|
|
|
if provider == "openai": |
|
for i, key in enumerate(lines, start=1): |
|
updates[f"OPENAI_API_KEY_{i}"] = key |
|
elif provider == "google": |
|
for i, key in enumerate(lines, start=1): |
|
updates[f"GOOGLE_API_KEY_{i}"] = key |
|
elif provider == "xai": |
|
for i, key in enumerate(lines, start=1): |
|
updates[f"XAI_API_KEY_{i}"] = key |
|
elif provider == "anthropic": |
|
for i, key in enumerate(lines, start=1): |
|
updates[f"ANTHROPIC_API_KEY_{i}"] = key |
|
|
|
return updates |
|
|
|
|
|
def prepare_proxy_list_updates(proxy_list: str) -> list: |
|
lines = [proxy.strip() for proxy in proxy_list.splitlines() if proxy.strip()] |
|
proxies = {} |
|
|
|
for i, proxy in enumerate(lines, start=1): |
|
proxies[f"PROXY_{i}"] = proxy |
|
|
|
return proxies |
|
|
|
|
|
def get_folder_size(folder_path: str) -> int: |
|
total_size = 0 |
|
if not os.path.exists(folder_path): |
|
return 0 |
|
for entry in os.scandir(folder_path): |
|
if entry.is_file(): |
|
total_size += entry.stat().st_size |
|
return total_size |
|
|
|
|
|
def clear_folder(folder_path: str): |
|
if os.path.exists(folder_path): |
|
try: |
|
shutil.rmtree(folder_path) |
|
print(f"Successfully cleared upload directory: {folder_path}") |
|
except Exception as e: |
|
print(f'Failed to delete {folder_path}. Reason: {e}') |