Nick088 commited on
Commit
244f082
1 Parent(s): 31328c7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +119 -24
app.py CHANGED
@@ -2,46 +2,141 @@ import gradio as gr
2
  import torch
3
  from transformers import T5Tokenizer, T5ForConditionalGeneration
4
 
5
- if torch.cuda.is_available():
6
- device = "cuda"
7
- print("Using GPU")
8
- else:
9
- device = "cpu"
10
- print("Using CPU")
11
-
12
- tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-small")
13
- model = T5ForConditionalGeneration.from_pretrained("roborovski/superprompt-v1", torch_dtype=torch.float16)
14
-
15
- model.to(device)
16
 
 
 
17
 
18
  def generate(
19
- prompt, history, max_new_tokens=512, repetition_penalty=1.2, temperature=0.5, top_p=1, top_k=1, seed=42
 
 
 
 
 
 
 
 
 
20
  ):
21
-
 
 
 
 
 
 
 
 
 
 
 
22
  input_text = f"{prompt}, {history}"
23
  input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)
24
- outputs = model.generate(input_ids, max_new_tokens=max_new_tokens, repetition_penalty=repetition_penalty, do_sample=True, temperature=temperature, top_p=top_p, top_k=top_k)
 
 
 
 
 
 
 
 
 
 
 
25
  better_prompt = tokenizer.decode(outputs[0])
26
  return better_prompt
27
 
28
- additional_inputs=[
29
- gr.Slider(value=512, minimum=250, maximum=512, step=1, interactive=True, label="Max New Tokens", info="The maximum numbers of new tokens, controls how long is the output"),
30
- gr.Slider(value=1.2, minimum=0, maximum=2, step=0.05, interactive=True, label="Repetition Penalty", info="Penalize repeated tokens, making the AI repeat less itself"),
31
- gr.Slider(value=0.5, minimum=0, maximum=1, step=0.05, interactive=True, label="Temperature", info="Higher values produce more diverse outputs"),
32
- gr.Slider(value=1, minimum=0, maximum=2, step=0.05, interactive=True, label="Top P", info="Higher values sample more low-probability tokens"),
33
- gr.Slider(value=1, minimum=1, maximum=100, step=1, interactive=True, label="Top K", info="Higher k means more diverse outputs by considering a range of tokens"),
34
- gr.Number(value=42, interactive=True, label="Seed", info="A starting point to initiate the generation process"),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  ]
36
 
37
- examples=[["Expand the following prompt to add more detail: A storefront with 'Text to Image' written on it.", None, None ]]
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  gr.ChatInterface(
40
  fn=generate,
41
- chatbot=gr.Chatbot(show_label=False, show_share_button=False, show_copy_button=True, likeable=True, layout="panel"),
 
 
42
  additional_inputs=additional_inputs,
43
  title="SuperPrompt-v1",
44
- description="Make your prompts more detailed! Especially for AI Art!!!",
45
  examples=examples,
46
  concurrency_limit=20,
47
  ).launch(show_api=False)
 
2
  import torch
3
  from transformers import T5Tokenizer, T5ForConditionalGeneration
4
 
5
+ def load_model(model_path, dtype):
6
+ if dtype == "fp32":
7
+ torch_dtype = torch.float32
8
+ elif dtype == "fp16":
9
+ torch_dtype = torch.float16
10
+ else:
11
+ raise ValueError("Invalid dtype. Only 'fp32' or 'fp16' are supported.")
 
 
 
 
12
 
13
+ model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype=torch_dtype)
14
+ return model
15
 
16
  def generate(
17
+ prompt,
18
+ history,
19
+ max_new_tokens,
20
+ repetition_penalty,
21
+ temperature,
22
+ top_p,
23
+ top_k,
24
+ seed,
25
+ model_path="roborovski/superprompt-v1",
26
+ dtype="fp16",
27
  ):
28
+ tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-small")
29
+ model = load_model(model_path, dtype)
30
+
31
+ if torch.cuda.is_available():
32
+ device = "cuda"
33
+ print("Using GPU")
34
+ else:
35
+ device = "cpu"
36
+ print("Using CPU")
37
+
38
+ model.to(device)
39
+
40
  input_text = f"{prompt}, {history}"
41
  input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)
42
+
43
+ torch.manual_seed(seed)
44
+ outputs = model.generate(
45
+ input_ids,
46
+ max_new_tokens=max_new_tokens,
47
+ repetition_penalty=repetition_penalty,
48
+ do_sample=True,
49
+ temperature=temperature,
50
+ top_p=top_p,
51
+ top_k=top_k,
52
+ )
53
+
54
  better_prompt = tokenizer.decode(outputs[0])
55
  return better_prompt
56
 
57
+ additional_inputs = [
58
+ gr.Slider(
59
+ value=512,
60
+ minimum=250,
61
+ maximum=512,
62
+ step=1,
63
+ interactive=True,
64
+ label="Max New Tokens",
65
+ info="The maximum numbers of new tokens, controls how long is the output",
66
+ ),
67
+ gr.Slider(
68
+ value=1.2,
69
+ minimum=0,
70
+ maximum=2,
71
+ step=0.05,
72
+ interactive=True,
73
+ label="Repetition Penalty",
74
+ info="Penalize repeated tokens, making the AI repeat less itself",
75
+ ),
76
+ gr.Slider(
77
+ value=0.5,
78
+ minimum=0,
79
+ maximum=1,
80
+ step=0.05,
81
+ interactive=True,
82
+ label="Temperature",
83
+ info="Higher values produce more diverse outputs",
84
+ ),
85
+ gr.Slider(
86
+ value=1,
87
+ minimum=0,
88
+ maximum=2,
89
+ step=0.05,
90
+ interactive=True,
91
+ label="Top P",
92
+ info="Higher values sample more low-probability tokens",
93
+ ),
94
+ gr.Slider(
95
+ value=1,
96
+ minimum=1,
97
+ maximum=100,
98
+ step=1,
99
+ interactive=True,
100
+ label="Top K",
101
+ info="Higher k means more diverse outputs by considering a range of tokens",
102
+ ),
103
+ gr.Number(
104
+ value=42,
105
+ interactive=True,
106
+ label="Seed",
107
+ info="A starting point to initiate the generation process",
108
+ ),
109
+ gr.Radio(
110
+ choices=["fp32", "fp16"],
111
+ value="fp16",
112
+ label="Model Precision",
113
+ info="Select the precision of the model: fp32 or fp16",
114
+ ),
115
  ]
116
 
117
+ examples = [
118
+ [
119
+ "Expand the following prompt to add more detail: A storefront with 'Text to Image' written on it.",
120
+ None,
121
+ None,
122
+ None,
123
+ None,
124
+ None,
125
+ None,
126
+ None,
127
+ "roborovski/superprompt-v1",
128
+ "fp16",
129
+ ]
130
+ ]
131
 
132
  gr.ChatInterface(
133
  fn=generate,
134
+ chatbot=gr.Chatbot(
135
+ show_label=False, show_share_button=False, show_copy_button=True, likeable=True, layout="panel"
136
+ ),
137
  additional_inputs=additional_inputs,
138
  title="SuperPrompt-v1",
139
+ description="Make your prompts more detailed!",
140
  examples=examples,
141
  concurrency_limit=20,
142
  ).launch(show_api=False)