| from __future__ import annotations |
|
|
| import logging |
| from dataclasses import dataclass |
| from typing import Optional |
| import json |
| import re |
| import urllib.request |
|
|
| from src.exceptions import GenerationError |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| @dataclass(frozen=True) |
| class TextGenResult: |
| text: str |
|
|
|
|
| def _sanitize_text(s: str) -> str: |
| """Remove common failure patterns (echoing rules, bullets, repetitions).""" |
| s = s.strip() |
|
|
| |
| s = re.sub(r"^\s*[-*•]\s+", "", s, flags=re.MULTILINE) |
|
|
| |
| bad_patterns = [ |
| r"(?i)\blength\s*:\s*\d+\s*[-–]\s*\d+\s*sentences\b.*", |
| r"(?i)\brules\s*:\b.*", |
| r"(?i)\bno bullet points\b.*", |
| r"(?i)\bno repetition\b.*", |
| r"(?i)\bno meta commentary\b.*", |
| r"(?i)\bdescribe only\b.*", |
| ] |
| for pat in bad_patterns: |
| s = re.sub(pat, "", s).strip() |
|
|
| |
| s = re.sub(r"\n{3,}", "\n\n", s) |
| s = re.sub(r"[ \t]{2,}", " ", s) |
|
|
| |
| lines = [ln.strip() for ln in s.splitlines() if ln.strip()] |
| deduped = [] |
| for ln in lines: |
| if not deduped or deduped[-1] != ln: |
| deduped.append(ln) |
| s = "\n".join(deduped).strip() |
|
|
| return s |
|
|
|
|
| def _ollama_generate( |
| prompt: str, |
| model: str = "qwen2:7b", |
| temperature: float = 0.7, |
| top_p: float = 0.9, |
| num_predict: int = 180, |
| host: str = "http://localhost:11434", |
| ) -> str: |
| """ |
| Calls Ollama local server: POST /api/generate |
| """ |
| url = f"{host.rstrip('/')}/api/generate" |
| payload = { |
| "model": model, |
| "prompt": prompt, |
| "stream": False, |
| "options": { |
| "temperature": temperature, |
| "top_p": top_p, |
| "num_predict": num_predict, |
| }, |
| } |
|
|
| req = urllib.request.Request( |
| url, |
| data=json.dumps(payload).encode("utf-8"), |
| headers={"Content-Type": "application/json"}, |
| method="POST", |
| ) |
|
|
| try: |
| with urllib.request.urlopen(req, timeout=600) as resp: |
| data = json.loads(resp.read().decode("utf-8")) |
| text = data.get("response", "").strip() |
| logger.debug("Ollama generated %d chars", len(text)) |
| return text |
| except Exception as e: |
| logger.error("Ollama call failed on %s: %s", host, e) |
| raise GenerationError( |
| f"Ollama call failed. Is Ollama running on {host}? Error: {e}", |
| modality="text", backend=f"ollama/{model}", |
| ) from e |
|
|
|
|
| class TextGenerator: |
| """ |
| Option A (recommended): Ollama text generator (instruction-following). |
| Falls back to HF pipeline if use_ollama=False. |
| """ |
|
|
| def __init__( |
| self, |
| use_ollama: bool = True, |
| ollama_model: str = "qwen2:7b", |
| ollama_host: str = "http://localhost:11434", |
| max_new_tokens: int = 160, |
| hf_model_name: str = "gpt2", |
| ): |
| self.use_ollama = use_ollama |
| self.ollama_model = ollama_model |
| self.ollama_host = ollama_host |
| self.max_new_tokens = max_new_tokens |
| self.hf_model_name = hf_model_name |
|
|
| self._hf_pipe = None |
| if not self.use_ollama: |
| from transformers import pipeline |
|
|
| self._hf_pipe = pipeline("text-generation", model=self.hf_model_name) |
|
|
| def generate(self, prompt: str, deterministic: bool = True) -> TextGenResult: |
| |
| wrapped_prompt = """You are a concise descriptive writer. |
| |
| Write a literal description of the same scene. Follow these rules: |
| - Write 3 to 5 natural sentences. |
| - No bullet points, no numbered lists. |
| - No repetition. |
| - No meta commentary (do not mention rules, prompts, or constraints). |
| - Focus on concrete visual details AND the likely audio ambience. |
| |
| SCENE PLAN: |
| """ |
| wrapped_prompt = f"{wrapped_prompt}{prompt}\n\nNow write the description:\n" |
|
|
| if self.use_ollama: |
| raw = _ollama_generate( |
| prompt=wrapped_prompt, |
| model=self.ollama_model, |
| host=self.ollama_host, |
| temperature=0.0 if deterministic else 0.7, |
| top_p=1.0 if deterministic else 0.9, |
| num_predict=max(self.max_new_tokens, 120), |
| ) |
| clean = _sanitize_text(raw) |
|
|
| |
| return TextGenResult(text=clean if clean else raw) |
|
|
| |
| outputs = self._hf_pipe( |
| wrapped_prompt, |
| max_new_tokens=self.max_new_tokens, |
| do_sample=not deterministic, |
| temperature=0.0 if deterministic else 0.9, |
| top_p=1.0 if deterministic else 0.95, |
| num_return_sequences=1, |
| ) |
| text = outputs[0]["generated_text"] |
| text = _sanitize_text(text) |
| return TextGenResult(text=text) |
|
|
|
|
| def generate_text( |
| prompt: str, |
| use_ollama: bool = True, |
| deterministic: bool = True, |
| ollama_model: str = "qwen2:7b", |
| ollama_host: str = "http://localhost:11434", |
| max_new_tokens: int = 160, |
| hf_model_name: str = "gpt2", |
| ) -> str: |
| generator = TextGenerator( |
| use_ollama=use_ollama, |
| ollama_model=ollama_model, |
| ollama_host=ollama_host, |
| max_new_tokens=max_new_tokens, |
| hf_model_name=hf_model_name, |
| ) |
| return generator.generate(prompt, deterministic=deterministic).text |
|
|