import streamlit as st from streamlit.delta_generator import DeltaGenerator import os import time import json import re from typing import List, Literal, TypedDict, Tuple from transformers import AutoTokenizer from gradio_client import Client from openai import OpenAI import anthropic from groq import Groq import constants as C import utils as U from helpers.auth import runWithAuth from helpers.sidebar import showSidebar from helpers.activities import saveLatestActivity from helpers.imageCdn import getCdnUrl from dotenv import load_dotenv load_dotenv() ModelType = Literal["GPT4", "CLAUDE", "LLAMA"] ModelConfig = TypedDict("ModelConfig", { "client": OpenAI | Groq | anthropic.Anthropic, "model": str, "max_context": int, "tokenizer": AutoTokenizer }) modelType: ModelType = os.environ.get("MODEL_TYPE") or "CLAUDE" MODEL_CONFIG: dict[ModelType, ModelConfig] = { "GPT4": { "client": OpenAI(api_key=os.environ.get("OPENAI_API_KEY")), "model": "gpt-4o-mini", "max_context": 128000, "tokenizer": AutoTokenizer.from_pretrained("Xenova/gpt-4o") }, "CLAUDE": { "client": anthropic.Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY")), "model": "claude-3-5-sonnet-20240620", "max_context": 128000, "tokenizer": AutoTokenizer.from_pretrained("Xenova/claude-tokenizer") }, "LLAMA": { "client": Groq(api_key=os.environ.get("GROQ_API_KEY")), "model": "llama-3.1-70b-versatile", "max_context": 128000, "tokenizer": AutoTokenizer.from_pretrained("Xenova/Meta-Llama-3.1-Tokenizer") } } client = MODEL_CONFIG[modelType]["client"] MODEL = MODEL_CONFIG[modelType]["model"] MAX_CONTEXT = MODEL_CONFIG[modelType]["max_context"] tokenizer = MODEL_CONFIG[modelType]["tokenizer"] isClaudeModel = modelType == "CLAUDE" def __countTokens(text): text = str(text) tokens = tokenizer.encode(text, add_special_tokens=False) return len(tokens) st.set_page_config( page_title="Kommuneity Story Creator", page_icon=C.AI_ICON, # menu_items={"About": None} ) def __isInvalidResponse(response: str): if len(re.findall(r'\n((?!http)[a-z])', response)) > 3 and "```" not in response: U.pprint("new line followed by small case char") return True if len(re.findall(r'\b(\w+)(\s+\1){2,}\b', response)) > 1: U.pprint("lot of consecutive repeating words") return True if len(re.findall(r'\n\n', response)) > 30: U.pprint("lots of paragraphs") return True if C.EXCEPTION_KEYWORD in response: U.pprint("LLM API threw exception") if 'roles must alternate between "user" and "assistant"' in str(response): U.pprint("Removing last msg from context...") st.session_state.messages.pop(-2) return True if ('{\n "options"' in response) and (C.JSON_SEPARATOR not in response): U.pprint("JSON response without json separator") return True if ('{\n "action"' in response) and (C.JSON_SEPARATOR not in response): U.pprint("JSON response without json separator") return True if response.startswith(C.JSON_SEPARATOR): U.pprint("only options with no text") return True def __matchingKeywordsCount(keywords: List[str], text: str): return sum([ 1 if keyword in text else 0 for keyword in keywords ]) def __getRawImagePromptDetails(prompt: str, response: str) -> Tuple[str, str, str]: regex = r'[^a-z0-9 \n\.\-\:\/]|((the) +)' cleanedResponse = re.sub(regex, '', response.lower()) U.pprint(f"{cleanedResponse=}") cleanedPrompt = re.sub(regex, '', prompt.lower()) if (st.session_state.selectedStory): imageText = st.session_state.selectedStory return ( f"Extract the story from this text and add few more details about this story:\n{imageText}", "Effect: dramatic, bokeh", "Painting your story character ...", ) if ( __matchingKeywordsCount( [C.BOOKING_LINK], cleanedResponse ) > 0 and "storytelling coach" not in cleanedPrompt ): aiResponses = [ chat.get("content") for chat in st.session_state.chatHistory if chat.get("role") == "assistant" ] relevantResponse = f""" {aiResponses[-1]} {response} """ return ( f"Extract the story from this text:\n{relevantResponse}", """ Style: In a storybook, surreal """, "Imagining your story scene ...", ) return (None, None, None) def __getImagePromptDetails(prompt: str, response: str): (enhancePrompt, imagePrompt, loaderText) = __getRawImagePromptDetails(prompt, response) if imagePrompt or enhancePrompt: # U.pprint(f"[Raw] {enhancePrompt=} | {imagePrompt=}") promptEnhanceModelType: ModelType = "LLAMA" U.pprint(f"{promptEnhanceModelType=}") modelConfig = MODEL_CONFIG[promptEnhanceModelType] client = modelConfig["client"] model = modelConfig["model"] isClaudeModel = promptEnhanceModelType == "CLAUDE" systemPrompt = "You help in creating prompts for image generation" promptPrefix = f"{enhancePrompt}\nAnd then use the above to" if enhancePrompt else "Use the text below to" enhancePrompt = f""" {promptPrefix} create a prompt for image generation. {imagePrompt} Return only the final Image Generation Prompt, and nothing else """ U.pprint(f"[Raw] {enhancePrompt=}") llmArgs = { "model": model, "messages": [{ "role": "user", "content": enhancePrompt }], "temperature": 1, "max_tokens": 2000 } if isClaudeModel: llmArgs["system"] = systemPrompt response = client.messages.create(**llmArgs) imagePrompt = response.content[0].text else: llmArgs["messages"] = [ {"role": "system", "content": systemPrompt}, *llmArgs["messages"] ] response = client.chat.completions.create(**llmArgs) responseMessage = response.choices[0].message imagePrompt = responseMessage.content U.pprint(f"[Enhanced] {imagePrompt=}") return (imagePrompt, loaderText) def __getMessages(): def getContextSize(): currContextSize = __countTokens(C.SYSTEM_MSG) + __countTokens(st.session_state.messages) + 100 U.pprint(f"{currContextSize=}") return currContextSize while getContextSize() > MAX_CONTEXT: U.pprint("Context size exceeded, removing first message") st.session_state.messages.pop(0) return st.session_state.messages def __logLlmRequest(messagesFormatted: list): contextSize = __countTokens(messagesFormatted) U.pprint(f"{contextSize=} | {MODEL}") # U.pprint(f"{messagesFormatted=}") def __predict(): messagesFormatted = [] try: if isClaudeModel: messagesFormatted.extend(__getMessages()) __logLlmRequest(messagesFormatted) with client.messages.stream( model=MODEL, messages=messagesFormatted, system=C.SYSTEM_MSG, temperature=0.9, max_tokens=4000, ) as stream: for text in stream.text_stream: yield text else: messagesFormatted.append( {"role": "system", "content": C.SYSTEM_MSG} ) messagesFormatted.extend(__getMessages()) __logLlmRequest(messagesFormatted) response = client.chat.completions.create( model=MODEL, messages=messagesFormatted, temperature=1, max_tokens=4000, stream=True ) for chunk in response: choices = chunk.choices if not choices: U.pprint("Empty chunk") continue chunkContent = chunk.choices[0].delta.content if chunkContent: yield chunkContent except Exception as e: U.pprint(f"LLM API Error: {e}") yield f"{C.EXCEPTION_KEYWORD} | {e}" def __generateImage(prompt: str): fluxClient = Client( "black-forest-labs/FLUX.1-schnell", os.environ.get("HF_FLUX_CLIENT_TOKEN") ) result = fluxClient.predict( prompt=prompt, seed=0, randomize_seed=True, width=1024, height=768, num_inference_steps=4, api_name="/infer" ) U.pprint(f"imageResult={result}") return result def __paintImageIfApplicable( imageContainer: DeltaGenerator, prompt: str, response: str, ): imagePath = None try: (imagePrompt, loaderText) = __getImagePromptDetails(prompt, response) if imagePrompt: imgContainer = imageContainer.container() imgContainer.write( f"""
{loaderText}
""", unsafe_allow_html=True ) imgContainer.image(C.IMAGE_LOADER) (imagePath, seed) = __generateImage(imagePrompt) imageContainer.image(imagePath) except Exception as e: U.pprint(e) imageContainer.empty() return imagePath def __selectButton(optionLabel: str): st.session_state["buttonValue"] = optionLabel U.pprint(f"Selected: {optionLabel}") def __showButtons(options: list): for option in options: st.button( option["label"], key=option["id"], on_click=lambda label=option["label"]: __selectButton(label) ) def __resetButtonState(): st.session_state.buttonValue = "" def __resetButtons(): st.session_state.buttons = [] def __resetSelectedStory(): st.session_state.selectedStory = {} def __setStartMsg(msg): st.session_state.startMsg = msg if "ipAddress" not in st.session_state: st.session_state.ipAddress = st.context.headers.get("x-forwarded-for") if "chatHistory" not in st.session_state: st.session_state.chatHistory = [] if "messages" not in st.session_state: st.session_state.messages = [] if "buttonValue" not in st.session_state: __resetButtonState() if "selectedStory" not in st.session_state: __resetSelectedStory() if "selectedStoryTitle" not in st.session_state: st.session_state.selectedStoryTitle = "" if "isStoryChosen" not in st.session_state: st.session_state.isStoryChosen = False if "buttons" not in st.session_state: st.session_state.buttons = [] if "activityId" not in st.session_state: st.session_state.activityId = None if "userActivitiesLog" not in st.session_state: st.session_state.userActivitiesLog = [] U.pprint("\n") U.pprint("\n") U.applyCommonStyles() st.title("Kommuneity Story Creator 🪄") def mainApp(): if "startMsg" not in st.session_state: __setStartMsg("") st.button(C.START_MSG, on_click=lambda: __setStartMsg(C.START_MSG)) for (i, chat) in enumerate(st.session_state.chatHistory): role = chat["role"] content = chat["content"] imagePath = chat.get("image") buttons = chat.get("buttons") avatar = C.AI_ICON if role == "assistant" else C.USER_ICON with st.chat_message(role, avatar=avatar): st.markdown(content) if imagePath and U.isValidImageUrl(imagePath): st.image(imagePath) if buttons: __showButtons(buttons) chat["buttons"] = [] # U.pprint(f"{st.session_state.buttonValue=}") # U.pprint(f"{st.session_state.selectedStoryTitle=}") # U.pprint(f"{st.session_state.startMsg=}") if prompt := ( st.chat_input() or st.session_state["buttonValue"] or st.session_state["selectedStoryTitle"] or st.session_state["startMsg"] ): __resetButtonState() __resetButtons() __setStartMsg("") if st.session_state["selectedStoryTitle"] != prompt: __resetSelectedStory() st.session_state.selectedStoryTitle = "" with st.chat_message("user", avatar=C.USER_ICON): st.markdown(prompt) U.pprint(f"{prompt=}") st.session_state.chatHistory.append({"role": "user", "content": prompt }) st.session_state.messages.append({"role": "user", "content": prompt}) with st.chat_message("assistant", avatar=C.AI_ICON): responseContainer = st.empty() def __printAndGetResponse(): response = "" responseContainer.image(C.TEXT_LOADER) responseGenerator = __predict() for chunk in responseGenerator: response += chunk if __isInvalidResponse(response): U.pprint(f"InvalidResponse={response}") return if C.JSON_SEPARATOR not in response: responseContainer.markdown(response) return response response = __printAndGetResponse() while not response: U.pprint("Empty response. Retrying..") time.sleep(0.7) response = __printAndGetResponse() U.pprint(f"{response=}") rawResponse = response responseParts = response.split(C.JSON_SEPARATOR) jsonStr = None if len(responseParts) > 1: [response, jsonStr] = responseParts imageContainer = st.empty() imagePath = __paintImageIfApplicable(imageContainer, prompt, response) if imagePath: imagePath = getCdnUrl(imagePath) st.session_state.chatHistory.append({ "role": "assistant", "content": response, "image": imagePath, }) st.session_state.messages.append({ "role": "assistant", "content": rawResponse, }) if jsonStr: try: json.loads(jsonStr) jsonObj = json.loads(jsonStr) options = jsonObj.get("options") action = jsonObj.get("action") if options: __showButtons(options) st.session_state.buttons = options elif action: U.pprint(f"{action=}") if action == "SHOW_STORY_DATABASE": time.sleep(0.5) st.switch_page("pages/popular-stories.py") # st.code(jsonStr, language="json") except Exception as e: U.pprint(e) saveLatestActivity() runWithAuth(mainApp) showSidebar()