Spaces:
Running
Running
File size: 3,099 Bytes
059d3f3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
# app/ai_core.py
import warnings
from transformers import pipeline, AutoProcessor
from PIL import Image
import torch
# Ignore the specific warning about the slow processor, as it's informational.
warnings.filterwarnings(
"ignore",
message="Using a slow image processor as `use_fast` is unset*"
)
# Determine the device to use (GPU if available, otherwise CPU)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")
# --- Load Models ---
# Models are loaded once when the script starts to improve performance
# and avoid reloading on every request.
print("Loading image-to-text model (GIT)...")
git_processor = AutoProcessor.from_pretrained("microsoft/git-large-coco")
captioner = pipeline(
"image-to-text",
model="microsoft/git-large-coco",
processor=git_processor,
device=DEVICE
)
print("Image-to-text model loaded.")
print("Loading text-generation model (GPT-2)...")
text_generator = pipeline(
"text-generation",
model="gpt2",
device=DEVICE
)
print("Text-generation model loaded.")
# --- Core Functions ---
def get_image_caption(image: Image.Image) -> str:
"""Generates a text caption from a given PIL Image object."""
try:
result = captioner(image)
caption = result[0]['generated_text']
print(f"Generated Caption: {caption}")
return caption
except Exception as e:
print(f"Error in get_image_caption: {e}")
return "Could not generate image caption."
def generate_story_from_caption(caption: str) -> str:
"""Generates a story based on a given caption."""
prompt = f"A fantasy adventure set in a land that once was {caption}:"
try:
result = text_generator(
prompt,
max_new_tokens=150,
num_return_sequences=1,
temperature=0.9,
repetition_penalty=1.2,
truncation=True,
pad_token_id=text_generator.tokenizer.eos_token_id
)
generated_text = result[0]['generated_text']
print("\n--- Generated Story ---")
print(generated_text)
print("-----------------------\n")
return generated_text
except Exception as e:
print(f"Error in generate_story_from_caption: {e}")
return "Could not generate story."
# --- Direct Test Block ---
# This block runs only when the script is executed directly (e.g., python app/ai_core.py)
# It's useful for testing the AI logic without running the full API.
if __name__ == "__main__":
test_image_file = "test_image.jpeg"
print(f"Starting test with '{test_image_file}'...")
try:
image_obj = Image.open(test_image_file).convert("RGB")
image_caption = get_image_caption(image_obj)
if "Could not generate" not in image_caption:
generate_story_from_caption(image_caption)
else:
print("Story could not be generated because caption creation failed.")
except FileNotFoundError:
print(f"Error: Test file '{test_image_file}' not found. Please add an image to the project's root directory.") |