from PIL import Image import io import torch import os import numpy as np from CNN_encoder import CNN_Encoder from distil_gpt2 import DistilGPT2 from configs import argHandler from utils import load_image, split_report_sections from tokenizer_wrapper import TokenizerWrapper from huggingface_hub import hf_hub_download from api import API_call # from src.models.cnn_encoder import # from src.models.distil_gpt2 import DistilGPT2 # from src.configs import argHandler FLAGS = argHandler() def init_model(): global tokenizer_wrapper, encoder, decoder, optimizer device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("✅ Initializing model components...") from configs import argHandler FLAGS.setDefaults() tokenizer_wrapper = TokenizerWrapper( FLAGS.csv_label_columns[0], FLAGS.max_sequence_length, FLAGS.tokenizer_vocab_size ) encoder_model_dir = 'pretrained_visual_model' encoder = CNN_Encoder( encoder_model_dir, FLAGS.visual_model_name, FLAGS.visual_model_pop_layers, FLAGS.encoder_layers, FLAGS.tags_threshold, num_tags=len(FLAGS.tags) ) decoder = DistilGPT2.from_pretrained('distilgpt2') optimizer = torch.optim.Adam(decoder.parameters(), lr=FLAGS.learning_rate) encoder.to(device) decoder.to(device) checkpoint_path = hf_hub_download( repo_id="TransformingBerry/CDGPT2_checkpoint", filename="checkpoint.pth" ) if os.path.exists(checkpoint_path): print(f"✅ Restoring from checkpoint: {checkpoint_path}") checkpoint = torch.load(checkpoint_path, map_location=device) encoder.load_state_dict(checkpoint['encoder_state_dict']) decoder.load_state_dict(checkpoint['decoder_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) else: print("⚠️ No checkpoint found. Starting from scratch.") print("✅ Model initialized.") init_model() def generate_report(image_bytes): image = Image.open(io.BytesIO(image_bytes)) image_tensor = load_image(image) visual_features, tags_embedding = encoder(image_tensor) dec_input = torch.unsqueeze( torch.tensor(tokenizer_wrapper.GPT2_encode('startseq', pad=False)), 0 ) generation_config = { "visual_features": visual_features, "tags_embedding": tags_embedding, "num_beams": 1, "max_length": FLAGS.max_sequence_length, "min_length": 3, "eos_token_ids": tokenizer_wrapper.GPT2_eos_token_id(), "pad_token_id": tokenizer_wrapper.GPT2_pad_token_id(), "do_sample": False, "early_stopping": True, } tokens = decoder.generate(dec_input, **generation_config) sentence = tokenizer_wrapper.GPT2_decode(tokens[0]) sentence = tokenizer_wrapper.filter_special_words(sentence) print(sentence) # Call the API to structure the report structured_report = API_call(sentence) print(structured_report) structured_report =split_report_sections(structured_report) return structured_report