|
import streamlit as st |
|
from transformers import pipeline |
|
from huggingface_hub import InferenceClient |
|
from PIL import Image |
|
import os |
|
|
|
|
|
def initialize(): |
|
if 'initialized' not in st.session_state: |
|
print("Initializing...") |
|
st.session_state['initialized'] = True |
|
st.session_state['api_key'] = os.getenv("HUGGINGFACE_TOKEN") |
|
st.session_state['client'] = InferenceClient(api_key=st.session_state['api_key']) |
|
|
|
|
|
def main(): |
|
initialize() |
|
st.header("Character Captions") |
|
st.write("Have a character caption any image you upload!") |
|
character = st.selectbox("Choose a character", ["artist", "elmo", "unintelligible", "goku"]) |
|
|
|
uploaded_img = st.file_uploader("Upload an image here") |
|
|
|
if uploaded_img is not None: |
|
image = Image.open(uploaded_img) |
|
st.image(image) |
|
|
|
|
|
image_captioner = pipeline("image-to-text", model="Salesforce/blip-image-captioning-large") |
|
response = image_captioner(image) |
|
caption = response[0]['generated_text'] |
|
|
|
|
|
|
|
character_prompts = { |
|
"artist": f"Describe this caption like you're a artist: {caption}.", |
|
"elmo": f"Describe this caption like you're elmo: {caption}.", |
|
"unintelligible": f"Describe this caption in a way that makes no sense: {caption}.", |
|
"goku": f"Describe this caption like you're goku: {caption}." |
|
} |
|
|
|
prompt = character_prompts[character] |
|
messages = [ |
|
{ "role": "user", "content": prompt } |
|
] |
|
|
|
|
|
stream = st.session_state['client'].chat.completions.create( |
|
model="meta-llama/Llama-3.2-3B-Instruct", |
|
messages=messages, |
|
max_tokens=500, |
|
stream=True |
|
) |
|
|
|
response = '' |
|
for chunk in stream: |
|
response += chunk.choices[0].delta.content |
|
|
|
st.write(response) |
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
main() |