mr_shitposter / app.py
nihonium286's picture
Update app.py
d253c2c
import streamlit as st
from transformers import pipeline, AutoModelWithLMHead, AutoTokenizer
from PIL import Image
import torch
st.set_page_config(layout="wide")
image_pipe = pipeline("image-classification")
text_pipe = pipeline("text-generation")
k2t_tokenizer = AutoTokenizer.from_pretrained("mrm8488/t5-base-finetuned-common_gen")
k2t_model = AutoModelWithLMHead.from_pretrained("mrm8488/t5-base-finetuned-common_gen")
def gen_sentence(words, max_length=32):
input_text = words
features = k2t_tokenizer([input_text], return_tensors='pt')
output = k2t_model.generate(input_ids=features['input_ids'],
attention_mask=features['attention_mask'],
max_length=max_length)
return k2t_tokenizer.decode(output[0], skip_special_tokens=True)
img = st.file_uploader(label='Upload jpg or png to create post',type=['jpg','png'])
if img is None:
torch.hub.download_url_to_file('https://assets.epicurious.com/photos/5761d0268accf290434553aa/master/pass/panna-cotta.jpg', "img.jpg")
img = "img.jpg"
with Image.open(img) as img:
results = image_pipe(img)
keywords = ""
for keyword in results:
keywords += keyword["label"].split(',')[0]
post_text = text_pipe(gen_sentence(keywords))[0]["generated_text"]
col1, col2 = st.columns(2)
with col1:
st.subheader("Your image")
st.image(img)
with col2:
st.subheader("Generated text")
st.write(post_text)