TeLVE / app.py
outsu's picture
Minor changes on app.py
77df2b2 verified
import os
import json
import gradio as gr
import torch
from TeLVE.imagine import ImageCaptioningModel, load_model, generate_caption
from transformers import BertTokenizerFast
from huggingface_hub import hf_hub_download, list_repo_files
# Constants
MODELS_DIR = "./TeLVE/models"
TOKENIZER_PATH = "./TeLVE/tokenizer"
HF_REPO_ID = "outsu/TeLVE"
MODEL_STATE_FILE = "./TeLVE/model_state.json"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def load_model_state():
"""Load the model state from JSON file"""
if os.path.exists(MODEL_STATE_FILE):
with open(MODEL_STATE_FILE, 'r') as f:
return json.load(f)
return {"downloaded_models": []}
def save_model_state(state):
"""Save the model state to JSON file"""
os.makedirs(os.path.dirname(MODEL_STATE_FILE), exist_ok=True)
with open(MODEL_STATE_FILE, 'w') as f:
json.dump(state, f)
def list_available_models():
"""List all .pth models in the models directory in reverse order"""
if not os.path.exists(MODELS_DIR):
return []
models = [f for f in os.listdir(MODELS_DIR) if f.endswith('.pth')]
return sorted(models, reverse=True)
def get_hf_model_list():
"""Get list of model files from HuggingFace repository in reverse order"""
try:
files = list_repo_files(HF_REPO_ID)
models = [f.split('/')[-1] for f in files if f.startswith('models/') and f.endswith('.pth')]
return sorted(models, reverse=True)
except Exception as e:
print(f"Error fetching models from HuggingFace: {str(e)}")
return []
def download_missing_models():
"""Download missing models from HuggingFace with state management"""
if not os.path.exists(MODELS_DIR):
os.makedirs(MODELS_DIR)
state = load_model_state()
downloaded_models = set(state["downloaded_models"])
local_models = set(list_available_models())
hf_models = set(get_hf_model_list())
# Check for models that need downloading
models_to_download = (hf_models - local_models) - downloaded_models
for model in models_to_download:
try:
print(f"Downloading missing model: {model}")
hf_hub_download(
repo_id=HF_REPO_ID,
filename=f"models/{model}",
local_dir=os.path.dirname(MODELS_DIR),
local_dir_use_symlinks=False
)
downloaded_models.add(model)
except Exception as e:
print(f"Error downloading {model}: {str(e)}")
continue
# Update state with newly downloaded models
state["downloaded_models"] = list(downloaded_models)
save_model_state(state)
def verify_model_integrity():
"""Verify that all models in state actually exist"""
state = load_model_state()
local_models = set(list_available_models())
state["downloaded_models"] = list(set(state["downloaded_models"]) & local_models)
save_model_state(state)
def generate_description(image, model_name):
"""Generate image caption using selected model"""
try:
# Load model and tokenizer
model_path = os.path.join(MODELS_DIR, model_name)
if not os.path.exists(model_path):
return "Error: Selected model file not found."
if not os.path.exists(TOKENIZER_PATH):
return "Error: Tokenizer not found. Please make sure you have trained a model first."
model = load_model(model_path)
tokenizer = BertTokenizerFast.from_pretrained(TOKENIZER_PATH)
# Generate caption
caption = generate_caption(model, image, tokenizer)
return caption
except Exception as e:
return f"Error occurred: {str(e)}"
# Create Gradio interface
def create_interface():
available_models = list_available_models()
if not available_models:
return gr.Interface(
fn=lambda x: "No models found in ./models directory. Please train a model first.",
inputs="image",
outputs="text",
title="TeLVE - Turkish efficient Language Vision Engine 🧿",
description="Error: No models available"
)
interface = gr.Interface(
fn=generate_description,
inputs=[
gr.Image(type="filepath", label="Upload Image"),
gr.Dropdown(choices=available_models, label="Select Model", value=available_models[0])
],
outputs=gr.Textbox(label="Generated Caption"),
title="TeLVE - Turkish efficient Language Vision Engine 🧿",
description="Upload an image to generate a Turkish description.",
examples=[
["./images/mugla.jpg", available_models[0]],
["./images/example.jpg", available_models[0]]
] if os.path.exists("./images") else None
)
return interface
if __name__ == "__main__":
print("Verifying model integrity...")
verify_model_integrity()
print("Checking for missing models...")
download_missing_models()
demo = create_interface()
demo.launch(share=True, server_name="0.0.0.0")