Benjamin Gonzalez commited on
Commit
4d07925
·
1 Parent(s): ff04433

fix token length

Browse files
Files changed (1) hide show
  1. app.py +8 -6
app.py CHANGED
@@ -2,18 +2,21 @@ import torch
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import gradio as gr
4
 
 
 
 
5
  tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)
6
  model = AutoModelForCausalLM.from_pretrained(
7
  "microsoft/phi-2",
8
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
9
- device_map="cuda" if torch.cuda.is_available() else "cpu",
10
  trust_remote_code=True,
11
  )
12
 
13
 
14
  def generate(prompt, length):
15
  inputs = tokenizer(prompt, return_tensors="pt", return_attention_mask=False)
16
- outputs = model.generate(**inputs, max_length=length if length >= len(inputs) else len(inputs))
 
17
  return tokenizer.batch_decode(outputs)[0]
18
 
19
 
@@ -24,7 +27,7 @@ demo = gr.Interface(
24
  label="prompt",
25
  value="Write a detailed analogy between mathematics and a lighthouse.",
26
  ),
27
- gr.Number(value=100, label="max length", maximum=1000),
28
  ],
29
  outputs="text",
30
  examples=[
@@ -41,12 +44,11 @@ demo = gr.Interface(
41
  150,
42
  ],
43
  [
44
- '''```python
45
- def print_prime(n):
46
  """
47
  Print all primes between 1 and n
48
  """\n''',
49
- 125,
50
  ],
51
  ],
52
  title="Microsoft Phi-2",
 
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import gradio as gr
4
 
5
+ if torch.cuda.is_available():
6
+ torch.set_default_device("cuda")
7
+
8
  tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)
9
  model = AutoModelForCausalLM.from_pretrained(
10
  "microsoft/phi-2",
11
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
 
12
  trust_remote_code=True,
13
  )
14
 
15
 
16
  def generate(prompt, length):
17
  inputs = tokenizer(prompt, return_tensors="pt", return_attention_mask=False)
18
+ input_token_len = len(inputs.tokens())
19
+ outputs = model.generate(**inputs, max_length=length if length >= input_token_len else input_token_len
20
  return tokenizer.batch_decode(outputs)[0]
21
 
22
 
 
27
  label="prompt",
28
  value="Write a detailed analogy between mathematics and a lighthouse.",
29
  ),
30
+ gr.Number(value=100, label="max length", maximum=500),
31
  ],
32
  outputs="text",
33
  examples=[
 
44
  150,
45
  ],
46
  [
47
+ '''def print_prime(n):
 
48
  """
49
  Print all primes between 1 and n
50
  """\n''',
51
+ 100,
52
  ],
53
  ],
54
  title="Microsoft Phi-2",