jina-code-debugger / executor.py
Girinath11's picture
Upload 7 files
87ce049 verified
import threading
import os
from jina import Executor, requests
from docarray import BaseDoc, DocList
# transformers imports are done lazily in _ensure_model to prevent heavy import on module load
class CodeInput(BaseDoc):
code: str
class CodeOutput(BaseDoc):
result: str
class CodeDebugger(Executor):
"""
Jina Executor that lazy-loads a Hugging Face seq2seq model on first request.
Use environment variable JINA_SKIP_MODEL_LOAD=1 to skip model loading (useful in CI/builds).
"""
def __init__(self, model_name: str = "Girinath11/aiml_code_debug_model", **kwargs):
super().__init__(**kwargs)
self.model_name = model_name
self._lock = threading.Lock()
self.tokenizer = None
self.model = None
# optional: allow overriding max_new_tokens via env var
self.max_new_tokens = int(os.environ.get("MAX_NEW_TOKENS", "256"))
def _ensure_model(self):
"""
Load tokenizer & model once in a thread-safe manner.
If JINA_SKIP_MODEL_LOAD is set to "1", skip loading (helpful for hub builds).
"""
skip = os.environ.get("JINA_SKIP_MODEL_LOAD", "0") == "1"
if skip:
self.logger.warning("JINA_SKIP_MODEL_LOAD=1 set — skipping HF model load.")
return
if self.model is None or self.tokenizer is None:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM # lazy import
with self._lock:
if self.model is None or self.tokenizer is None:
self.logger.info(f"Loading model {self.model_name} ...")
# If HF_TOKEN is set, transformers will use it automatically via huggingface-cli login
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name)
self.logger.info("Model loaded successfully.")
@requests
def debug(self, docs: DocList[CodeInput], **kwargs) -> DocList[CodeOutput]:
# Lazy load model at request time
self._ensure_model()
results = []
if self.model is None or self.tokenizer is None:
# If model was skipped, return a helpful message
for _ in docs:
results.append(CodeOutput(result="Model not loaded (JINA_SKIP_MODEL_LOAD=1)."))
return DocList[CodeOutput](results)
for doc in docs:
# make sure input is string
code_text = doc.code if isinstance(doc.code, str) else str(doc.code)
inputs = self.tokenizer(code_text, return_tensors="pt", padding=True, truncation=True)
outputs = self.model.generate(**inputs, max_new_tokens=self.max_new_tokens)
result = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
results.append(CodeOutput(result=result))
return DocList[CodeOutput](results)