File size: 3,080 Bytes
5d3b777
9789b35
5d3b777
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import requests
from langchain.chat_models import ChatOpenAI #model server 
from langchain_groq import ChatGroq
from langchain.chains import LLMChain
from langchain.prompts import (
    PromptTemplate,
    SystemMessagePromptTemplate,
    HumanMessagePromptTemplate,
    ChatPromptTemplate,
)
from config import app_config
import mongo_utils as mongo
GROQ_API_KEY = "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"

def __image2text(image):
    """Generates a short description of the image"""
    headers = {"Authorization": app_config.HF_TOKEN}
    try:
        response = requests.post(app_config.I2T_API_URL, headers=headers, data=image)
        response = response.json()[0]["generated_text"]
    except Exception as e:
        print(e)
    return response


def __text2story(image_desc, genre, style, word_count, creativity):
    """ "Generates a short story based on image description text prompt"""
    ## chat LLM model
    # story_model = ChatOpenAI(
    #     model="gpt-3.5-turbo",
    #     openai_api_key=app_config.OPENAI_KEY,
    #     temperature=creativity,
    # )

    story_model = ChatGroq(model="llama3-8b-8192",  
                temperature=0.0,
                api_key=GROQ_API_KEY)
    
    ## chat message prompts
    sys_prompt = PromptTemplate(
        template="""You are an expert story writer, write a maximum of {word_count} 
        words long story in {genre} genre in {style} writing style, based on the user 
        provided story-context.
        """,
        input_variables=["word_count", "genre", "style"],
    )
    system_msg_prompt = SystemMessagePromptTemplate(prompt=sys_prompt)
    human_prompt = PromptTemplate(
        template="story-context: {context}", input_variables=["context"]
    )
    human_msg_prompt = HumanMessagePromptTemplate(prompt=human_prompt)
    chat_prompt = ChatPromptTemplate.from_messages(
        [system_msg_prompt, human_msg_prompt]
    )
    ## LLM chain
    story_chain = LLMChain(llm=story_model, prompt=chat_prompt)
    response = story_chain.run(
        genre=genre, style=style, word_count=word_count, context=image_desc
    )
    return response


def generate_story(image_file, genre, style, word_count, creativity):
    """Generates a story given an image"""
    # read image as bytes arrayS
    with open(image_file, "rb") as f:
        input_image = f.read()
    # generate caption for image
    image_desc = __image2text(image=input_image)

    print("++++++++++++++++++++++++++++++++++++++")
    print(image_desc)
    print("++++++++++++++++++++++++++++++++++++++")
    # generate story from caption
    story = __text2story(
        image_desc=image_desc,
        genre=genre,
        style=style,
        word_count=word_count,
        creativity=creativity,
    )
    # increment the openai access counter and compute count stats
    mongo.increment_curr_access_count()
    max_count = app_config.openai_max_access_count
    curr_count = app_config.openai_curr_access_count
    available_count = max_count - curr_count
    return story, max_count, curr_count, available_count