Nick088 commited on
Commit
f6bfd01
1 Parent(s): 427cf92

testing model precision type option

Browse files
Files changed (1) hide show
  1. app.py +7 -2
app.py CHANGED
@@ -4,7 +4,7 @@ import random
4
  from transformers import T5Tokenizer, T5ForConditionalGeneration
5
 
6
  tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-small")
7
- model = T5ForConditionalGeneration.from_pretrained("roborovski/superprompt-v1", torch_dtype=torch.float16)
8
 
9
  if torch.cuda.is_available():
10
  device = "cuda"
@@ -24,8 +24,11 @@ def generate(
24
  top_p,
25
  top_k,
26
  seed,
 
27
  ):
28
 
 
 
29
  input_text = f"{system_prompt}, {prompt}"
30
  input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)
31
 
@@ -49,6 +52,9 @@ def generate(
49
  better_prompt = better_prompt.replace("<pad»", "").replace("</s>", "")
50
  return better_prompt
51
 
 
 
 
52
  prompt = gr.Textbox(label="Prompt", interactive=True)
53
 
54
  system_prompt = gr.Textbox(label="System Prompt", interactive=True)
@@ -65,7 +71,6 @@ top_k = gr.Slider(value=1, minimum=1, maximum=100, step=1, interactive=True, lab
65
 
66
  seed = gr.Number(value=42, interactive=True, label="Seed", info="A starting point to initiate the generation process, put 0 for a random one")
67
 
68
-
69
  examples = [
70
  [
71
  "A storefront with 'Text to Image' written on it.",
 
4
  from transformers import T5Tokenizer, T5ForConditionalGeneration
5
 
6
  tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-small")
7
+
8
 
9
  if torch.cuda.is_available():
10
  device = "cuda"
 
24
  top_p,
25
  top_k,
26
  seed,
27
+ precision_model
28
  ):
29
 
30
+ model = T5ForConditionalGeneration.from_pretrained("roborovski/superprompt-v1", torch_dtype=precision_model)
31
+
32
  input_text = f"{system_prompt}, {prompt}"
33
  input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)
34
 
 
52
  better_prompt = better_prompt.replace("<pad»", "").replace("</s>", "")
53
  return better_prompt
54
 
55
+
56
+ precision_model = gr.Radio([('fp32', torch.float32), ('fp16', toch.float16)], label="Model Precision Type", info="fp32 is more precised but slower, fp16 is faster and less resource consuming but less pricse")
57
+
58
  prompt = gr.Textbox(label="Prompt", interactive=True)
59
 
60
  system_prompt = gr.Textbox(label="System Prompt", interactive=True)
 
71
 
72
  seed = gr.Number(value=42, interactive=True, label="Seed", info="A starting point to initiate the generation process, put 0 for a random one")
73
 
 
74
  examples = [
75
  [
76
  "A storefront with 'Text to Image' written on it.",