Roaoch commited on
Commit
5340fbe
·
1 Parent(s): ee98053

Config Changes

Browse files
Files changed (2) hide show
  1. main.py +1 -2
  2. src/cyberclaasic.py +1 -7
main.py CHANGED
@@ -8,8 +8,7 @@ warnings.simplefilter("ignore", UserWarning)
8
  app = FastAPI()
9
 
10
  text_generator = CyberClassic(
11
- min_length=30,
12
- max_length=50,
13
  startings_path='./startings.csv'
14
  )
15
 
 
8
  app = FastAPI()
9
 
10
  text_generator = CyberClassic(
11
+ max_length=60,
 
12
  startings_path='./startings.csv'
13
  )
14
 
src/cyberclaasic.py CHANGED
@@ -13,12 +13,10 @@ import numpy as np
13
  class CyberClassic(torch.nn.Module):
14
  def __init__(
15
  self,
16
- min_length: int,
17
  max_length: int,
18
  startings_path: str
19
  ) -> None:
20
  super().__init__()
21
- self.min_length = min_length
22
  self.max_length = max_length
23
  self.startings = pd.read_csv(startings_path)
24
 
@@ -26,17 +24,13 @@ class CyberClassic(torch.nn.Module):
26
  self.generator: GPT2LMHeadModel = AutoModelForCausalLM.from_pretrained('Roaoch/CyberClassic-Generator')
27
  self.discriminator = DiscriminatorModel.from_pretrained('Roaoch/CyberClassic-Discriminator')
28
 
29
- self.tokenizer.pad_token = self.tokenizer.eos_token
30
  self.generation_config = GenerationConfig(
31
  max_new_tokens=max_length,
32
  num_beams=6,
33
  early_stopping=True,
34
  do_sample=True,
35
- # top_k=60,
36
- # penalty_alpha=0.6,
37
- # top_p=0.95,
38
  eos_token_id=self.tokenizer.eos_token_id,
39
- pad_token=self.tokenizer.pad_token_id
40
  )
41
 
42
  def encode(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
 
13
  class CyberClassic(torch.nn.Module):
14
  def __init__(
15
  self,
 
16
  max_length: int,
17
  startings_path: str
18
  ) -> None:
19
  super().__init__()
 
20
  self.max_length = max_length
21
  self.startings = pd.read_csv(startings_path)
22
 
 
24
  self.generator: GPT2LMHeadModel = AutoModelForCausalLM.from_pretrained('Roaoch/CyberClassic-Generator')
25
  self.discriminator = DiscriminatorModel.from_pretrained('Roaoch/CyberClassic-Discriminator')
26
 
 
27
  self.generation_config = GenerationConfig(
28
  max_new_tokens=max_length,
29
  num_beams=6,
30
  early_stopping=True,
31
  do_sample=True,
 
 
 
32
  eos_token_id=self.tokenizer.eos_token_id,
33
+ pad_token_id=self.tokenizer.pad_token_id
34
  )
35
 
36
  def encode(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: