Kumarkishalaya commited on
Commit
0a8b7e2
1 Parent(s): 1062c1a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -8
app.py CHANGED
@@ -11,17 +11,34 @@ trained_model.to(device)
11
  untrained_model.to(device)
12
 
13
  def generate(commentary_text, max_length, temperature):
14
- if temperature <= 0:
15
- return "Error: Temperature must be a strictly positive float.", "Error: Temperature must be a strictly positive float."
16
-
17
- # Generate text using the finetuned model
18
- input_ids = trained_tokenizer(commentary_text, return_tensors="pt").input_ids.to(device)
19
- trained_output = trained_model.generate(input_ids, max_length=max_length, num_beams=5, do_sample=True, temperature=temperature)
 
 
 
 
 
 
 
20
  trained_text = trained_tokenizer.decode(trained_output[0], skip_special_tokens=True)
21
 
22
  # Generate text using the base model
23
- input_ids = untrained_tokenizer(commentary_text, return_tensors="pt").input_ids.to(device)
24
- untrained_output = untrained_model.generate(input_ids, max_length=max_length, num_beams=5, do_sample=True,temperature=temperature)
 
 
 
 
 
 
 
 
 
 
25
  untrained_text = untrained_tokenizer.decode(untrained_output[0], skip_special_tokens=True)
26
 
27
  return trained_text, untrained_text
 
11
  untrained_model.to(device)
12
 
13
  def generate(commentary_text, max_length, temperature):
14
+ # Generate text using the finetuned model
15
+ inputs = trained_tokenizer(commentary_text, return_tensors="pt", padding=True)
16
+ input_ids = inputs.input_ids.to(device)
17
+ attention_mask = inputs.attention_mask.to(device)
18
+ trained_output = trained_model.generate(
19
+ input_ids,
20
+ max_length=max_length,
21
+ num_beams=5,
22
+ do_sample=True,
23
+ temperature=temperature,
24
+ attention_mask=attention_mask,
25
+ pad_token_id=trained_tokenizer.eos_token_id
26
+ )
27
  trained_text = trained_tokenizer.decode(trained_output[0], skip_special_tokens=True)
28
 
29
  # Generate text using the base model
30
+ inputs = untrained_tokenizer(commentary_text, return_tensors="pt", padding=True)
31
+ input_ids = inputs.input_ids.to(device)
32
+ attention_mask = inputs.attention_mask.to(device)
33
+ untrained_output = untrained_model.generate(
34
+ input_ids,
35
+ max_length=max_length,
36
+ num_beams=5,
37
+ do_sample=True,
38
+ temperature=temperature,
39
+ attention_mask=attention_mask,
40
+ pad_token_id=untrained_tokenizer.eos_token_id
41
+ )
42
  untrained_text = untrained_tokenizer.decode(untrained_output[0], skip_special_tokens=True)
43
 
44
  return trained_text, untrained_text