|
|
|
import streamlit as st |
|
from PIL import Image |
|
import io |
|
from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer |
|
import torch |
|
from torchvision import transforms |
|
import open_clip |
|
from openai import OpenAI |
|
import openai |
|
from diffusers import StableDiffusionPipeline |
|
|
|
|
|
top_notification = st.empty() |
|
|
|
|
|
if 'simplified_text' not in st.session_state: |
|
st.session_state['simplified_text'] = '' |
|
if 'new_caption' not in st.session_state: |
|
st.session_state['new_caption'] = '' |
|
if 'model_clip' not in st.session_state: |
|
st.session_state['model_clip'] = None |
|
if 'transform_clip' not in st.session_state: |
|
st.session_state['transform_clip'] = None |
|
if 'openai_api_key' not in st.session_state: |
|
st.session_state['openai_api_key'] = '' |
|
if 'message_content_from_caption' not in st.session_state: |
|
st.session_state['message_content_from_caption'] = '' |
|
if 'message_content_from_simplified_text' not in st.session_state: |
|
st.session_state['message_content_from_simplified_text'] = '' |
|
if 'image_from_caption' not in st.session_state: |
|
st.session_state['image_from_caption'] = None |
|
if 'image_from_simplified_text' not in st.session_state: |
|
st.session_state['image_from_simplified_text'] = None |
|
if 'image_from_press_text' not in st.session_state: |
|
st.session_state['image_from_press_text'] = None |
|
|
|
|
|
|
|
model_name = "mrm8488/t5-small-finetuned-text-simplification" |
|
tokenizer_name = "mrm8488/t5-small-finetuned-text-simplification" |
|
|
|
|
|
if 'model' not in st.session_state or 'tokenizer' not in st.session_state: |
|
st.session_state['model'] = AutoModelForSeq2SeqLM.from_pretrained(model_name) |
|
st.session_state['tokenizer'] = AutoTokenizer.from_pretrained(tokenizer_name) |
|
st.session_state['simplifier'] = pipeline("text2text-generation", model=st.session_state['model'], tokenizer=st.session_state['tokenizer']) |
|
|
|
|
|
simplifier = st.session_state['simplifier'] |
|
|
|
|
|
def load_clip_model(): |
|
model_clip, _, transform_clip = open_clip.create_model_and_transforms( |
|
model_name="coca_ViT-L-14", |
|
pretrained="mscoco_finetuned_laion2B-s13B-b90k" |
|
) |
|
return model_clip, transform_clip |
|
|
|
|
|
def generate_caption(image_path): |
|
|
|
if st.session_state['model_clip'] is None or st.session_state['transform_clip'] is None: |
|
st.session_state['model_clip'], st.session_state['transform_clip'] = load_clip_model() |
|
|
|
|
|
im = Image.open(image_path).convert("RGB") |
|
im = st.session_state['transform_clip'](im).unsqueeze(0) |
|
|
|
|
|
with torch.no_grad(), torch.cuda.amp.autocast(): |
|
generated = st.session_state['model_clip'].generate(im) |
|
|
|
new_caption = open_clip.decode(generated[0]).split("<end_of_text>")[0].replace("<start_of_text>", "")[:-2] |
|
return new_caption |
|
|
|
|
|
|
|
|
|
st.title("ARTSPEAK") |
|
|
|
|
|
|
|
user_input = st.text_area("Enter text here") |
|
|
|
|
|
uploaded_image = st.file_uploader("Upload an image (jpg or png)", type=["jpg", "png"]) |
|
|
|
|
|
|
|
with st.expander("Display of Uploaded Files"): |
|
st.write("These are you uploaded files:") |
|
|
|
if user_input is not None: |
|
|
|
st.write("Original Text:") |
|
st.write(user_input) |
|
|
|
|
|
if uploaded_image is not None: |
|
|
|
st.image(uploaded_image, caption="Uploaded Image", use_column_width=True) |
|
|
|
st.markdown("---") |
|
|
|
|
|
if st.button("Simplify"): |
|
if user_input: |
|
simplified_text = simplifier(user_input, min_length=20, max_length=50, do_sample=True) |
|
|
|
st.session_state['simplified_text'] = simplified_text[0]['generated_text'] |
|
else: |
|
st.warning("Please enter text in the input field before clicking 'Save'") |
|
|
|
|
|
if st.session_state['simplified_text']: |
|
st.write("Simplified Text:") |
|
st.write(st.session_state['simplified_text']) |
|
|
|
st.markdown("---") |
|
|
|
|
|
|
|
if st.button("Get Caption"): |
|
if uploaded_image is not None: |
|
|
|
caption = generate_caption(uploaded_image) |
|
|
|
st.session_state['new_caption'] = caption |
|
else: |
|
st.warning("Please upload an image before clicking 'Get Caption'") |
|
|
|
|
|
if st.session_state['new_caption']: |
|
st.write("New Caption for Artwork:") |
|
st.write(st.session_state['new_caption']) |
|
|
|
st.markdown("---") |
|
|
|
|
|
|
|
|
|
|
|
api_key_input = st.text_input("Enter your OpenAI API key if you want more", type="password") |
|
|
|
|
|
if st.button('Save API Key'): |
|
st.session_state['openai_api_key'] = api_key_input |
|
st.success("API Key saved temporarily for this session.") |
|
|
|
st.markdown("---") |
|
|
|
|
|
def get_openai_completion(api_key, prompt_message): |
|
client = OpenAI(api_key=api_key,) |
|
completion = client.chat.completions.create( |
|
model="gpt-3.5-turbo", |
|
messages=[ |
|
{"role": "system", "content": "I give a short description of an artwork. Please write a complex press text for an exhibiton in international art english dealing with post-colonialism, identity politics, military industrial complex and queerness through the language of writerts like Ranciere, Deleuze, Trevor Paglen, Hito Steyerl, Slavoy Zizek, Claire Fontane, Michel Foucault, Donna Harraway and Paul Preciado. Without doing too many name drops. Just output the press text and not surrounding or explaining messages with it."}, |
|
{"role": "user", "content": prompt_message} |
|
] |
|
) |
|
return completion.choices[0].message.content |
|
|
|
|
|
if st.button("Generate Press Text from New Caption"): |
|
if st.session_state['new_caption'] and st.session_state['openai_api_key']: |
|
try: |
|
st.session_state['message_content_from_caption'] = get_openai_completion(st.session_state['openai_api_key'], st.session_state['new_caption']) |
|
except Exception as e: |
|
st.error(f"An error occurred: {e}") |
|
else: |
|
st.warning("Please ensure a caption is generated and an API key is entered.") |
|
|
|
|
|
if st.session_state['message_content_from_caption']: |
|
st.write("Generated Press Text from New Caption:") |
|
st.write(st.session_state['message_content_from_caption']) |
|
|
|
|
|
if st.button("Generate Press Text from Simplified Text"): |
|
if st.session_state['simplified_text'] and st.session_state['openai_api_key']: |
|
try: |
|
st.session_state['message_content_from_simplified_text'] = get_openai_completion(st.session_state['openai_api_key'], st.session_state['simplified_text']) |
|
except Exception as e: |
|
st.error(f"An error occurred: {e}") |
|
else: |
|
st.warning("Please ensure simplified text is available and an API key is entered.") |
|
|
|
|
|
if st.session_state['message_content_from_simplified_text']: |
|
st.write("Generated Press Text from Simplified Text:") |
|
st.write(st.session_state['message_content_from_simplified_text']) |
|
|
|
st.markdown("---") |
|
|
|
|
|
|
|
|
|
|
|
def load_diffusion_model(): |
|
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16) |
|
pipe = pipe.to("cuda") |
|
return pipe |
|
|
|
|
|
def generate_image(pipe, prompt): |
|
top_notification.text('Generating image...') |
|
image = pipe(prompt).images[0] |
|
top_notification.empty() |
|
return image |
|
|
|
|
|
if st.button("Generate Image from New Caption"): |
|
if st.session_state['new_caption']: |
|
pipe = load_diffusion_model() |
|
prompt_caption = f"contemporary art of {st.session_state['new_caption']}" |
|
st.session_state['image_from_caption'] = generate_image(pipe, prompt_caption) |
|
|
|
|
|
if st.session_state['image_from_caption'] is not None: |
|
st.image(st.session_state['image_from_caption'], caption="Image from New Caption", use_column_width=True) |
|
|
|
|
|
if st.button("Generate Image from Simplified Text"): |
|
if st.session_state['simplified_text']: |
|
pipe = load_diffusion_model() |
|
prompt_summary = f"contemporary art of {st.session_state['simplified_text']}" |
|
st.session_state['image_from_simplified_text'] = generate_image(pipe, prompt_summary) |
|
|
|
|
|
if st.session_state['image_from_simplified_text'] is not None: |
|
st.image(st.session_state['image_from_simplified_text'], caption="Image from Simplified Text", use_column_width=True) |
|
|
|
|
|
if st.button("Generate Image from Press Text"): |
|
if st.session_state['message_content_from_simplified_text']: |
|
pipe = load_diffusion_model() |
|
prompt_press_text = f"contemporary art of {st.session_state['message_content_from_simplified_text']}" |
|
st.session_state['image_from_press_text'] = generate_image(pipe, prompt_press_text) |
|
|
|
|
|
if st.session_state['image_from_press_text'] is not None: |
|
st.image(st.session_state['image_from_press_text'], caption="Image from Press Text", use_column_width=True) |