Ashhar
fixed regex for image prompt matching
1675db5
raw
history blame
13 kB
import streamlit as st
import os
import time
import json
import re
from typing import List, Literal, TypedDict, Tuple
from transformers import AutoTokenizer
from gradio_client import Client
import constants as C
import utils as U
from openai import OpenAI
import anthropic
from groq import Groq
from dotenv import load_dotenv
load_dotenv()
ModelType = Literal["GPT4", "CLAUDE", "LLAMA"]
ModelConfig = TypedDict("ModelConfig", {
"client": OpenAI | Groq | anthropic.Anthropic,
"model": str,
"max_context": int,
"tokenizer": AutoTokenizer
})
modelType: ModelType = os.environ.get("MODEL_TYPE") or "CLAUDE"
MODEL_CONFIG: dict[ModelType, ModelConfig] = {
"GPT4": {
"client": OpenAI(api_key=os.environ.get("OPENAI_API_KEY")),
"model": "gpt-4o-mini",
"max_context": 128000,
"tokenizer": AutoTokenizer.from_pretrained("Xenova/gpt-4o")
},
"CLAUDE": {
"client": anthropic.Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY")),
"model": "claude-3-sonnet-20240229",
"max_context": 128000,
"tokenizer": AutoTokenizer.from_pretrained("Xenova/claude-tokenizer")
},
"LLAMA": {
"client": Groq(api_key=os.environ.get("GROQ_API_KEY")),
"model": "llama-3.1-70b-versatile",
"max_context": 128000,
"tokenizer": AutoTokenizer.from_pretrained("Xenova/Meta-Llama-3.1-Tokenizer")
}
}
client = MODEL_CONFIG[modelType]["client"]
MODEL = MODEL_CONFIG[modelType]["model"]
MAX_CONTEXT = MODEL_CONFIG[modelType]["max_context"]
tokenizer = MODEL_CONFIG[modelType]["tokenizer"]
isClaudeModel = modelType == "CLAUDE"
def __countTokens(text):
text = str(text)
tokens = tokenizer.encode(text, add_special_tokens=False)
return len(tokens)
st.set_page_config(
page_title="Kommuneity Story Creator",
page_icon=C.AI_ICON,
# menu_items={"About": None}
)
U.pprint("\n")
U.pprint("\n")
def __isInvalidResponse(response: str):
# new line followed by small case char
if len(re.findall(r'\n[a-z]', response)) > 3:
return True
# lot of repeating words
if len(re.findall(r'\b(\w+)(\s+\1){2,}\b', response)) > 1:
return True
# lots of paragraphs
if len(re.findall(r'\n\n', response)) > 20:
return True
# LLM API threw exception
if C.EXCEPTION_KEYWORD in response:
return True
# json response without json separator
if ('{\n "options"' in response) and (C.JSON_SEPARATOR not in response):
return True
if ('{\n "action"' in response) and (C.JSON_SEPARATOR not in response):
return True
# only options with no text
if response.startswith(C.JSON_SEPARATOR):
return True
def __matchingKeywordsCount(keywords: List[str], text: str):
return sum([
1 if keyword in text else 0
for keyword in keywords
])
def __getRawImagePromptDetails(prompt: str, response: str) -> Tuple[str, str, str]:
regex = r'[^a-z0-9 \n\.\-\:\/]|((the) +)'
cleanedResponse = re.sub(regex, '', response.lower())
U.pprint(f"{cleanedResponse=}")
cleanedPrompt = re.sub(regex, '', prompt.lower())
U.pprint(f"{cleanedPrompt=}")
if (
__matchingKeywordsCount(
["adapt", "personal branding", "purpose", "use case"],
cleanedResponse
) > 2
and "story so far" not in cleanedResponse
):
return (
f"Extract the name of selected story from this text and add few more details about this story:\n{response}",
"Effect: dramatic, bokeh",
"Painting your character ...",
)
if __matchingKeywordsCount(
[C.BOOKING_LINK],
cleanedResponse
) > 0:
relevantResponse = f"""
{st.session_state.chatHistory[-1].get("content")}
{response}
"""
return (
f"Extract the story from this text:\n{relevantResponse}",
"""
Style: In a storybook, surreal
""",
"Imagining your scene (beta) ...",
)
return (None, None, None)
def __getImagePromptDetails(prompt: str, response: str):
(enhancePrompt, imagePrompt, loaderText) = __getRawImagePromptDetails(prompt, response)
if imagePrompt or enhancePrompt:
U.pprint(f"[Raw] {enhancePrompt=} | {imagePrompt=}")
promptEnhanceModelType: ModelType = "LLAMA"
U.pprint(f"{promptEnhanceModelType=}")
modelConfig = MODEL_CONFIG[promptEnhanceModelType]
client = modelConfig["client"]
model = modelConfig["model"]
isClaudeModel = promptEnhanceModelType == "CLAUDE"
systemPrompt = "You help in creating prompts for image generation"
promptPrefix = f"{enhancePrompt}\nAnd then use the above to" if enhancePrompt else "Use the text below to"
llmArgs = {
"model": model,
"messages": [{
"role": "user",
"content": f"""
{promptPrefix} create a prompt for image generation (limit to less than 500 words)
{imagePrompt}
Return only the final Image Generation Prompt, and nothing else
"""
}],
"temperature": 1,
"max_tokens": 2000
}
if isClaudeModel:
llmArgs["system"] = systemPrompt
response = client.messages.create(**llmArgs)
imagePrompt = response.content[0].text
else:
llmArgs["messages"] = [
{"role": "system", "content": systemPrompt},
*llmArgs["messages"]
]
response = client.chat.completions.create(**llmArgs)
responseMessage = response.choices[0].message
imagePrompt = responseMessage.content
U.pprint(f"[Enhanced] {imagePrompt=}")
return (imagePrompt, loaderText)
def __getMessages():
def getContextSize():
currContextSize = __countTokens(C.SYSTEM_MSG) + __countTokens(st.session_state.messages) + 100
U.pprint(f"{currContextSize=}")
return currContextSize
while getContextSize() > MAX_CONTEXT:
U.pprint("Context size exceeded, removing first message")
st.session_state.messages.pop(0)
return st.session_state.messages
def __logLlmRequest(messagesFormatted: list):
contextSize = __countTokens(messagesFormatted)
U.pprint(f"{contextSize=} | {MODEL}")
# U.pprint(f"{messagesFormatted=}")
def predict():
messagesFormatted = []
try:
if isClaudeModel:
messagesFormatted.extend(__getMessages())
__logLlmRequest(messagesFormatted)
with client.messages.stream(
model=MODEL,
messages=messagesFormatted,
temperature=0.8,
system=C.SYSTEM_MSG,
max_tokens=4000,
) as stream:
for text in stream.text_stream:
yield text
else:
messagesFormatted.append(
{"role": "system", "content": C.SYSTEM_MSG}
)
messagesFormatted.extend(__getMessages())
__logLlmRequest(messagesFormatted)
response = client.chat.completions.create(
model=MODEL,
messages=messagesFormatted,
temperature=1,
max_tokens=4000,
stream=True
)
for chunk in response:
choices = chunk.choices
if not choices:
U.pprint("Empty chunk")
continue
chunkContent = chunk.choices[0].delta.content
if chunkContent:
yield chunkContent
except Exception as e:
U.pprint(f"LLM API Error: {e}")
yield C.EXCEPTION_KEYWORD
def __generateImage(prompt: str):
fluxClient = Client("black-forest-labs/FLUX.1-schnell")
result = fluxClient.predict(
prompt=prompt,
seed=0,
randomize_seed=True,
width=1024,
height=768,
num_inference_steps=4,
api_name="/infer"
)
U.pprint(f"imageResult={result}")
return result
U.applyCommonStyles()
st.title("Kommuneity Story Creator 🪄")
def __resetButtonState():
st.session_state.buttonValue = ""
def __resetSelectedStory():
st.session_state.selectedStory = {}
def __setStartMsg(msg):
st.session_state.startMsg = msg
if "chatHistory" not in st.session_state:
st.session_state.chatHistory = []
if "messages" not in st.session_state:
st.session_state.messages = []
if "buttonValue" not in st.session_state:
__resetButtonState()
if "selectedStory" not in st.session_state:
__resetSelectedStory()
if "storyChosen" not in st.session_state:
st.session_state.storyChosen = False
if "startMsg" not in st.session_state:
__setStartMsg("")
st.button(C.START_MSG, on_click=lambda: __setStartMsg(C.START_MSG))
for chat in st.session_state.chatHistory:
role = chat["role"]
content = chat["content"]
imagePath = chat.get("image")
avatar = C.AI_ICON if role == "assistant" else C.USER_ICON
with st.chat_message(role, avatar=avatar):
st.markdown(content)
if imagePath:
st.image(imagePath)
# U.pprint(f"{st.session_state.buttonValue=}")
# U.pprint(f"{st.session_state.selectedStory=}")
# U.pprint(f"{st.session_state.startMsg=}")
if prompt := (
st.chat_input()
or st.session_state["buttonValue"]
or st.session_state["selectedStory"].get("title")
or st.session_state["startMsg"]
):
__resetButtonState()
__resetSelectedStory()
__setStartMsg("")
with st.chat_message("user", avatar=C.USER_ICON):
st.markdown(prompt)
U.pprint(f"{prompt=}")
st.session_state.chatHistory.append({"role": "user", "content": prompt })
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("assistant", avatar=C.AI_ICON):
responseContainer = st.empty()
def __printAndGetResponse():
response = ""
responseContainer.image(C.TEXT_LOADER)
responseGenerator = predict()
for chunk in responseGenerator:
response += chunk
if __isInvalidResponse(response):
U.pprint(f"InvalidResponse={response}")
return
if C.JSON_SEPARATOR not in response:
responseContainer.markdown(response)
return response
response = __printAndGetResponse()
while not response:
U.pprint("Empty response. Retrying..")
time.sleep(0.7)
response = __printAndGetResponse()
U.pprint(f"{response=}")
def selectButton(optionLabel):
st.session_state["buttonValue"] = optionLabel
U.pprint(f"Selected: {optionLabel}")
rawResponse = response
responseParts = response.split(C.JSON_SEPARATOR)
jsonStr = None
if len(responseParts) > 1:
[response, jsonStr] = responseParts
imagePath = None
imageContainer = st.empty()
try:
(imagePrompt, loaderText) = __getImagePromptDetails(prompt, response)
if imagePrompt:
imgContainer = imageContainer.container()
imgContainer.write(
f"""
<div class='blinking code'>
{loaderText}
</div>
""",
unsafe_allow_html=True
)
# imgContainer.markdown(f"`{loaderText}`")
imgContainer.image(C.IMAGE_LOADER)
(imagePath, seed) = __generateImage(imagePrompt)
imageContainer.image(imagePath)
except Exception as e:
U.pprint(e)
imageContainer.empty()
st.session_state.chatHistory.append({
"role": "assistant",
"content": response,
"image": imagePath,
})
st.session_state.messages.append({
"role": "assistant",
"content": rawResponse,
})
if jsonStr:
try:
json.loads(jsonStr)
jsonObj = json.loads(jsonStr)
options = jsonObj.get("options")
action = jsonObj.get("action")
if options:
for option in options:
st.button(
option["label"],
key=option["id"],
on_click=lambda label=option["label"]: selectButton(label)
)
elif action:
U.pprint(f"{action=}")
if action == "SHOW_STORY_DATABASE":
st.switch_page("pages/popular-stories.py")
# st.code(jsonStr, language="json")
except Exception as e:
U.pprint(e)