Karzan commited on
Commit
9d0dd09
·
verified ·
1 Parent(s): d0f8d63

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -4
app.py CHANGED
@@ -11,15 +11,59 @@ model = GemmaForCausalLM.from_pretrained('google/gemma-2-9b',
11
  # pipe = pipeline('text-generation', model=model,tokenizer = tokenizer)
12
 
13
  @spaces.GPU(duration=120)
14
- def generate(prompt):
 
 
 
 
 
 
 
15
  input_ids = tokenizer(prompt, return_tensors="pt").to("cuda")
16
- outputs = model.generate(**input_ids)
17
  return tokenizer.decode(outputs[0]);
18
  # return pipe(prompt)[0]['generated_text']
19
 
20
  gr.Interface(
21
  fn=generate,
22
- inputs=gr.Text(),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  outputs="text",
24
- examples=[['Write me a poem about Machine Learning.']]
 
25
  ).launch()
 
11
  # pipe = pipeline('text-generation', model=model,tokenizer = tokenizer)
12
 
13
  @spaces.GPU(duration=120)
14
+ def generate(
15
+ message: str,
16
+ max_new_tokens: int = 1024,
17
+ temperature: float = 0.6,
18
+ top_p: float = 0.9,
19
+ top_k: int = 50,
20
+ repetition_penalty: float = 1.2,
21
+ ):
22
  input_ids = tokenizer(prompt, return_tensors="pt").to("cuda")
23
+ outputs = model.generate(**input_ids,top_p=top_p,max_new_tokens=max_new_tokens,top_k=top_k,repetition_penalty=repetition_penalty,temperature=temperature)
24
  return tokenizer.decode(outputs[0]);
25
  # return pipe(prompt)[0]['generated_text']
26
 
27
  gr.Interface(
28
  fn=generate,
29
+ inputs=[
30
+ gr.Text(),
31
+ gr.Slider(
32
+ label="Max new tokens",
33
+ minimum=1,
34
+ maximum=MAX_MAX_NEW_TOKENS,
35
+ step=1,
36
+ value=DEFAULT_MAX_NEW_TOKENS,
37
+ ),
38
+ gr.Slider(
39
+ label="Temperature",
40
+ minimum=0.1,
41
+ maximum=4.0,
42
+ step=0.1,
43
+ value=0.6,
44
+ ),
45
+ gr.Slider(
46
+ label="Top-p (nucleus sampling)",
47
+ minimum=0.05,
48
+ maximum=1.0,
49
+ step=0.05,
50
+ value=0.9,
51
+ ),
52
+ gr.Slider(
53
+ label="Top-k",
54
+ minimum=1,
55
+ maximum=1000,
56
+ step=1,
57
+ value=50,
58
+ ),
59
+ gr.Slider(
60
+ label="Repetition penalty",
61
+ minimum=1.0,
62
+ maximum=2.0,
63
+ step=0.05,
64
+ value=1.2,
65
+ ),],
66
  outputs="text",
67
+ examples=[['Write me a poem about Machine Learning.']],
68
+
69
  ).launch()