yash3056/tmp / llama-cpp.py
yash3056's picture
download
raw
9.64 kB
import os
import asyncio
import json
import re
import pandas as pd
import httpx
from tqdm import tqdm
from pathlib import Path
# ── Config ────────────────────────────────────────────────────────────────────
LLAMA_BASE_URL = os.environ.get("LLAMA_BASE_URL", "http://localhost:8099")
INPUT_CSV = "/home/aza/workspace/textai-reason/Suicide_Detection.csv"
OUTPUT_DIR = Path("/home/aza/workspace/textai-reason/test_llama")
OUTPUT_CSV = "/home/aza/workspace/textai-reason/test_llama/output.csv"
# How many requests to fire concurrently — tune to match --parallel on server
CONCURRENCY = int(os.environ.get("CONCURRENCY", "4"))
MAX_TOKENS = int(os.environ.get("MAX_TOKENS", "8192"))
TEMPERATURE = float(os.environ.get("TEMPERATURE","0.7"))
MAX_RETRIES = 3 # retries when label is invalid
VALID_LABELS = {"depression", "non-depression"}
CHAT_URL = f"{LLAMA_BASE_URL}/v1/chat/completions"
SYSTEM_PROMPT = """You are a clinical text classifier. Classify input text as exactly one of two labels: "depression" or "non-depression".
Classify as "depression" if the text shows:
- Suicidal ideation, self-harm urges, or wishes to die
- Hopelessness, worthlessness, or pervasive emptiness
- Anhedonia — loss of interest in previously enjoyed activities
- Inability to function, extreme fatigue, or emotional numbness
- Expressions of being trapped, a burden, or without a future
Classify as "non-depression" if the text shows:
- Neutral, factual, or informational content
- Situational stress without persistent clinical markers
- Positive, stable, or curious emotional states
STRICT OUTPUT RULE: Output exactly one of these two strings and nothing else:
depression
non-depression"""
USER_PROMPT = """<input_text>
{input_text}
</input_text>
Classification:"""
# ── Helpers ───────────────────────────────────────────────────────────────────
def build_messages(text: str) -> list[dict]:
return [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": USER_PROMPT.format(input_text=text)},
]
def parse_output(raw: str) -> tuple[str, str]:
"""Extract (thinking, label) from model output."""
raw = raw.strip() if raw else ""
if "<|channel>" in raw and "<channel|>" in raw:
start = raw.index("<|channel>") + len("<|channel>")
end = raw.index("<channel|>")
think_body = raw[start:end]
if think_body.startswith("thought"):
think_body = think_body[len("thought"):].lstrip("\n")
thinking = think_body.strip()
label = raw[end + len("<channel|>"):].strip()
elif "<think>" in raw and "</think>" in raw:
start = raw.index("<think>") + len("<think>")
end = raw.index("</think>")
thinking = raw[start:end].strip()
label = raw[end + len("</think>"):].strip()
else:
thinking = ""
label = raw
# Normalise label — keep only first token
label = re.split(r"[\s\n]+", label.strip())[0].lower()
return thinking, label
def get_resume_index(output_dir: Path) -> int:
"""Count already-completed rows from JSONL shards in output_dir."""
if not output_dir.exists():
return 0
total = 0
for fpath in sorted(output_dir.glob("shard_*.jsonl")):
try:
lines = [l for l in fpath.read_text().splitlines() if l.strip()]
total += len(lines)
except Exception:
pass
if total > 0:
print(f"Resuming from row {total:,} (found existing shards).")
return total
def next_shard_path(output_dir: Path) -> Path:
existing = list(output_dir.glob("shard_*.jsonl"))
idx = len(existing)
return output_dir / f"shard_{idx:06d}.jsonl"
# ── Async request ─────────────────────────────────────────────────────────────
async def classify_one(
client: httpx.AsyncClient,
sem: asyncio.Semaphore,
text: str, # only the raw text — no label/class column
row_index: int,
) -> tuple[str, str, str]:
"""Send one chat-completion request; retry until label is valid.
Returns (text, thinking, label).
"""
payload = {
"messages": build_messages(text), # prompt uses text only
"max_tokens": MAX_TOKENS,
"temperature": TEMPERATURE,
"stream": False,
}
thinking = ""
label = ""
for attempt in range(1, MAX_RETRIES + 1):
async with sem:
resp = await client.post(CHAT_URL, json=payload, timeout=300.0)
resp.raise_for_status()
message = resp.json()["choices"][0]["message"]
# llama.cpp with --reasoning on exposes separate fields
thinking = (message.get("reasoning_content") or "").strip()
raw = (message.get("content") or "").strip()
# Fallback: parse <think> tags when reasoning_content is absent
if not thinking:
thinking, raw = parse_output(raw)
# Normalise: keep only the first token, lowercase
label = re.split(r"[\s\n]+", raw)[0].lower() if raw else ""
if label in VALID_LABELS:
break # valid — no retry needed
print(f"\n[WARN] row {row_index}: invalid label {label!r} "
f"(attempt {attempt}/{MAX_RETRIES}) — retrying")
if label not in VALID_LABELS:
print(f"\n[ERROR] row {row_index}: still invalid label {label!r} "
f"after {MAX_RETRIES} attempts — storing as-is")
return text, thinking, label
# ── Main ──────────────────────────────────────────────────────────────────────
async def main_async():
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
print(f"Reading {INPUT_CSV} ...")
df = pd.read_csv(INPUT_CSV)
total = len(df)
print(f"Total rows: {total:,}")
print(f"Llama server : {CHAT_URL}")
print(f"Concurrency : {CONCURRENCY}")
print(f"Output dir : {OUTPUT_DIR}")
start_row = get_resume_index(OUTPUT_DIR)
df_todo = df.iloc[start_row:].reset_index(drop=True)
if len(df_todo) == 0:
print("All rows already processed.")
_merge_and_save(OUTPUT_DIR, total)
return
sem = asyncio.Semaphore(CONCURRENCY)
pbar = tqdm(total=total, initial=start_row, unit="row",
dynamic_ncols=True, desc="Classifying")
# Fire requests in sliding-window batches so we can checkpoint often
BATCH_SIZE = CONCURRENCY * 4 # checkpoint every N rows
# Keep only the columns we need; class/label col stays out of the prompt
text_col = "text"
all_texts = df_todo[text_col].tolist() # pure text — no class
all_indices = list(range(start_row, start_row + len(df_todo)))
async with httpx.AsyncClient() as client:
for batch_start in range(0, len(all_texts), BATCH_SIZE):
batch_texts = all_texts[batch_start:batch_start + BATCH_SIZE]
batch_indices = all_indices[batch_start:batch_start + BATCH_SIZE]
tasks = [
classify_one(client, sem, txt, idx)
for txt, idx in zip(batch_texts, batch_indices)
]
results = await asyncio.gather(*tasks, return_exceptions=True)
shard_rows = []
for res in results:
if isinstance(res, Exception):
print(f"\n[WARN] Request failed: {res}")
continue
txt, thinking, label = res
# Output row: text (exact CSV value) + thinking + label only
shard_rows.append({
"text": txt,
"thinking": thinking,
"label": label,
})
# Write shard JSONL checkpoint
if shard_rows:
shard_path = next_shard_path(OUTPUT_DIR)
with open(shard_path, "w", encoding="utf-8") as f:
for r in shard_rows:
f.write(json.dumps(r, ensure_ascii=False) + "\n")
pbar.update(len(shard_rows))
pbar.close()
print(f"\nAll rows processed. Merging shards → {OUTPUT_CSV}")
_merge_and_save(OUTPUT_DIR, total)
def _merge_and_save(output_dir: Path, total: int):
"""Merge all JSONL shards into a single CSV."""
shards = sorted(output_dir.glob("shard_*.jsonl"))
if not shards:
print("No shard files found — nothing to merge.")
return
frames = []
for s in shards:
lines = [l for l in s.read_text().splitlines() if l.strip()]
if lines:
frames.append(pd.DataFrame([json.loads(l) for l in lines]))
if not frames:
print("Shards are empty — nothing to merge.")
return
final = pd.concat(frames, ignore_index=True)
final.to_csv(OUTPUT_CSV, index=False, encoding="utf-8")
print(f"Saved {len(final):,} / {total:,} rows → {OUTPUT_CSV}")
print("\nLabel distribution:")
print(final["label"].value_counts())
def main():
asyncio.run(main_async())
if __name__ == "__main__":
main()

Xet Storage Details

Size:
9.64 kB
·
Xet hash:
6dcbb71debcef7434b4a3f3545524e4a80824c877a33cdd9e86c8277dfcda266

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.