import streamlit as st from streamlit.delta_generator import DeltaGenerator import os import time import json import re from typing import List, Literal, TypedDict, Tuple from transformers import AutoTokenizer from gradio_client import Client from openai import OpenAI import anthropic from groq import Groq import constants as C import utils as U from helpers.auth import runWithAuth from helpers.sidebar import showSidebar from helpers.activities import saveLatestActivity from helpers.imageCdn import getCdnUrl from dotenv import load_dotenv load_dotenv() ModelType = Literal["GPT4", "CLAUDE", "LLAMA"] ModelConfig = TypedDict("ModelConfig", { "client": OpenAI | Groq | anthropic.Anthropic, "model": str, "max_context": int, "tokenizer": AutoTokenizer }) modelType: ModelType = os.environ.get("MODEL_TYPE") or "CLAUDE" MODEL_CONFIG: dict[ModelType, ModelConfig] = { "GPT4": { "client": OpenAI(api_key=os.environ.get("OPENAI_API_KEY")), "model": "gpt-4o-mini", "max_context": 128000, "tokenizer": AutoTokenizer.from_pretrained("Xenova/gpt-4o") }, "CLAUDE": { "client": anthropic.Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY")), "model": "claude-3-5-sonnet-20240620", "max_context": 128000, "tokenizer": AutoTokenizer.from_pretrained("Xenova/claude-tokenizer") }, "LLAMA": { "client": Groq(api_key=os.environ.get("GROQ_API_KEY")), "model": "llama-3.1-70b-versatile", "max_context": 128000, "tokenizer": AutoTokenizer.from_pretrained("Xenova/Meta-Llama-3.1-Tokenizer") } } client = MODEL_CONFIG[modelType]["client"] MODEL = MODEL_CONFIG[modelType]["model"] MAX_CONTEXT = MODEL_CONFIG[modelType]["max_context"] tokenizer = MODEL_CONFIG[modelType]["tokenizer"] isClaudeModel = modelType == "CLAUDE" def __countTokens(text): text = str(text) tokens = tokenizer.encode(text, add_special_tokens=False) return len(tokens) st.set_page_config( page_title="Kommuneity Story Creator", page_icon=C.AI_ICON, # menu_items={"About": None} ) def __isInvalidResponse(response: str): if len(re.findall(r'\n((?!http)[a-z])', response)) > 3 and "```" not in response: U.pprint("new line followed by small case char") return True if len(re.findall(r'\b(\w+)(\s+\1){2,}\b', response)) > 1: U.pprint("lot of consecutive repeating words") return True if len(re.findall(r'\n\n', response)) > 30: U.pprint("lots of paragraphs") return True if C.EXCEPTION_KEYWORD in response: U.pprint("LLM API threw exception") if 'roles must alternate between "user" and "assistant"' in str(response): U.pprint("Removing last msg from context...") st.session_state.messages.pop(-2) return True if ('{\n "options"' in response) and (C.JSON_SEPARATOR not in response): U.pprint("JSON response without json separator") return True if ('{\n "action"' in response) and (C.JSON_SEPARATOR not in response): U.pprint("JSON response without json separator") return True if response.startswith(C.JSON_SEPARATOR): U.pprint("only options with no text") return True def __matchingKeywordsCount(keywords: List[str], text: str): return sum([ 1 if keyword in text else 0 for keyword in keywords ]) def __getRawImagePromptDetails(prompt: str, response: str) -> Tuple[str, str, str]: regex = r'[^a-z0-9 \n\.\-\:\/]|((the) +)' cleanedResponse = re.sub(regex, '', response.lower()) U.pprint(f"{cleanedResponse=}") cleanedPrompt = re.sub(regex, '', prompt.lower()) if (st.session_state.selectedStory): imageText = st.session_state.selectedStory return ( f"Extract the story from this text and add few more details about this story:\n{imageText}", "Effect: dramatic, bokeh", "Painting your story character ...", ) if ( __matchingKeywordsCount( [C.BOOKING_LINK], cleanedResponse ) > 0 and "storytelling coach" not in cleanedPrompt ): aiResponses = [ chat.get("content") for chat in st.session_state.chatHistory if chat.get("role") == "assistant" ] relevantResponse = f""" {aiResponses[-1]} {response} """ return ( f"Extract the story from this text:\n{relevantResponse}", """ Style: In a storybook, surreal """, "Imagining your story scene ...", ) return (None, None, None) def __getImagePromptDetails(prompt: str, response: str): (enhancePrompt, imagePrompt, loaderText) = __getRawImagePromptDetails(prompt, response) if imagePrompt or enhancePrompt: # U.pprint(f"[Raw] {enhancePrompt=} | {imagePrompt=}") promptEnhanceModelType: ModelType = "LLAMA" U.pprint(f"{promptEnhanceModelType=}") modelConfig = MODEL_CONFIG[promptEnhanceModelType] client = modelConfig["client"] model = modelConfig["model"] isClaudeModel = promptEnhanceModelType == "CLAUDE" systemPrompt = "You help in creating prompts for image generation" promptPrefix = f"{enhancePrompt}\nAnd then use the above to" if enhancePrompt else "Use the text below to" enhancePrompt = f""" {promptPrefix} create a prompt for image generation. {imagePrompt} Return only the final Image Generation Prompt, and nothing else """ U.pprint(f"[Raw] {enhancePrompt=}") llmArgs = { "model": model, "messages": [{ "role": "user", "content": enhancePrompt }], "temperature": 1, "max_tokens": 2000 } if isClaudeModel: llmArgs["system"] = systemPrompt response = client.messages.create(**llmArgs) imagePrompt = response.content[0].text else: llmArgs["messages"] = [ {"role": "system", "content": systemPrompt}, *llmArgs["messages"] ] response = client.chat.completions.create(**llmArgs) responseMessage = response.choices[0].message imagePrompt = responseMessage.content U.pprint(f"[Enhanced] {imagePrompt=}") return (imagePrompt, loaderText) def __getMessages(): def getContextSize(): currContextSize = __countTokens(C.SYSTEM_MSG) + __countTokens(st.session_state.messages) + 100 U.pprint(f"{currContextSize=}") return currContextSize while getContextSize() > MAX_CONTEXT: U.pprint("Context size exceeded, removing first message") st.session_state.messages.pop(0) return st.session_state.messages def __logLlmRequest(messagesFormatted: list): contextSize = __countTokens(messagesFormatted) U.pprint(f"{contextSize=} | {MODEL}") # U.pprint(f"{messagesFormatted=}") def __predict(): messagesFormatted = [] try: if isClaudeModel: messagesFormatted.extend(__getMessages()) __logLlmRequest(messagesFormatted) with client.messages.stream( model=MODEL, messages=messagesFormatted, system=C.SYSTEM_MSG, temperature=0.9, max_tokens=4000, ) as stream: for text in stream.text_stream: yield text else: messagesFormatted.append( {"role": "system", "content": C.SYSTEM_MSG} ) messagesFormatted.extend(__getMessages()) __logLlmRequest(messagesFormatted) response = client.chat.completions.create( model=MODEL, messages=messagesFormatted, temperature=1, max_tokens=4000, stream=True ) for chunk in response: choices = chunk.choices if not choices: U.pprint("Empty chunk") continue chunkContent = chunk.choices[0].delta.content if chunkContent: yield chunkContent except Exception as e: U.pprint(f"LLM API Error: {e}") yield f"{C.EXCEPTION_KEYWORD} | {e}" def __generateImage(prompt: str): fluxClient = Client( "black-forest-labs/FLUX.1-schnell", os.environ.get("HF_FLUX_CLIENT_TOKEN") ) result = fluxClient.predict( prompt=prompt, seed=0, randomize_seed=True, width=1024, height=768, num_inference_steps=4, api_name="/infer" ) U.pprint(f"imageResult={result}") return result def __paintImageIfApplicable( imageContainer: DeltaGenerator, prompt: str, response: str, ): imagePath = None try: (imagePrompt, loaderText) = __getImagePromptDetails(prompt, response) if imagePrompt: imgContainer = imageContainer.container() imgContainer.write( f"""