File size: 5,316 Bytes
19aec01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
# -*- coding: utf-8 -*-
"""Untitled3.ipynb

Automatically generated by Colaboratory.

Original file is located at
    https://colab.research.google.com/drive/1BltKPv_n-glCuuIIYSBA6GHK-tmwbl20
"""

!pip install fastai
!pip install gradio --upgrade
!pip install lida
!pip install diffusers --upgrade
!pip install tensorflow-probability --upgrade
!pip install invisible_watermark transformers accelerate safetensors
!pip install torch --upgrade
!pip install transformers --upgrade
!pip install datasets
#!pip install typing_extensions==4.5.0
#!pip install fastapi==0.103.0

import torch
from PIL import Image
from torchvision import transforms
import gradio as gr
import json
import urllib, urllib.request
from diffusers import DiffusionPipeline
from transformers import pipeline

# First Page
def demo_tab(image):
    # The demo tab simply returns the same image as input
    return image

# Secibd Page
def generate_image(Prompt,Negative_prompt,Steps):
  # load both base & refiner
  base = DiffusionPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
  )
  base.to("cuda")
  refiner = DiffusionPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-refiner-1.0",
    text_encoder_2=base.text_encoder_2,
    vae=base.vae,
    torch_dtype=torch.float16,
    use_safetensors=True,
    variant="fp16",
  )
  refiner.to("cuda")

# Define how many steps and what % of steps to be run on each experts (80/20) here
  high_noise_frac = 0.8
  prompt = Prompt
  negative_prompt = Negative_prompt
  n_steps = Steps
  # run both experts
  image = base(
      prompt=prompt,
      negative_prompt=negative_prompt,
      num_inference_steps=n_steps,
      denoising_end=high_noise_frac,
      output_type="latent",
  ).images
  image = refiner(
      prompt=prompt,
      num_inference_steps=n_steps,
      denoising_start=high_noise_frac,
      image=image,
  ).images[0]
  return image

def predict(input_image):
  model = torch.hub.load('RF5/danbooru-pretrained', 'resnet50')
  model.eval()

# Load JSON file from github as Label
  with urllib.request.urlopen("https://github.com/RF5/danbooru-pretrained/raw/master/config/class_names_6000.json") as url:
      labels = json.loads(url.read().decode())
      #Convert input image from array to PIL Image
      input_image = Image.fromarray(input_image.astype('uint8'), 'RGB')
      #Preprocess the input image
      preprocess = transforms.Compose([
          transforms.Resize(360),
          transforms.ToTensor(),
          transforms.Normalize(mean=[0.7137, 0.6628, 0.6519], std=[0.2970, 0.3017, 0.2979]),
      ])
      input_tensor = preprocess(input_image)
      input_batch = input_tensor.unsqueeze(0)

      # Use CUDA if available
      if torch.cuda.is_available():
          input_batch = input_batch.to('cuda')
          model.to('cuda')

      # Make prediction
      with torch.no_grad():
          output = model(input_batch)

      # Get probabilities
      probs = torch.sigmoid(output[0])

      # Convert tensor to Python list of floats
      probs = probs.cpu().numpy().tolist()

      # Sort labels with probabilities and return top 10
      sorted_labels_with_probs = sorted(list(zip(labels, probs)), key=lambda x: x[1], reverse=True)[:10]

      # Convert list of tuples to dictionary and convert numpy floats to Python floats
      sorted_labels_with_probs_dict = {label: float(prob) for label, prob in sorted_labels_with_probs}
      return sorted_labels_with_probs_dict

def image_classify(input_image, model):
    model_mapping= {
      "Resnet 50": "microsoft/resnet-50",
      "Vit Base Patch16-224": "google/vit-base-patch16-224",
      "NSFW Image Detection": "Falconsai/nsfw_image_detection",
      "Vit Age Classifier": "nateraw/vit-age-classifier"
    }
    classifier = pipeline("image-classification", model=model_mapping[model])
    img = input_image
    result = classifier(img)
    #Sort the perccentage confident from highest to lowest
    highest_confidence_result = sorted(result, key=lambda x: x['score'], reverse=True)[0]
    # Format the score as a percentage and combine it with the label
    output = f"{highest_confidence_result['score']*100:.2f}% confident : {highest_confidence_result['label']}"
    return output

# Define the demo tab
with gr.Blocks() as demo:
  with gr.Tab("Demo"):
    image_input = gr.Image(type='pil')
    image_output = gr.Image()
    demo_button = gr.Button("Generate")

  with gr.Tab("Text2Image"):
    SD_text_input = gr.Textbox(lines=5, label="Prompt")
    SD_text2_input = gr.Textbox(lines=5, label="Negative Prompt")
    Slider_input = gr.Slider(0, 100, label="Strength")
    SD_output = gr.Image()
    SD_button = gr.Button("Generate")

  with gr.Tab("Text2Image"):
    option_input = gr.Dropdown(["resnet50", "vit-base-patch16-224", "vit-age-classifier", "nsfw image classification"], label="Model")
    t2i_input = gr.Image(label="Image")
    t2i_output = gr.Image()
    t2i_button = gr.Button("Generate")

    demo_button.click(demo_tab, inputs=image_input, outputs=image_output)
    SD_button.click(generate_image, inputs=[SD_text_input,SD_text2_input,Slider_input], outputs=SD_output)
    t2i_button.click(image_classify, inputs=[option_input,t2i_input], outputs=t2i_button)

demo.launch(debug=True)