File size: 1,543 Bytes
d87e8d0
 
 
bc9b706
d87e8d0
 
bc9b706
 
d87e8d0
bc9b706
 
 
d87e8d0
 
 
 
 
 
 
bc9b706
 
 
 
 
 
 
d87e8d0
 
 
 
bc9b706
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d87e8d0
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from app.logger import get_logger
logger = get_logger(__name__)

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"

# Load tokenizer and model
tokenizer =AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

def generate_story(caption: str, max_length: int = 256) -> str:
    logger.info("Generating story...")
    try:
        # Turn caption into a story prompt
        prompt = (
        "<|system|>\n"
        "You are a helpful assistant.</s>\n"
        "<|user|>\n"
        f"Write a complete, short story about {caption}. Make sure the story has a clear ending.\n</s>\n"
        "<|assistant|>\n"
        )

        # Tokenize and run through model
        inputs = tokenizer(prompt, return_tensors="pt").to(device)
        outputs = model.generate(
            **inputs,
            max_new_tokens=1000,
            do_sample=True,
            temperature=0.8,
            top_p=0.9,
            top_k=50,
            eos_token_id=tokenizer.eos_token_id
            )


        # Decode and clean output
        generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
        generated_story = generated_text[len(prompt):]  # Strip prompt part
        
        return generated_story.replace(prompt, "").strip()
    except Exception as e:
        logger.exception(f"Failed to generate story: {str(e)}")
        raise