| | |
| | import os |
| | |
| | os.environ.setdefault("CUDA_VISIBLE_DEVICES", "") |
| | os.environ.setdefault("TRANSFORMERS_NO_ADVISORY_WARNINGS", "1") |
| |
|
| | import sys |
| | import traceback |
| | import json |
| | from typing import List, Optional |
| |
|
| | import requests |
| | import torch |
| | import torch.nn.functional as F |
| | from datasets import load_dataset |
| | from transformers import ( |
| | AutoProcessor, |
| | AutoModel, |
| | AutoTokenizer, |
| | AutoModelForCausalLM, |
| | ) |
| | from PIL import Image |
| | import gradio as gr |
| | from tqdm import tqdm |
| |
|
| | |
| | |
| | |
| | SIGLIP_MODEL_ID = "EYEDOL/siglipFULL-agri-finetuned" |
| | LLAVA_MODEL_ID = "liuhaotian/llava-v1.6-vicuna-7b" |
| | DATASET_TEMPLATE = "EYEDOL/AGRILLAVA-image-text{}" |
| | NUM_DATASETS = 1 |
| | BATCH_SIZE = 16 |
| | TOP_K_DEFAULT = 3 |
| |
|
| | |
| | HF_API_URL = "https://router.huggingface.co/hf-inference" |
| | HUGGINGFACE_TOKEN = os.environ.get("HUGGINGFACE_TOKEN", None) |
| |
|
| | |
| | device = torch.device("cpu") |
| | print("Running on device:", device) |
| |
|
| | |
| | |
| | |
| | print("Loading datasets and computing SigLip text embeddings (startup)...") |
| | texts_all: List[str] = [] |
| | for i in range(1, NUM_DATASETS + 1): |
| | ds = load_dataset(DATASET_TEMPLATE.format(i), split="train") |
| | texts_all.extend(ds["text"]) |
| |
|
| | siglip_processor = AutoProcessor.from_pretrained(SIGLIP_MODEL_ID) |
| | siglip_model = AutoModel.from_pretrained(SIGLIP_MODEL_ID).to(device) |
| | siglip_model.eval() |
| |
|
| | |
| | text_embeds_parts = [] |
| | for i in tqdm(range(0, len(texts_all), BATCH_SIZE), desc="Encoding texts (CPU)"): |
| | batch_texts = texts_all[i : i + BATCH_SIZE] |
| | inputs = siglip_processor(text=batch_texts, padding=True, truncation=True, return_tensors="pt") |
| | with torch.no_grad(): |
| | text_embeds = siglip_model.get_text_features(**inputs) |
| | text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) |
| | text_embeds_parts.append(text_embeds.cpu()) |
| | del inputs, text_embeds |
| | if text_embeds_parts: |
| | text_embeds_all = torch.cat(text_embeds_parts, dim=0) |
| | else: |
| | text_embeds_all = torch.empty((0, 0)) |
| | print(f"Encoded {len(texts_all)} texts. Embeddings shape: {text_embeds_all.shape}") |
| |
|
| | |
| | |
| | |
| | llava_tokenizer: Optional[AutoTokenizer] = None |
| | llava_model = None |
| | llava_mode: Optional[str] = None |
| | load_errors = [] |
| |
|
| | |
| | try: |
| | |
| | from llava.model import LlavaForCausalLM |
| |
|
| | print("Loading LlavaForCausalLM from installed 'llava' package (CPU)...") |
| | llava_tokenizer = AutoTokenizer.from_pretrained(LLAVA_MODEL_ID, use_fast=False) |
| | llava_model = LlavaForCausalLM.from_pretrained( |
| | LLAVA_MODEL_ID, |
| | device_map={"": "cpu"}, |
| | torch_dtype=torch.float32, |
| | low_cpu_mem_usage=True, |
| | ) |
| | llava_model.to(device) |
| | llava_model.eval() |
| | llava_mode = "local" |
| | print("β
Llava loaded from installed package.") |
| | except Exception: |
| | tb_local = traceback.format_exc() |
| | load_errors.append(("local_llava_import", tb_local)) |
| | print("Local llava import failed β will try trust_remote_code fallback. See logs for details.") |
| |
|
| | |
| | if llava_mode is None: |
| | try: |
| | print("Attempting AutoModelForCausalLM.from_pretrained(..., trust_remote_code=True) (CPU)...") |
| | llava_tokenizer = AutoTokenizer.from_pretrained(LLAVA_MODEL_ID, use_fast=False) |
| | llava_model = AutoModelForCausalLM.from_pretrained( |
| | LLAVA_MODEL_ID, |
| | trust_remote_code=True, |
| | device_map={"": "cpu"}, |
| | torch_dtype=torch.float32, |
| | low_cpu_mem_usage=True, |
| | ) |
| | llava_model.to(device) |
| | llava_model.eval() |
| | llava_mode = "trust_remote_code" |
| | print("β
Llava loaded via trust_remote_code fallback.") |
| | except Exception: |
| | tb_trust = traceback.format_exc() |
| | load_errors.append(("fallback_trust_remote_code", tb_trust)) |
| | print("trust_remote_code fallback failed β will try HF router if token provided.") |
| |
|
| | |
| | if llava_mode is None and HUGGINGFACE_TOKEN: |
| | llava_mode = "hf_api" |
| | print("No usable local model found. Will use Hugging Face router Inference API for generation (HUGGINGFACE_TOKEN detected).") |
| |
|
| | if llava_mode is None: |
| | print("WARNING: No Llava model available and no HUGGINGFACE_TOKEN supplied. Generation will return an actionable error.") |
| | for name, tb in load_errors: |
| | print(f"--- {name} traceback ---\n{tb}") |
| |
|
| | |
| | |
| | |
| | def call_hf_inference_api(prompt: str, max_new_tokens: int = 256, temperature: float = 0.0): |
| | if not HUGGINGFACE_TOKEN: |
| | raise RuntimeError("HUGGINGFACE_TOKEN not set; cannot call Hugging Face Inference API.") |
| | headers = {"Authorization": f"Bearer {HUGGINGFACE_TOKEN}", "Content-Type": "application/json"} |
| | payload = { |
| | "model": LLAVA_MODEL_ID, |
| | "inputs": prompt, |
| | "parameters": {"max_new_tokens": max_new_tokens, "temperature": temperature}, |
| | "options": {"wait_for_model": True}, |
| | } |
| | resp = requests.post(HF_API_URL, headers=headers, json=payload, timeout=300) |
| | if resp.status_code != 200: |
| | raise RuntimeError(f"HF Inference API error {resp.status_code}: {resp.text}") |
| | data = resp.json() |
| | |
| | if isinstance(data, list) and data and isinstance(data[0], dict) and "generated_text" in data[0]: |
| | return data[0]["generated_text"] |
| | if isinstance(data, dict) and "generated_text" in data: |
| | return data["generated_text"] |
| | if isinstance(data, str): |
| | return data |
| | return json.dumps(data) |
| |
|
| | |
| | |
| | |
| | def retrieve_top_k_texts(image: Image.Image, k: int = TOP_K_DEFAULT): |
| | inputs = siglip_processor(images=image, return_tensors="pt") |
| | with torch.no_grad(): |
| | img_embed = siglip_model.get_image_features(**inputs) |
| | img_embed = img_embed / img_embed.norm(p=2, dim=-1, keepdim=True) |
| |
|
| | sims = F.cosine_similarity(img_embed.cpu(), text_embeds_all) |
| | topk = torch.topk(sims, k) |
| | results = [(texts_all[idx.item()], float(score)) for idx, score in zip(topk.indices, topk.values)] |
| | return results |
| |
|
| | def llava_answer(image: Image.Image, retrieved_texts, question: str, max_tokens: int = 256): |
| | context_text = "\n".join([f"Retrieved Text: {t}" for t, _ in retrieved_texts]) |
| | prompt = ( |
| | "You are an agricultural assistant. Use the provided retrieved texts to answer concisely.\n\n" |
| | f"Retrieved texts:\n{context_text}\n\n" |
| | f"User question: {question}\n\n" |
| | "Provide a concise, actionable answer and crop suggestions when applicable." |
| | ) |
| |
|
| | if llava_mode in ("local", "trust_remote_code"): |
| | inputs = llava_tokenizer(prompt, return_tensors="pt") |
| | inputs = {k: v.to(device) for k, v in inputs.items()} |
| | with torch.no_grad(): |
| | output_ids = llava_model.generate(**inputs, max_new_tokens=max_tokens) |
| | resp = llava_tokenizer.decode(output_ids[0], skip_special_tokens=True) |
| | return resp |
| | elif llava_mode == "hf_api": |
| | return call_hf_inference_api(prompt, max_new_tokens=max_tokens) |
| | else: |
| | err = ( |
| | "No Llava model is available for generation.\n\n" |
| | "Fix options:\n" |
| | "1) Install the LLaVA repo in requirements.txt and rebuild the Space:\n" |
| | " git+https://github.com/haotian-liu/LLaVA.git@main\n" |
| | "2) Or add a valid Hugging Face API token as HUGGINGFACE_TOKEN in Space secrets to use the router.\n\n" |
| | "Check Space logs for detailed tracebacks printed at startup." |
| | ) |
| | return err |
| |
|
| | |
| | |
| | |
| | def gradio_pipeline(image: Image.Image, question: str, k: int = TOP_K_DEFAULT): |
| | if image is None or not question: |
| | return None, "Please provide both an image and a question." |
| | retrieved = retrieve_top_k_texts(image, k=int(k)) |
| | try: |
| | answer = llava_answer(image, retrieved, question) |
| | except Exception as e: |
| | tb = traceback.format_exc() |
| | answer = f"Error during generation: {e}\n\nTraceback:\n{tb}" |
| | return image, answer |
| |
|
| | with gr.Blocks(title="Agri Image + Question β Llava Response (robust)") as demo: |
| | gr.Markdown( |
| | "## Agri Image QA\n\nThis app preloads SigLip embeddings at startup. " |
| | "Generation uses a local Llava model if available, otherwise the Hugging Face router Inference API " |
| | "(requires HUGGINGFACE_TOKEN secret in Space settings)." |
| | ) |
| | with gr.Row(): |
| | img_in = gr.Image(type="pil") |
| | out_img = gr.Image(type="pil", label="Image") |
| | question_input = gr.Textbox(label="Question about the image", lines=2) |
| | k_slider = gr.Slider(minimum=1, maximum=10, step=1, value=TOP_K_DEFAULT, label="Top-k retrieval") |
| | txt_out = gr.Textbox(label="Llava Response", lines=12) |
| | run_btn = gr.Button("Generate Answer") |
| | run_btn.click(fn=gradio_pipeline, inputs=[img_in, question_input, k_slider], outputs=[out_img, txt_out]) |
| |
|
| | if __name__ == "__main__": |
| | demo.launch(server_name="0.0.0.0", share=False) |
| |
|