ysharma's picture
ysharma HF staff
update
856c26d
raw history blame
No virus
2.11 kB
import gradio as gr
import requests
import os
import torch as th
from torch import autocast
from diffusers import StableDiffusionPipeline
HF_TOKEN = os.environ["HF_TOKEN"]
#HF_TOKEN = os.environ.get("diffuse_new") or True
has_cuda = th.cuda.is_available()
device = th.device('cpu' if not th.cuda.is_available() else 'cuda')
print(f"device is :{device}")
# init stable diffusion model
#pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=th.float32, use_auth_token= HF_TOKEN).to(device) #revision="fp16",
def get_sd_old(translated_txt):
scale=7.5
steps=45
with autocast('cpu' if not th.cuda.is_available() else 'cuda'):
image = pipe(translated_txt, guidance_scale=scale, num_inference_steps=steps)["sample"][0]
return image
#API_URL = "https://api-inference.huggingface.co/models/bigscience/bloom"
#HF_TOKEN = os.environ.get("diffuse_new") or True
#headers = {"Authorization": f"Bearer {HF_TOKEN}"}
sd_inf = gr.Interface.load(name="spaces/stabilityai/stable-diffusion", api_key = HF_TOKEN, use_auth_token=HF_TOKEN )#'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')
def get_sd(translated_txt):
print("******** Inside get_SD ********")
print(f"translated_txt is : {translated_txt}")
#sd_inf = gr.Blocks.load(name="spaces/stabilityai/stable-diffusion", use_auth_token='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')
print(f"stable Diff inf is : {sd_inf}")
sd_img_gallery = sd_inf(translated_txt, float(4),float(45), float(7.5),1024, fn_index=2) # fn_index=2)[0] #(prompt, samples, steps, scale, seed) #translated_txt
return sd_img_gallery[0]
demo = gr.Blocks()
with demo:
gr.Markdown("Testing Diffusion models. STILL VERY MUCH WORK IN PROGRESS !!!!!!!!")
with gr.Row():
in_text_prompt = gr.Textbox(label="Enter English text here")
#out_text_chinese = gr.Textbox(label="Your Chinese language output")
b1 = gr.Button("Generate SD")
out_sd = gr.Image(type="pil", label="SD output for the given prompt")
b1.click(get_sd, in_text_prompt, out_sd) #out_gallery )
demo.launch(enable_queue=True, debug=True)