File size: 5,218 Bytes
2ddd289
317198d
2ddd289
 
 
 
5b43647
2ddd289
 
 
 
5b43647
317198d
2ddd289
 
317198d
 
 
 
 
 
 
 
 
 
 
 
 
2ddd289
77df2b2
2ddd289
 
77df2b2
 
2ddd289
5b43647
77df2b2
5b43647
 
77df2b2
 
5b43647
 
 
 
 
317198d
5b43647
 
 
317198d
 
5b43647
 
 
317198d
 
 
 
5b43647
 
 
 
 
 
 
 
317198d
5b43647
 
317198d
 
 
 
 
 
 
 
 
 
 
 
5b43647
2ddd289
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77df2b2
2ddd289
 
 
 
 
 
 
 
 
 
77df2b2
 
2ddd289
77df2b2
 
2ddd289
 
 
 
 
317198d
 
5b43647
 
2ddd289
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import os
import json
import gradio as gr
import torch
from TeLVE.imagine import ImageCaptioningModel, load_model, generate_caption
from transformers import BertTokenizerFast
from huggingface_hub import hf_hub_download, list_repo_files

# Constants
MODELS_DIR = "./TeLVE/models"
TOKENIZER_PATH = "./TeLVE/tokenizer"
HF_REPO_ID = "outsu/TeLVE"
MODEL_STATE_FILE = "./TeLVE/model_state.json"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def load_model_state():
    """Load the model state from JSON file"""
    if os.path.exists(MODEL_STATE_FILE):
        with open(MODEL_STATE_FILE, 'r') as f:
            return json.load(f)
    return {"downloaded_models": []}

def save_model_state(state):
    """Save the model state to JSON file"""
    os.makedirs(os.path.dirname(MODEL_STATE_FILE), exist_ok=True)
    with open(MODEL_STATE_FILE, 'w') as f:
        json.dump(state, f)

def list_available_models():
    """List all .pth models in the models directory in reverse order"""
    if not os.path.exists(MODELS_DIR):
        return []
    models = [f for f in os.listdir(MODELS_DIR) if f.endswith('.pth')]
    return sorted(models, reverse=True)

def get_hf_model_list():
    """Get list of model files from HuggingFace repository in reverse order"""
    try:
        files = list_repo_files(HF_REPO_ID)
        models = [f.split('/')[-1] for f in files if f.startswith('models/') and f.endswith('.pth')]
        return sorted(models, reverse=True)
    except Exception as e:
        print(f"Error fetching models from HuggingFace: {str(e)}")
        return []

def download_missing_models():
    """Download missing models from HuggingFace with state management"""
    if not os.path.exists(MODELS_DIR):
        os.makedirs(MODELS_DIR)
    
    state = load_model_state()
    downloaded_models = set(state["downloaded_models"])
    local_models = set(list_available_models())
    hf_models = set(get_hf_model_list())
    
    # Check for models that need downloading
    models_to_download = (hf_models - local_models) - downloaded_models
    
    for model in models_to_download:
        try:
            print(f"Downloading missing model: {model}")
            hf_hub_download(
                repo_id=HF_REPO_ID,
                filename=f"models/{model}",
                local_dir=os.path.dirname(MODELS_DIR),
                local_dir_use_symlinks=False
            )
            downloaded_models.add(model)
        except Exception as e:
            print(f"Error downloading {model}: {str(e)}")
            continue
    
    # Update state with newly downloaded models
    state["downloaded_models"] = list(downloaded_models)
    save_model_state(state)

def verify_model_integrity():
    """Verify that all models in state actually exist"""
    state = load_model_state()
    local_models = set(list_available_models())
    state["downloaded_models"] = list(set(state["downloaded_models"]) & local_models)
    save_model_state(state)

def generate_description(image, model_name):
    """Generate image caption using selected model"""
    try:
        # Load model and tokenizer
        model_path = os.path.join(MODELS_DIR, model_name)
        if not os.path.exists(model_path):
            return "Error: Selected model file not found."
        
        if not os.path.exists(TOKENIZER_PATH):
            return "Error: Tokenizer not found. Please make sure you have trained a model first."

        model = load_model(model_path)
        tokenizer = BertTokenizerFast.from_pretrained(TOKENIZER_PATH)

        # Generate caption
        caption = generate_caption(model, image, tokenizer)
        return caption

    except Exception as e:
        return f"Error occurred: {str(e)}"

# Create Gradio interface
def create_interface():
    available_models = list_available_models()
    if not available_models:
        return gr.Interface(
            fn=lambda x: "No models found in ./models directory. Please train a model first.",
            inputs="image",
            outputs="text",
            title="TeLVE - Turkish efficient Language Vision Engine 🧿",
            description="Error: No models available"
        )

    interface = gr.Interface(
        fn=generate_description,
        inputs=[
            gr.Image(type="filepath", label="Upload Image"),
            gr.Dropdown(choices=available_models, label="Select Model", value=available_models[0])
        ],
        outputs=gr.Textbox(label="Generated Caption"),
        title="TeLVE - Turkish efficient Language Vision Engine 🧿",
        description="Upload an image to generate a Turkish description.",
        examples=[
            ["./images/mugla.jpg", available_models[0]],
            ["./images/example.jpg", available_models[0]]
        ] if os.path.exists("./images") else None
    )
    return interface

if __name__ == "__main__":
    print("Verifying model integrity...")
    verify_model_integrity()
    print("Checking for missing models...")
    download_missing_models()
    demo = create_interface()
    demo.launch(share=True, server_name="0.0.0.0")