File size: 5,218 Bytes
			
			| 2ddd289 317198d 2ddd289 5b43647 2ddd289 5b43647 317198d 2ddd289 317198d 2ddd289 77df2b2 2ddd289 77df2b2 2ddd289 5b43647 77df2b2 5b43647 77df2b2 5b43647 317198d 5b43647 317198d 5b43647 317198d 5b43647 317198d 5b43647 317198d 5b43647 2ddd289 77df2b2 2ddd289 77df2b2 2ddd289 77df2b2 2ddd289 317198d 5b43647 2ddd289 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 | import os
import json
import gradio as gr
import torch
from TeLVE.imagine import ImageCaptioningModel, load_model, generate_caption
from transformers import BertTokenizerFast
from huggingface_hub import hf_hub_download, list_repo_files
# Constants
MODELS_DIR = "./TeLVE/models"
TOKENIZER_PATH = "./TeLVE/tokenizer"
HF_REPO_ID = "outsu/TeLVE"
MODEL_STATE_FILE = "./TeLVE/model_state.json"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def load_model_state():
    """Load the model state from JSON file"""
    if os.path.exists(MODEL_STATE_FILE):
        with open(MODEL_STATE_FILE, 'r') as f:
            return json.load(f)
    return {"downloaded_models": []}
def save_model_state(state):
    """Save the model state to JSON file"""
    os.makedirs(os.path.dirname(MODEL_STATE_FILE), exist_ok=True)
    with open(MODEL_STATE_FILE, 'w') as f:
        json.dump(state, f)
def list_available_models():
    """List all .pth models in the models directory in reverse order"""
    if not os.path.exists(MODELS_DIR):
        return []
    models = [f for f in os.listdir(MODELS_DIR) if f.endswith('.pth')]
    return sorted(models, reverse=True)
def get_hf_model_list():
    """Get list of model files from HuggingFace repository in reverse order"""
    try:
        files = list_repo_files(HF_REPO_ID)
        models = [f.split('/')[-1] for f in files if f.startswith('models/') and f.endswith('.pth')]
        return sorted(models, reverse=True)
    except Exception as e:
        print(f"Error fetching models from HuggingFace: {str(e)}")
        return []
def download_missing_models():
    """Download missing models from HuggingFace with state management"""
    if not os.path.exists(MODELS_DIR):
        os.makedirs(MODELS_DIR)
    
    state = load_model_state()
    downloaded_models = set(state["downloaded_models"])
    local_models = set(list_available_models())
    hf_models = set(get_hf_model_list())
    
    # Check for models that need downloading
    models_to_download = (hf_models - local_models) - downloaded_models
    
    for model in models_to_download:
        try:
            print(f"Downloading missing model: {model}")
            hf_hub_download(
                repo_id=HF_REPO_ID,
                filename=f"models/{model}",
                local_dir=os.path.dirname(MODELS_DIR),
                local_dir_use_symlinks=False
            )
            downloaded_models.add(model)
        except Exception as e:
            print(f"Error downloading {model}: {str(e)}")
            continue
    
    # Update state with newly downloaded models
    state["downloaded_models"] = list(downloaded_models)
    save_model_state(state)
def verify_model_integrity():
    """Verify that all models in state actually exist"""
    state = load_model_state()
    local_models = set(list_available_models())
    state["downloaded_models"] = list(set(state["downloaded_models"]) & local_models)
    save_model_state(state)
def generate_description(image, model_name):
    """Generate image caption using selected model"""
    try:
        # Load model and tokenizer
        model_path = os.path.join(MODELS_DIR, model_name)
        if not os.path.exists(model_path):
            return "Error: Selected model file not found."
        
        if not os.path.exists(TOKENIZER_PATH):
            return "Error: Tokenizer not found. Please make sure you have trained a model first."
        model = load_model(model_path)
        tokenizer = BertTokenizerFast.from_pretrained(TOKENIZER_PATH)
        # Generate caption
        caption = generate_caption(model, image, tokenizer)
        return caption
    except Exception as e:
        return f"Error occurred: {str(e)}"
# Create Gradio interface
def create_interface():
    available_models = list_available_models()
    if not available_models:
        return gr.Interface(
            fn=lambda x: "No models found in ./models directory. Please train a model first.",
            inputs="image",
            outputs="text",
            title="TeLVE - Turkish efficient Language Vision Engine 🧿",
            description="Error: No models available"
        )
    interface = gr.Interface(
        fn=generate_description,
        inputs=[
            gr.Image(type="filepath", label="Upload Image"),
            gr.Dropdown(choices=available_models, label="Select Model", value=available_models[0])
        ],
        outputs=gr.Textbox(label="Generated Caption"),
        title="TeLVE - Turkish efficient Language Vision Engine 🧿",
        description="Upload an image to generate a Turkish description.",
        examples=[
            ["./images/mugla.jpg", available_models[0]],
            ["./images/example.jpg", available_models[0]]
        ] if os.path.exists("./images") else None
    )
    return interface
if __name__ == "__main__":
    print("Verifying model integrity...")
    verify_model_integrity()
    print("Checking for missing models...")
    download_missing_models()
    demo = create_interface()
    demo.launch(share=True, server_name="0.0.0.0")
 | 
