|
import streamlit as st |
|
import os |
|
import time |
|
import base64 |
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torch.optim as optim |
|
import timm |
|
import math |
|
from PIL import Image, ImageDraw, ImageFont |
|
from io import BytesIO |
|
from diffusers import StableDiffusionPipeline |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
from transformers import AutoConfig |
|
|
|
|
|
pipe = StableDiffusionPipeline.from_pretrained("runway/stable-diffusion-v2-1") |
|
|
|
|
|
config = AutoConfig.from_pretrained("bert-base-cased") |
|
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased") |
|
model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", config=config) |
|
|
|
def generate_audio(text): |
|
audio = pipe(text, audio=True).audio.squeeze() |
|
audio = audio.detach().cpu().numpy() |
|
audio = audio.astype(np.float32) |
|
audio = audio / np.max(audio) |
|
return audio |
|
|
|
def generate_image(text): |
|
image = pipe(text).images[0] |
|
image = image.detach().cpu().numpy() |
|
image = Image.fromarray(image) |
|
return image |
|
|
|
def generate_video(text, frames=30, width=512, height=512): |
|
with torch.no_grad(): |
|
sequence = pipe.unet.unet(text).x.shape[0] |
|
sequence = int(sequence * frames) |
|
video = [] |
|
for I in range(frames): |
|
img = pipe(text, return_dict=True).images[0] |
|
img = F.interpolate(img, size=(height, width), mode="bilinear", align_corners=False) |
|
video.append(img.permute(1, 2, 0).detach().cpu().numpy()) |
|
video = np.stack(video, axis=0) |
|
video = (video * 255).astype(np.uint8) |
|
video = cv2.VideoWriter('output.avi', cv2.VideoWriter_fourcc(*'MJPG'), frames, (width, height)) |
|
for frame in video.iterquene(): |
|
for I in range(width): |
|
for j in range(height): |
|
frame[i, j] = video[sequence - sequence // frames * frames + I * height + j] |
|
video.write(frame) |
|
video.release() |
|
return cv2.VideoCapture('output.avi')[1] |
|
|
|
|
|
|
|
def main(): |
|
st.title("Streamlit App with Diffusers and Transformers") |
|
st.header("Generate Audio, Video, and Images using Diffusers") |
|
st.header("Chatbot using BERT") |
|
|
|
|
|
with st.form("audio_form"): |
|
text_input = st.text_input("Enter text for audio generation:") |
|
submit_button = st.form_submit_button("Generate Audio") |
|
if submit_button: |
|
audio_output = generate_audio(text_input) |
|
audio_base64 = base64.b64encode(audio_output).decode('utf-8') |
|
st.write(f"Generated Audio:") |
|
st.audio(BytesIO(base64.b64decode(audio_base64)), format="audio/x-wav") |
|
|
|
|
|
with st.form("image_form"): |
|
text_input = st.text_input("Enter text for image generation:") |
|
submit_button = st.form_submit_button("Generate Image") |
|
if submit_button: |
|
image_output = generate_image(text_input) |
|
image_base64 = base64.b64encode(image_output).decode('utf-8') |
|
st.image(image_output, caption="Generated Image:", use_column_width=True) |
|
st.write(f"Generated Image (base64): {image_base64}") |
|
|
|
|
|
with st.form("video_form"): |
|
text_input = st.text_input("Enter text for video generation:") |
|
frames = st.number_input("Number of frames:", value=30, step=1) |
|
width = st.number_input("Image width:", value=512, step=1) |
|
height = st.number_input("Image height:", value=512, step=1) |
|
submit_button = st.form_submit_button("Generate Video") |
|
if submit_button: |
|
video_output = generate_video(text_input, frames, width, height) |
|
st.write(f"Generated Video:") |
|
|
|
|
|
with st.form("chat_form"): |
|
user_input = st.text_area("Enter your message:", height=100) |
|
submit_button = st.form_submit_button("Send Message") |
|
if submit_button: |
|
message = tokenizer(user_input, padding=True, return_tensors="pt").to("cpu") |
|
outputs = model(message) |
|
prediction = torch.argmax(outputs.logits, dim=-1).item() |
|
response = tokenizer.decode(prediction, skip_special_tokens=True) |
|
st.write(f"Assistant Response: {response}") |
|
|
|
st.write("Streamlit App with Diffusers and Transformers") |
|
st.write("Generated by FallnAI") |