|
|
import os
|
|
|
import json
|
|
|
import gradio as gr
|
|
|
import torch
|
|
|
from TeLVE.imagine import ImageCaptioningModel, load_model, generate_caption
|
|
|
from transformers import BertTokenizerFast
|
|
|
from huggingface_hub import hf_hub_download, list_repo_files
|
|
|
|
|
|
|
|
|
MODELS_DIR = "./TeLVE/models"
|
|
|
TOKENIZER_PATH = "./TeLVE/tokenizer"
|
|
|
HF_REPO_ID = "outsu/TeLVE"
|
|
|
MODEL_STATE_FILE = "./TeLVE/model_state.json"
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
|
def load_model_state():
|
|
|
"""Load the model state from JSON file"""
|
|
|
if os.path.exists(MODEL_STATE_FILE):
|
|
|
with open(MODEL_STATE_FILE, 'r') as f:
|
|
|
return json.load(f)
|
|
|
return {"downloaded_models": []}
|
|
|
|
|
|
def save_model_state(state):
|
|
|
"""Save the model state to JSON file"""
|
|
|
os.makedirs(os.path.dirname(MODEL_STATE_FILE), exist_ok=True)
|
|
|
with open(MODEL_STATE_FILE, 'w') as f:
|
|
|
json.dump(state, f)
|
|
|
|
|
|
def list_available_models():
|
|
|
"""List all .pth models in the models directory in reverse order"""
|
|
|
if not os.path.exists(MODELS_DIR):
|
|
|
return []
|
|
|
models = [f for f in os.listdir(MODELS_DIR) if f.endswith('.pth')]
|
|
|
return sorted(models, reverse=True)
|
|
|
|
|
|
def get_hf_model_list():
|
|
|
"""Get list of model files from HuggingFace repository in reverse order"""
|
|
|
try:
|
|
|
files = list_repo_files(HF_REPO_ID)
|
|
|
models = [f.split('/')[-1] for f in files if f.startswith('models/') and f.endswith('.pth')]
|
|
|
return sorted(models, reverse=True)
|
|
|
except Exception as e:
|
|
|
print(f"Error fetching models from HuggingFace: {str(e)}")
|
|
|
return []
|
|
|
|
|
|
def download_missing_models():
|
|
|
"""Download missing models from HuggingFace with state management"""
|
|
|
if not os.path.exists(MODELS_DIR):
|
|
|
os.makedirs(MODELS_DIR)
|
|
|
|
|
|
state = load_model_state()
|
|
|
downloaded_models = set(state["downloaded_models"])
|
|
|
local_models = set(list_available_models())
|
|
|
hf_models = set(get_hf_model_list())
|
|
|
|
|
|
|
|
|
models_to_download = (hf_models - local_models) - downloaded_models
|
|
|
|
|
|
for model in models_to_download:
|
|
|
try:
|
|
|
print(f"Downloading missing model: {model}")
|
|
|
hf_hub_download(
|
|
|
repo_id=HF_REPO_ID,
|
|
|
filename=f"models/{model}",
|
|
|
local_dir=os.path.dirname(MODELS_DIR),
|
|
|
local_dir_use_symlinks=False
|
|
|
)
|
|
|
downloaded_models.add(model)
|
|
|
except Exception as e:
|
|
|
print(f"Error downloading {model}: {str(e)}")
|
|
|
continue
|
|
|
|
|
|
|
|
|
state["downloaded_models"] = list(downloaded_models)
|
|
|
save_model_state(state)
|
|
|
|
|
|
def verify_model_integrity():
|
|
|
"""Verify that all models in state actually exist"""
|
|
|
state = load_model_state()
|
|
|
local_models = set(list_available_models())
|
|
|
state["downloaded_models"] = list(set(state["downloaded_models"]) & local_models)
|
|
|
save_model_state(state)
|
|
|
|
|
|
def generate_description(image, model_name):
|
|
|
"""Generate image caption using selected model"""
|
|
|
try:
|
|
|
|
|
|
model_path = os.path.join(MODELS_DIR, model_name)
|
|
|
if not os.path.exists(model_path):
|
|
|
return "Error: Selected model file not found."
|
|
|
|
|
|
if not os.path.exists(TOKENIZER_PATH):
|
|
|
return "Error: Tokenizer not found. Please make sure you have trained a model first."
|
|
|
|
|
|
model = load_model(model_path)
|
|
|
tokenizer = BertTokenizerFast.from_pretrained(TOKENIZER_PATH)
|
|
|
|
|
|
|
|
|
caption = generate_caption(model, image, tokenizer)
|
|
|
return caption
|
|
|
|
|
|
except Exception as e:
|
|
|
return f"Error occurred: {str(e)}"
|
|
|
|
|
|
|
|
|
def create_interface():
|
|
|
available_models = list_available_models()
|
|
|
if not available_models:
|
|
|
return gr.Interface(
|
|
|
fn=lambda x: "No models found in ./models directory. Please train a model first.",
|
|
|
inputs="image",
|
|
|
outputs="text",
|
|
|
title="TeLVE - Turkish efficient Language Vision Engine 🧿",
|
|
|
description="Error: No models available"
|
|
|
)
|
|
|
|
|
|
interface = gr.Interface(
|
|
|
fn=generate_description,
|
|
|
inputs=[
|
|
|
gr.Image(type="filepath", label="Upload Image"),
|
|
|
gr.Dropdown(choices=available_models, label="Select Model", value=available_models[0])
|
|
|
],
|
|
|
outputs=gr.Textbox(label="Generated Caption"),
|
|
|
title="TeLVE - Turkish efficient Language Vision Engine 🧿",
|
|
|
description="Upload an image to generate a Turkish description.",
|
|
|
examples=[
|
|
|
["./images/mugla.jpg", available_models[0]],
|
|
|
["./images/example.jpg", available_models[0]]
|
|
|
] if os.path.exists("./images") else None
|
|
|
)
|
|
|
return interface
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
print("Verifying model integrity...")
|
|
|
verify_model_integrity()
|
|
|
print("Checking for missing models...")
|
|
|
download_missing_models()
|
|
|
demo = create_interface()
|
|
|
demo.launch(share=True, server_name="0.0.0.0")
|
|
|
|