Spaces:
Sleeping
Sleeping
Roaoch
commited on
Commit
·
5340fbe
1
Parent(s):
ee98053
Config Changes
Browse files- main.py +1 -2
- src/cyberclaasic.py +1 -7
main.py
CHANGED
@@ -8,8 +8,7 @@ warnings.simplefilter("ignore", UserWarning)
|
|
8 |
app = FastAPI()
|
9 |
|
10 |
text_generator = CyberClassic(
|
11 |
-
|
12 |
-
max_length=50,
|
13 |
startings_path='./startings.csv'
|
14 |
)
|
15 |
|
|
|
8 |
app = FastAPI()
|
9 |
|
10 |
text_generator = CyberClassic(
|
11 |
+
max_length=60,
|
|
|
12 |
startings_path='./startings.csv'
|
13 |
)
|
14 |
|
src/cyberclaasic.py
CHANGED
@@ -13,12 +13,10 @@ import numpy as np
|
|
13 |
class CyberClassic(torch.nn.Module):
|
14 |
def __init__(
|
15 |
self,
|
16 |
-
min_length: int,
|
17 |
max_length: int,
|
18 |
startings_path: str
|
19 |
) -> None:
|
20 |
super().__init__()
|
21 |
-
self.min_length = min_length
|
22 |
self.max_length = max_length
|
23 |
self.startings = pd.read_csv(startings_path)
|
24 |
|
@@ -26,17 +24,13 @@ class CyberClassic(torch.nn.Module):
|
|
26 |
self.generator: GPT2LMHeadModel = AutoModelForCausalLM.from_pretrained('Roaoch/CyberClassic-Generator')
|
27 |
self.discriminator = DiscriminatorModel.from_pretrained('Roaoch/CyberClassic-Discriminator')
|
28 |
|
29 |
-
self.tokenizer.pad_token = self.tokenizer.eos_token
|
30 |
self.generation_config = GenerationConfig(
|
31 |
max_new_tokens=max_length,
|
32 |
num_beams=6,
|
33 |
early_stopping=True,
|
34 |
do_sample=True,
|
35 |
-
# top_k=60,
|
36 |
-
# penalty_alpha=0.6,
|
37 |
-
# top_p=0.95,
|
38 |
eos_token_id=self.tokenizer.eos_token_id,
|
39 |
-
|
40 |
)
|
41 |
|
42 |
def encode(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
|
|
|
13 |
class CyberClassic(torch.nn.Module):
|
14 |
def __init__(
|
15 |
self,
|
|
|
16 |
max_length: int,
|
17 |
startings_path: str
|
18 |
) -> None:
|
19 |
super().__init__()
|
|
|
20 |
self.max_length = max_length
|
21 |
self.startings = pd.read_csv(startings_path)
|
22 |
|
|
|
24 |
self.generator: GPT2LMHeadModel = AutoModelForCausalLM.from_pretrained('Roaoch/CyberClassic-Generator')
|
25 |
self.discriminator = DiscriminatorModel.from_pretrained('Roaoch/CyberClassic-Discriminator')
|
26 |
|
|
|
27 |
self.generation_config = GenerationConfig(
|
28 |
max_new_tokens=max_length,
|
29 |
num_beams=6,
|
30 |
early_stopping=True,
|
31 |
do_sample=True,
|
|
|
|
|
|
|
32 |
eos_token_id=self.tokenizer.eos_token_id,
|
33 |
+
pad_token_id=self.tokenizer.pad_token_id
|
34 |
)
|
35 |
|
36 |
def encode(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
|