Spaces:
Sleeping
Sleeping
NotShrirang
commited on
Commit
·
f4e648b
1
Parent(s):
1510533
feat: add application file
Browse files- .gitignore +162 -0
- README.md +191 -14
- app.py +100 -0
- config/config.json +287 -0
- config/example-config.json +12 -0
- config/harpoon_config.json +287 -0
- config/shakespearean_config.json +147 -0
- core/__init__.py +0 -0
- core/layers/__init__.py +1 -0
- core/layers/layers.py +207 -0
- core/models/__init__.py +1 -0
- core/models/gpt.py +64 -0
- core/models/llama.py +51 -0
- core/tokenizers/__init__.py +1 -0
- core/tokenizers/tokenizer.py +115 -0
- core/utils/__init__.py +1 -0
- core/utils/gptutils.py +71 -0
- core/utils/preprocessing.py +75 -0
- requirements.txt +96 -0
.gitignore
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
weights/
|
7 |
+
|
8 |
+
# C extensions
|
9 |
+
*.so
|
10 |
+
|
11 |
+
# Distribution / packaging
|
12 |
+
.Python
|
13 |
+
build/
|
14 |
+
develop-eggs/
|
15 |
+
dist/
|
16 |
+
downloads/
|
17 |
+
eggs/
|
18 |
+
.eggs/
|
19 |
+
lib/
|
20 |
+
lib64/
|
21 |
+
parts/
|
22 |
+
sdist/
|
23 |
+
var/
|
24 |
+
wheels/
|
25 |
+
share/python-wheels/
|
26 |
+
*.egg-info/
|
27 |
+
.installed.cfg
|
28 |
+
*.egg
|
29 |
+
MANIFEST
|
30 |
+
|
31 |
+
# PyInstaller
|
32 |
+
# Usually these files are written by a python script from a template
|
33 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
34 |
+
*.manifest
|
35 |
+
*.spec
|
36 |
+
|
37 |
+
# Installer logs
|
38 |
+
pip-log.txt
|
39 |
+
pip-delete-this-directory.txt
|
40 |
+
|
41 |
+
# Unit test / coverage reports
|
42 |
+
htmlcov/
|
43 |
+
.tox/
|
44 |
+
.nox/
|
45 |
+
.coverage
|
46 |
+
.coverage.*
|
47 |
+
.cache
|
48 |
+
nosetests.xml
|
49 |
+
coverage.xml
|
50 |
+
*.cover
|
51 |
+
*.py,cover
|
52 |
+
.hypothesis/
|
53 |
+
.pytest_cache/
|
54 |
+
cover/
|
55 |
+
|
56 |
+
# Translations
|
57 |
+
*.mo
|
58 |
+
*.pot
|
59 |
+
|
60 |
+
# Django stuff:
|
61 |
+
*.log
|
62 |
+
local_settings.py
|
63 |
+
db.sqlite3
|
64 |
+
db.sqlite3-journal
|
65 |
+
|
66 |
+
# Flask stuff:
|
67 |
+
instance/
|
68 |
+
.webassets-cache
|
69 |
+
|
70 |
+
# Scrapy stuff:
|
71 |
+
.scrapy
|
72 |
+
|
73 |
+
# Sphinx documentation
|
74 |
+
docs/_build/
|
75 |
+
|
76 |
+
# PyBuilder
|
77 |
+
.pybuilder/
|
78 |
+
target/
|
79 |
+
|
80 |
+
# Jupyter Notebook
|
81 |
+
.ipynb_checkpoints
|
82 |
+
|
83 |
+
# IPython
|
84 |
+
profile_default/
|
85 |
+
ipython_config.py
|
86 |
+
|
87 |
+
# pyenv
|
88 |
+
# For a library or package, you might want to ignore these files since the code is
|
89 |
+
# intended to run in multiple environments; otherwise, check them in:
|
90 |
+
# .python-version
|
91 |
+
|
92 |
+
# pipenv
|
93 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
94 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
95 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
96 |
+
# install all needed dependencies.
|
97 |
+
#Pipfile.lock
|
98 |
+
|
99 |
+
# poetry
|
100 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
101 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
102 |
+
# commonly ignored for libraries.
|
103 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
104 |
+
#poetry.lock
|
105 |
+
|
106 |
+
# pdm
|
107 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
108 |
+
#pdm.lock
|
109 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
110 |
+
# in version control.
|
111 |
+
# https://pdm.fming.dev/#use-with-ide
|
112 |
+
.pdm.toml
|
113 |
+
|
114 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
115 |
+
__pypackages__/
|
116 |
+
|
117 |
+
# Celery stuff
|
118 |
+
celerybeat-schedule
|
119 |
+
celerybeat.pid
|
120 |
+
|
121 |
+
# SageMath parsed files
|
122 |
+
*.sage.py
|
123 |
+
|
124 |
+
# Environments
|
125 |
+
.env
|
126 |
+
.venv
|
127 |
+
env/
|
128 |
+
venv/
|
129 |
+
ENV/
|
130 |
+
env.bak/
|
131 |
+
venv.bak/
|
132 |
+
|
133 |
+
# Spyder project settings
|
134 |
+
.spyderproject
|
135 |
+
.spyproject
|
136 |
+
|
137 |
+
# Rope project settings
|
138 |
+
.ropeproject
|
139 |
+
|
140 |
+
# mkdocs documentation
|
141 |
+
/site
|
142 |
+
|
143 |
+
# mypy
|
144 |
+
.mypy_cache/
|
145 |
+
.dmypy.json
|
146 |
+
dmypy.json
|
147 |
+
|
148 |
+
# Pyre type checker
|
149 |
+
.pyre/
|
150 |
+
|
151 |
+
# pytype static type analyzer
|
152 |
+
.pytype/
|
153 |
+
|
154 |
+
# Cython debug symbols
|
155 |
+
cython_debug/
|
156 |
+
|
157 |
+
# PyCharm
|
158 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
159 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
160 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
161 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
162 |
+
#.idea/
|
README.md
CHANGED
@@ -1,14 +1,191 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
license
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
![QuillGPT-cropped-removebg-preview](https://github.com/NotShrirang/QuillGPT/assets/85283622/2e63d8ce-24f8-4bf0-835a-0c621f1d7400)
|
2 |
+
|
3 |
+
# QuillGPT
|
4 |
+
|
5 |
+
![GitHub stars](https://img.shields.io/github/stars/NotShrirang/GPT-From-Scratch?style=social)
|
6 |
+
![GitHub forks](https://img.shields.io/github/forks/NotShrirang/GPT-From-Scratch?style=social)
|
7 |
+
![GitHub commits](https://img.shields.io/github/commit-activity/t/NotShrirang/QuillGPT)
|
8 |
+
![GitHub issues](https://img.shields.io/github/issues/NotShrirang/GPT-From-Scratch)
|
9 |
+
![GitHub pull requests](https://img.shields.io/github/issues-pr/NotShrirang/GPT-From-Scratch)
|
10 |
+
![GitHub](https://img.shields.io/github/license/NotShrirang/GPT-From-Scratch)
|
11 |
+
![GitHub last commit](https://img.shields.io/github/last-commit/NotShrirang/GPT-From-Scratch)
|
12 |
+
![GitHub repo size](https://img.shields.io/github/repo-size/NotShrirang/GPT-From-Scratch)
|
13 |
+
![Streamlit Playground](https://img.shields.io/badge/Streamlit%20App-red?style=flat-rounded-square&logo=streamlit&labelColor=white)
|
14 |
+
![Docker Container](https://img.shields.io/badge/docker-blue?style=flat-rounded-square&logo=docker&labelColor=white)
|
15 |
+
|
16 |
+
QuillGPT is an implementation of the GPT decoder block based on the architecture from [Attention is All You Need](https://arxiv.org/abs/1706.03762) paper by Vaswani et. al. implemented in PyTorch. Additionally, this repository contains two pre-trained models—Shakespearean GPT and Harpoon GPT—along with their trained weights. For ease of experimentation and deployment, a Streamlit Playground is provided for interactive exploration of these models and FastAPI microservice implemented with Docker containerization for scalable deployment. You'll also find Python scripts for training new GPT models and performing inference on them, along with notebooks showcasing trained models. To facilitate text encoding and decoding, a simple tokenizer is implemented. Explore QuillGPT to utilize these tools and enhance your natural language processing projects!
|
17 |
+
|
18 |
+
## Table of Contents
|
19 |
+
|
20 |
+
- [Models](#models)
|
21 |
+
- [Getting Started](#getting-started)
|
22 |
+
- [Installation](#installation)
|
23 |
+
- [Streamlit Playground](#streamlit-playground)
|
24 |
+
- [FastAPI Microservice](#for-running-fastapi-microservice)
|
25 |
+
- [Running Docker Container](#for-using-containerized-version)
|
26 |
+
- [Usage](#usage)
|
27 |
+
- [Training the GPT Model](#training-the-gpt-model)
|
28 |
+
- [Using the Trained Model for Inference](#for-inference)
|
29 |
+
- [Explanation](#explanation)
|
30 |
+
- [Decoder Block](#the-decoder-block)
|
31 |
+
- [Input Embeddings](#input-embeddings)
|
32 |
+
- [Positional Embeddings](#positional-embeddings)
|
33 |
+
- [Self-Attention](#self-attention)
|
34 |
+
- [License](#license)
|
35 |
+
- [Contributing](#contributing)
|
36 |
+
- [Support](#support)
|
37 |
+
|
38 |
+
## <div align="center">Models</div>
|
39 |
+
|
40 |
+
There are two pre-trained models and weights included in this repository.
|
41 |
+
|
42 |
+
| Feature | Shakespearean GPT | Harpoon GPT |
|
43 |
+
| ------------------------------ | --------------------------------------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------- |
|
44 |
+
| **Parameters** | 10.7 M | 226 M |
|
45 |
+
| **Weights** | [Weights](https://github.com/NotShrirang/GPT-From-Scratch/blob/main/weights/GPT_model_char.pt) | [Weights](https://www.dropbox.com/scl/fi/vi5z3s17otn0jf7sr40po/Harpoon_Corpus_GPT_model.pt?rlkey=r7oppeslusv736fzmi908le95&st=wak0uf2t&dl=0) |
|
46 |
+
| **Model Config** | [Config](https://github.com/NotShrirang/GPT-From-Scratch/blob/main/config/shakespearean_config.json) | [Config](https://github.com/NotShrirang/GPT-From-Scratch/blob/main/config/harpoon_config.json) |
|
47 |
+
| **Training Data** | Text from Shakespearean plays ([input.txt](https://github.com/NotShrirang/GPT-From-Scratch/blob/main/data/input.txt)) | Random text from books ([corpus.txt](https://github.com/NotShrirang/GPT-From-Scratch/blob/main/data/corpus.txt)) |
|
48 |
+
| **Embedding Type** | Character embeddings | Character embeddings |
|
49 |
+
| **Training Notebook** | [Notebook](https://github.com/NotShrirang/GPT-From-Scratch/blob/main/notebooks/GPT_From_Scratch_CharEmbeddings.ipynb) | [Notebook](https://github.com/NotShrirang/GPT-From-Scratch/blob/main/notebooks/GPT_From_Scratch_with_1024_char_embd.ipynb) |
|
50 |
+
| **Hardware** | NVIDIA T4 | NVIDIA A100 |
|
51 |
+
| **Training & Validation Loss** | ![loss](https://github.com/user-attachments/assets/df89c1f6-d89a-4a3a-8340-edcf7416878c) | ![loss](https://github.com/user-attachments/assets/76c5e0d1-a53c-4d0d-ac8f-5529ec3a5008) |
|
52 |
+
|
53 |
+
## Getting Started:
|
54 |
+
|
55 |
+
### Installation:
|
56 |
+
|
57 |
+
To run the training and inference scripts, follow these steps:
|
58 |
+
|
59 |
+
1. Clone the repository:
|
60 |
+
|
61 |
+
```sh
|
62 |
+
git clone https://github.com/NotShrirang/GPT-From-Scratch.git
|
63 |
+
cd GPT-From-Scratch
|
64 |
+
```
|
65 |
+
|
66 |
+
2. Install the required packages:
|
67 |
+
|
68 |
+
```sh
|
69 |
+
pip install -r requirements.txt
|
70 |
+
```
|
71 |
+
|
72 |
+
Make sure you download the weights for Harpoon GPT from [here](https://www.dropbox.com/scl/fi/vi5z3s17otn0jf7sr40po/Harpoon_Corpus_GPT_model.pt?rlkey=r7oppeslusv736fzmi908le95&st=wak0uf2t&dl=0) before proceeding!
|
73 |
+
|
74 |
+
### Streamlit Playground:
|
75 |
+
|
76 |
+
It is hosted on Streamlit Cloud Service. You can visit it through the link [here](https://quillgpt.streamlit.app/).
|
77 |
+
|
78 |
+
[![Streamlit Demo](https://github.com/NotShrirang/GPT-From-Scratch/assets/85283622/fa888670-2c44-4f97-a07d-c58473d847d0)](https://quillgpt.streamlit.app/)
|
79 |
+
|
80 |
+
```sh
|
81 |
+
streamlit run app.py
|
82 |
+
```
|
83 |
+
|
84 |
+
### For running FastAPI Microservice:
|
85 |
+
|
86 |
+
```sh
|
87 |
+
python main.py
|
88 |
+
```
|
89 |
+
|
90 |
+
### For using Containerized Version:
|
91 |
+
|
92 |
+
#### Build and Run the Docker Container with bash:
|
93 |
+
|
94 |
+
```sh
|
95 |
+
./run.sh start-dev
|
96 |
+
```
|
97 |
+
|
98 |
+
#### To stop the Docker Container, run the following command:
|
99 |
+
|
100 |
+
```sh
|
101 |
+
./run.sh stop-dev
|
102 |
+
```
|
103 |
+
|
104 |
+
## Usage
|
105 |
+
|
106 |
+
### Training the GPT Model:
|
107 |
+
|
108 |
+
To train the GPT model, follow these steps:
|
109 |
+
|
110 |
+
1. Prepare data. Put the whole text data into single .txt file and save it.
|
111 |
+
2. Write the configurations for transformer and save the file.
|
112 |
+
<br>For example:
|
113 |
+
`json
|
114 |
+
{
|
115 |
+
"data_path": "data/corpus.txt",
|
116 |
+
"vocab_size": 135,
|
117 |
+
"batch_size": 32,
|
118 |
+
"block_size": 256,
|
119 |
+
"max_iters": 3000,
|
120 |
+
"eval_interval": 300,
|
121 |
+
"learning_rate": 3e-5,
|
122 |
+
"eval_iters": 50,
|
123 |
+
"n_embd": 1024,
|
124 |
+
"n_head": 12,
|
125 |
+
"n_layer": 18,
|
126 |
+
"dropout": 0.3,
|
127 |
+
}
|
128 |
+
`
|
129 |
+
|
130 |
+
3. Train model using script `scripts/train_gpt.py`
|
131 |
+
|
132 |
+
```bash
|
133 |
+
python scripts/train_gpt.py \
|
134 |
+
--config_path config/config.json \
|
135 |
+
--data_path data/corpus.txt \
|
136 |
+
--output_dir trained_models
|
137 |
+
```
|
138 |
+
|
139 |
+
(You can change the `config_path`, `data_path` and `output_dir` as per your requirements.)
|
140 |
+
|
141 |
+
4. The trained model will be saved in the `output_dir` specified in the command.
|
142 |
+
|
143 |
+
### For Inference:
|
144 |
+
|
145 |
+
After training, you can use the trained GPT model for text generation. Here's an example of using the trained model for inference:
|
146 |
+
|
147 |
+
```bash
|
148 |
+
python scripts/inference_gpt.py \
|
149 |
+
--config_path config/shakespearean_config.json \
|
150 |
+
--weights_path weights/GPT_model_char.pt \
|
151 |
+
--max_length 500 \
|
152 |
+
--prompt "Once upon a time"
|
153 |
+
```
|
154 |
+
|
155 |
+
## <div align="center">Explanation</div>
|
156 |
+
|
157 |
+
### The Decoder Block:
|
158 |
+
|
159 |
+
<div align="center"><img src="https://github.com/NotShrirang/GPT-From-Scratch/assets/85283622/397049a3-10cc-49b5-8696-f19806b2668e" width=350 alt="Decoder Architecture"/></div>
|
160 |
+
|
161 |
+
The decoder block is a crucial component of the GPT (Generative Pre-trained Transformer) model, it is where GPT actually generates the text. It leverages the self-attention mechanism to process input sequences and generate coherent outputs. Each decoder block consists of multiple layers, including self-attention layers, feed-forward neural networks, and layer normalization. The self-attention layers allow the model to weigh the importance of different words in a sequence, capturing context and dependencies regardless of their positions. This enables the GPT model to generate contextually relevant text.
|
162 |
+
|
163 |
+
### Input Embeddings:
|
164 |
+
|
165 |
+
<div align="center">![vector embeddings](https://github.com/NotShrirang/GPT-From-Scratch/assets/85283622/29b4c375-c9f0-47b9-9d34-2a21dfdf0be8)</div>
|
166 |
+
|
167 |
+
Input embeddings play a crucial role in transformer-based models like GPT by transforming input tokens into meaningful numerical representations. These embeddings serve as the initial input for the model, capturing semantic information about the words in the sequence. The process involves mapping each token in the input sequence to a high-dimensional vector space, where similar tokens are positioned closer together. This enables the model to understand the relationships between different words and effectively learn from the input data. The input embeddings are then fed into the subsequent layers of the model for further processing.
|
168 |
+
|
169 |
+
### Positional Embeddings:
|
170 |
+
|
171 |
+
![positional_encoding](https://github.com/NotShrirang/GPT-From-Scratch/assets/85283622/90293fb0-8f20-4dc0-adba-8c31a54ef4f4)
|
172 |
+
|
173 |
+
In addition to input embeddings, positional embeddings are another vital component of transformer architectures such as GPT. Since transformers lack inherent information about the order of tokens in a sequence, positional embeddings are introduced to provide the model with positional information. These embeddings encode the position of each token within the sequence, allowing the model to distinguish between tokens based on their positions. By incorporating positional embeddings, transformers like GPT can effectively capture the sequential nature of data and generate coherent outputs that maintain the correct order of words in the generated text.
|
174 |
+
|
175 |
+
### Self-Attention:
|
176 |
+
|
177 |
+
![self attention](https://github.com/NotShrirang/GPT-From-Scratch/assets/85283622/a6d785e4-ab00-4da0-a072-791f680d2bb8)
|
178 |
+
|
179 |
+
Self-attention, a fundamental mechanism in transformer-based models like GPT, operates by assigning importance scores to different words in a sequence. This process involves three key steps: calculating attention scores, applying softmax to obtain attention weights, and finally combining these weights with the input embeddings to generate contextually informed representations. At its core, self-attention allows the model to focus more on relevant words while de-emphasizing less important ones, facilitating effective learning of contextual dependencies within the input data. This mechanism is pivotal in capturing long-range dependencies and contextual nuances, enabling transformer models to generate long sequences of text.
|
180 |
+
|
181 |
+
## License
|
182 |
+
|
183 |
+
MIT © [Shrirang Mahajan](https://github.com/NotShrirang)
|
184 |
+
|
185 |
+
## Contributing
|
186 |
+
|
187 |
+
Feel free to submit pull requests, create issues, or spread the word!
|
188 |
+
|
189 |
+
## Support
|
190 |
+
|
191 |
+
Support me by simply starring this repository! ⭐
|
app.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import streamlit as st
|
3 |
+
from colorama import Fore
|
4 |
+
from core.models.gpt import GPTLanguageModel
|
5 |
+
from core.tokenizers.tokenizer import Tokenizer
|
6 |
+
from core.utils.gptutils import hyperparameters, load_data
|
7 |
+
|
8 |
+
st.set_page_config(layout='wide',
|
9 |
+
page_title='QuillGPT',
|
10 |
+
page_icon='🪶',
|
11 |
+
initial_sidebar_state='expanded'
|
12 |
+
)
|
13 |
+
|
14 |
+
def decode_text(input, model: GPTLanguageModel, max_tokens, temperature):
|
15 |
+
for idx in model.generate(idx=input, max_new_tokens=max_tokens, max_seq_length=50, temperature=temperature):
|
16 |
+
text = tokenizer.decode(idx[0].tolist())[-1]
|
17 |
+
yield text
|
18 |
+
|
19 |
+
models = {
|
20 |
+
"Shakespearean GPT": './weights/GPT_model_char.pt',
|
21 |
+
"GPT": './weights/Harpoon_Corpus_GPT_model_word2.pt',
|
22 |
+
}
|
23 |
+
|
24 |
+
st.sidebar.header('QuillGPT')
|
25 |
+
|
26 |
+
st.sidebar.write("This app generates text using a GPT model trained on either the Harpoon corpus or Shakespearean plays.")
|
27 |
+
|
28 |
+
# Select one of the two model
|
29 |
+
model_name = st.sidebar.selectbox('Select a model:', list(models.keys()))
|
30 |
+
if model_name == "GPT":
|
31 |
+
st.title('GPT From Scratch')
|
32 |
+
st.write("This model was trained on the Harpoon corpus.")
|
33 |
+
else:
|
34 |
+
st.title('Shakespearean GPT')
|
35 |
+
st.write("This model was trained on Shakespearean plays.")
|
36 |
+
|
37 |
+
path = models[model_name]
|
38 |
+
|
39 |
+
if model_name == "GPT":
|
40 |
+
config_path = './config/harpoon_config.json'
|
41 |
+
data_path = './data/corpus.txt'
|
42 |
+
name = "Harpoon GPT"
|
43 |
+
tokenizer: Tokenizer = Tokenizer()
|
44 |
+
tokenizer.from_pretrained(config_path)
|
45 |
+
vocab_size = tokenizer.vocab_size
|
46 |
+
(batch_size, block_size, max_iters, eval_interval, learning_rate, device,
|
47 |
+
eval_iters, n_embd, n_head, n_layer, dropout) = hyperparameters(config_path=config_path)
|
48 |
+
|
49 |
+
elif model_name == "Shakespearean GPT":
|
50 |
+
config_path = './config/shakespearean_config.json'
|
51 |
+
data_path = './data/input.txt'
|
52 |
+
name = "Shakespearean GPT"
|
53 |
+
tokenizer: Tokenizer = Tokenizer()
|
54 |
+
tokenizer.from_pretrained(config_path)
|
55 |
+
vocab_size = tokenizer.vocab_size
|
56 |
+
(batch_size, block_size, max_iters, eval_interval, learning_rate, device,
|
57 |
+
eval_iters, n_embd, n_head, n_layer, dropout) = hyperparameters(config_path=config_path)
|
58 |
+
|
59 |
+
|
60 |
+
if model_name == "GPT":
|
61 |
+
input_text = st.text_area(
|
62 |
+
'Enter a prompt:', 'And then Ted said, "'
|
63 |
+
)
|
64 |
+
else:
|
65 |
+
input_text = st.text_area(
|
66 |
+
'Enter a prompt:', 'Write a scene about ROMEO arguing with JULIET. \nROMEO:'
|
67 |
+
)
|
68 |
+
|
69 |
+
temperature = st.sidebar.slider('Temperature:', 0.1, 1.0, 0.5, 0.1)
|
70 |
+
max_tokens = st.sidebar.slider('Max Tokens:', 250, 1000, 500, 50)
|
71 |
+
|
72 |
+
@st.cache_resource
|
73 |
+
def load_model(path):
|
74 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
75 |
+
|
76 |
+
try:
|
77 |
+
model = GPTLanguageModel(
|
78 |
+
vocab_size, n_embd, block_size, n_head, n_layer, dropout, device, name=name
|
79 |
+
).to(device)
|
80 |
+
state_dict = torch.load(
|
81 |
+
path, map_location=device)
|
82 |
+
|
83 |
+
model.load_state_dict(state_dict)
|
84 |
+
return model, device
|
85 |
+
except FileNotFoundError as e:
|
86 |
+
st.error(f"Don't forget to download the model weights from the link in the README.md file.")
|
87 |
+
return None, None
|
88 |
+
|
89 |
+
|
90 |
+
model, device = load_model(path)
|
91 |
+
|
92 |
+
|
93 |
+
if model:
|
94 |
+
if st.button('Generate Text'):
|
95 |
+
prompt = input_text
|
96 |
+
st.subheader(model.name)
|
97 |
+
input = torch.tensor([tokenizer.encode(prompt)], dtype=torch.long, device=device)
|
98 |
+
generated_text = []
|
99 |
+
st.write(f":green[{prompt}]")
|
100 |
+
st.write_stream(decode_text(input, model, max_tokens, temperature))
|
config/config.json
ADDED
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"vocab_size": 135,
|
3 |
+
"batch_size": 32,
|
4 |
+
"block_size": 256,
|
5 |
+
"max_iters": 3000,
|
6 |
+
"eval_interval": 300,
|
7 |
+
"learning_rate": 3e-5,
|
8 |
+
"eval_iters": 50,
|
9 |
+
"n_embd": 1024,
|
10 |
+
"n_head": 8,
|
11 |
+
"n_layer": 8,
|
12 |
+
"dropout": 0.3,
|
13 |
+
"encode": {
|
14 |
+
"\n": 0,
|
15 |
+
" ": 1,
|
16 |
+
"!": 2,
|
17 |
+
"\"": 3,
|
18 |
+
"#": 4,
|
19 |
+
"$": 5,
|
20 |
+
"%": 6,
|
21 |
+
"&": 7,
|
22 |
+
"'": 8,
|
23 |
+
"(": 9,
|
24 |
+
")": 10,
|
25 |
+
"*": 11,
|
26 |
+
"+": 12,
|
27 |
+
",": 13,
|
28 |
+
"-": 14,
|
29 |
+
".": 15,
|
30 |
+
"/": 16,
|
31 |
+
"0": 17,
|
32 |
+
"1": 18,
|
33 |
+
"2": 19,
|
34 |
+
"3": 20,
|
35 |
+
"4": 21,
|
36 |
+
"5": 22,
|
37 |
+
"6": 23,
|
38 |
+
"7": 24,
|
39 |
+
"8": 25,
|
40 |
+
"9": 26,
|
41 |
+
":": 27,
|
42 |
+
";": 28,
|
43 |
+
"<": 29,
|
44 |
+
"=": 30,
|
45 |
+
">": 31,
|
46 |
+
"?": 32,
|
47 |
+
"@": 33,
|
48 |
+
"A": 34,
|
49 |
+
"B": 35,
|
50 |
+
"C": 36,
|
51 |
+
"D": 37,
|
52 |
+
"E": 38,
|
53 |
+
"F": 39,
|
54 |
+
"G": 40,
|
55 |
+
"H": 41,
|
56 |
+
"I": 42,
|
57 |
+
"J": 43,
|
58 |
+
"K": 44,
|
59 |
+
"L": 45,
|
60 |
+
"M": 46,
|
61 |
+
"N": 47,
|
62 |
+
"O": 48,
|
63 |
+
"P": 49,
|
64 |
+
"Q": 50,
|
65 |
+
"R": 51,
|
66 |
+
"S": 52,
|
67 |
+
"T": 53,
|
68 |
+
"U": 54,
|
69 |
+
"V": 55,
|
70 |
+
"W": 56,
|
71 |
+
"X": 57,
|
72 |
+
"Y": 58,
|
73 |
+
"Z": 59,
|
74 |
+
"[": 60,
|
75 |
+
"\\": 61,
|
76 |
+
"]": 62,
|
77 |
+
"^": 63,
|
78 |
+
"_": 64,
|
79 |
+
"`": 65,
|
80 |
+
"a": 66,
|
81 |
+
"b": 67,
|
82 |
+
"c": 68,
|
83 |
+
"d": 69,
|
84 |
+
"e": 70,
|
85 |
+
"f": 71,
|
86 |
+
"g": 72,
|
87 |
+
"h": 73,
|
88 |
+
"i": 74,
|
89 |
+
"j": 75,
|
90 |
+
"k": 76,
|
91 |
+
"l": 77,
|
92 |
+
"m": 78,
|
93 |
+
"n": 79,
|
94 |
+
"o": 80,
|
95 |
+
"p": 81,
|
96 |
+
"q": 82,
|
97 |
+
"r": 83,
|
98 |
+
"s": 84,
|
99 |
+
"t": 85,
|
100 |
+
"u": 86,
|
101 |
+
"v": 87,
|
102 |
+
"w": 88,
|
103 |
+
"x": 89,
|
104 |
+
"y": 90,
|
105 |
+
"z": 91,
|
106 |
+
"{": 92,
|
107 |
+
"|": 93,
|
108 |
+
"}": 94,
|
109 |
+
"\u00a0": 95,
|
110 |
+
"\u00a3": 96,
|
111 |
+
"\u00b0": 97,
|
112 |
+
"\u00b2": 98,
|
113 |
+
"\u00b3": 99,
|
114 |
+
"\u00bc": 100,
|
115 |
+
"\u00bd": 101,
|
116 |
+
"\u00be": 102,
|
117 |
+
"\u00c6": 103,
|
118 |
+
"\u00c7": 104,
|
119 |
+
"\u00c8": 105,
|
120 |
+
"\u00c9": 106,
|
121 |
+
"\u00d7": 107,
|
122 |
+
"\u00dc": 108,
|
123 |
+
"\u00e0": 109,
|
124 |
+
"\u00e1": 110,
|
125 |
+
"\u00e2": 111,
|
126 |
+
"\u00e6": 112,
|
127 |
+
"\u00e7": 113,
|
128 |
+
"\u00e8": 114,
|
129 |
+
"\u00e9": 115,
|
130 |
+
"\u00ea": 116,
|
131 |
+
"\u00eb": 117,
|
132 |
+
"\u00ee": 118,
|
133 |
+
"\u00ef": 119,
|
134 |
+
"\u00f1": 120,
|
135 |
+
"\u00f2": 121,
|
136 |
+
"\u00f4": 122,
|
137 |
+
"\u00f6": 123,
|
138 |
+
"\u00f7": 124,
|
139 |
+
"\u00f9": 125,
|
140 |
+
"\u00fb": 126,
|
141 |
+
"\u00fc": 127,
|
142 |
+
"\u2013": 128,
|
143 |
+
"\u2014": 129,
|
144 |
+
"\u2018": 130,
|
145 |
+
"\u2019": 131,
|
146 |
+
"\u201c": 132,
|
147 |
+
"\u201d": 133,
|
148 |
+
"\ufeff": 134
|
149 |
+
},
|
150 |
+
"decode": {
|
151 |
+
"0": "\n",
|
152 |
+
"1": " ",
|
153 |
+
"2": "!",
|
154 |
+
"3": "\"",
|
155 |
+
"4": "#",
|
156 |
+
"5": "$",
|
157 |
+
"6": "%",
|
158 |
+
"7": "&",
|
159 |
+
"8": "'",
|
160 |
+
"9": "(",
|
161 |
+
"10": ")",
|
162 |
+
"11": "*",
|
163 |
+
"12": "+",
|
164 |
+
"13": ",",
|
165 |
+
"14": "-",
|
166 |
+
"15": ".",
|
167 |
+
"16": "/",
|
168 |
+
"17": "0",
|
169 |
+
"18": "1",
|
170 |
+
"19": "2",
|
171 |
+
"20": "3",
|
172 |
+
"21": "4",
|
173 |
+
"22": "5",
|
174 |
+
"23": "6",
|
175 |
+
"24": "7",
|
176 |
+
"25": "8",
|
177 |
+
"26": "9",
|
178 |
+
"27": ":",
|
179 |
+
"28": ";",
|
180 |
+
"29": "<",
|
181 |
+
"30": "=",
|
182 |
+
"31": ">",
|
183 |
+
"32": "?",
|
184 |
+
"33": "@",
|
185 |
+
"34": "A",
|
186 |
+
"35": "B",
|
187 |
+
"36": "C",
|
188 |
+
"37": "D",
|
189 |
+
"38": "E",
|
190 |
+
"39": "F",
|
191 |
+
"40": "G",
|
192 |
+
"41": "H",
|
193 |
+
"42": "I",
|
194 |
+
"43": "J",
|
195 |
+
"44": "K",
|
196 |
+
"45": "L",
|
197 |
+
"46": "M",
|
198 |
+
"47": "N",
|
199 |
+
"48": "O",
|
200 |
+
"49": "P",
|
201 |
+
"50": "Q",
|
202 |
+
"51": "R",
|
203 |
+
"52": "S",
|
204 |
+
"53": "T",
|
205 |
+
"54": "U",
|
206 |
+
"55": "V",
|
207 |
+
"56": "W",
|
208 |
+
"57": "X",
|
209 |
+
"58": "Y",
|
210 |
+
"59": "Z",
|
211 |
+
"60": "[",
|
212 |
+
"61": "\\",
|
213 |
+
"62": "]",
|
214 |
+
"63": "^",
|
215 |
+
"64": "_",
|
216 |
+
"65": "`",
|
217 |
+
"66": "a",
|
218 |
+
"67": "b",
|
219 |
+
"68": "c",
|
220 |
+
"69": "d",
|
221 |
+
"70": "e",
|
222 |
+
"71": "f",
|
223 |
+
"72": "g",
|
224 |
+
"73": "h",
|
225 |
+
"74": "i",
|
226 |
+
"75": "j",
|
227 |
+
"76": "k",
|
228 |
+
"77": "l",
|
229 |
+
"78": "m",
|
230 |
+
"79": "n",
|
231 |
+
"80": "o",
|
232 |
+
"81": "p",
|
233 |
+
"82": "q",
|
234 |
+
"83": "r",
|
235 |
+
"84": "s",
|
236 |
+
"85": "t",
|
237 |
+
"86": "u",
|
238 |
+
"87": "v",
|
239 |
+
"88": "w",
|
240 |
+
"89": "x",
|
241 |
+
"90": "y",
|
242 |
+
"91": "z",
|
243 |
+
"92": "{",
|
244 |
+
"93": "|",
|
245 |
+
"94": "}",
|
246 |
+
"95": "\u00a0",
|
247 |
+
"96": "\u00a3",
|
248 |
+
"97": "\u00b0",
|
249 |
+
"98": "\u00b2",
|
250 |
+
"99": "\u00b3",
|
251 |
+
"100": "\u00bc",
|
252 |
+
"101": "\u00bd",
|
253 |
+
"102": "\u00be",
|
254 |
+
"103": "\u00c6",
|
255 |
+
"104": "\u00c7",
|
256 |
+
"105": "\u00c8",
|
257 |
+
"106": "\u00c9",
|
258 |
+
"107": "\u00d7",
|
259 |
+
"108": "\u00dc",
|
260 |
+
"109": "\u00e0",
|
261 |
+
"110": "\u00e1",
|
262 |
+
"111": "\u00e2",
|
263 |
+
"112": "\u00e6",
|
264 |
+
"113": "\u00e7",
|
265 |
+
"114": "\u00e8",
|
266 |
+
"115": "\u00e9",
|
267 |
+
"116": "\u00ea",
|
268 |
+
"117": "\u00eb",
|
269 |
+
"118": "\u00ee",
|
270 |
+
"119": "\u00ef",
|
271 |
+
"120": "\u00f1",
|
272 |
+
"121": "\u00f2",
|
273 |
+
"122": "\u00f4",
|
274 |
+
"123": "\u00f6",
|
275 |
+
"124": "\u00f7",
|
276 |
+
"125": "\u00f9",
|
277 |
+
"126": "\u00fb",
|
278 |
+
"127": "\u00fc",
|
279 |
+
"128": "\u2013",
|
280 |
+
"129": "\u2014",
|
281 |
+
"130": "\u2018",
|
282 |
+
"131": "\u2019",
|
283 |
+
"132": "\u201c",
|
284 |
+
"133": "\u201d",
|
285 |
+
"134": "\ufeff"
|
286 |
+
}
|
287 |
+
}
|
config/example-config.json
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"batch_size": 32,
|
3 |
+
"block_size": 256,
|
4 |
+
"max_iters": 3000,
|
5 |
+
"eval_interval": 300,
|
6 |
+
"learning_rate": 3e-5,
|
7 |
+
"eval_iters": 50,
|
8 |
+
"n_embd": 1024,
|
9 |
+
"n_head": 12,
|
10 |
+
"n_layer": 18,
|
11 |
+
"dropout": 0.3
|
12 |
+
}
|
config/harpoon_config.json
ADDED
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"vocab_size": 135,
|
3 |
+
"batch_size": 32,
|
4 |
+
"block_size": 256,
|
5 |
+
"max_iters": 3000,
|
6 |
+
"eval_interval": 300,
|
7 |
+
"learning_rate": 3e-5,
|
8 |
+
"eval_iters": 50,
|
9 |
+
"n_embd": 1024,
|
10 |
+
"n_head": 12,
|
11 |
+
"n_layer": 18,
|
12 |
+
"dropout": 0.3,
|
13 |
+
"encode": {
|
14 |
+
"\n": 0,
|
15 |
+
" ": 1,
|
16 |
+
"!": 2,
|
17 |
+
"\"": 3,
|
18 |
+
"#": 4,
|
19 |
+
"$": 5,
|
20 |
+
"%": 6,
|
21 |
+
"&": 7,
|
22 |
+
"'": 8,
|
23 |
+
"(": 9,
|
24 |
+
")": 10,
|
25 |
+
"*": 11,
|
26 |
+
"+": 12,
|
27 |
+
",": 13,
|
28 |
+
"-": 14,
|
29 |
+
".": 15,
|
30 |
+
"/": 16,
|
31 |
+
"0": 17,
|
32 |
+
"1": 18,
|
33 |
+
"2": 19,
|
34 |
+
"3": 20,
|
35 |
+
"4": 21,
|
36 |
+
"5": 22,
|
37 |
+
"6": 23,
|
38 |
+
"7": 24,
|
39 |
+
"8": 25,
|
40 |
+
"9": 26,
|
41 |
+
":": 27,
|
42 |
+
";": 28,
|
43 |
+
"<": 29,
|
44 |
+
"=": 30,
|
45 |
+
">": 31,
|
46 |
+
"?": 32,
|
47 |
+
"@": 33,
|
48 |
+
"A": 34,
|
49 |
+
"B": 35,
|
50 |
+
"C": 36,
|
51 |
+
"D": 37,
|
52 |
+
"E": 38,
|
53 |
+
"F": 39,
|
54 |
+
"G": 40,
|
55 |
+
"H": 41,
|
56 |
+
"I": 42,
|
57 |
+
"J": 43,
|
58 |
+
"K": 44,
|
59 |
+
"L": 45,
|
60 |
+
"M": 46,
|
61 |
+
"N": 47,
|
62 |
+
"O": 48,
|
63 |
+
"P": 49,
|
64 |
+
"Q": 50,
|
65 |
+
"R": 51,
|
66 |
+
"S": 52,
|
67 |
+
"T": 53,
|
68 |
+
"U": 54,
|
69 |
+
"V": 55,
|
70 |
+
"W": 56,
|
71 |
+
"X": 57,
|
72 |
+
"Y": 58,
|
73 |
+
"Z": 59,
|
74 |
+
"[": 60,
|
75 |
+
"\\": 61,
|
76 |
+
"]": 62,
|
77 |
+
"^": 63,
|
78 |
+
"_": 64,
|
79 |
+
"`": 65,
|
80 |
+
"a": 66,
|
81 |
+
"b": 67,
|
82 |
+
"c": 68,
|
83 |
+
"d": 69,
|
84 |
+
"e": 70,
|
85 |
+
"f": 71,
|
86 |
+
"g": 72,
|
87 |
+
"h": 73,
|
88 |
+
"i": 74,
|
89 |
+
"j": 75,
|
90 |
+
"k": 76,
|
91 |
+
"l": 77,
|
92 |
+
"m": 78,
|
93 |
+
"n": 79,
|
94 |
+
"o": 80,
|
95 |
+
"p": 81,
|
96 |
+
"q": 82,
|
97 |
+
"r": 83,
|
98 |
+
"s": 84,
|
99 |
+
"t": 85,
|
100 |
+
"u": 86,
|
101 |
+
"v": 87,
|
102 |
+
"w": 88,
|
103 |
+
"x": 89,
|
104 |
+
"y": 90,
|
105 |
+
"z": 91,
|
106 |
+
"{": 92,
|
107 |
+
"|": 93,
|
108 |
+
"}": 94,
|
109 |
+
"\u00a0": 95,
|
110 |
+
"\u00a3": 96,
|
111 |
+
"\u00b0": 97,
|
112 |
+
"\u00b2": 98,
|
113 |
+
"\u00b3": 99,
|
114 |
+
"\u00bc": 100,
|
115 |
+
"\u00bd": 101,
|
116 |
+
"\u00be": 102,
|
117 |
+
"\u00c6": 103,
|
118 |
+
"\u00c7": 104,
|
119 |
+
"\u00c8": 105,
|
120 |
+
"\u00c9": 106,
|
121 |
+
"\u00d7": 107,
|
122 |
+
"\u00dc": 108,
|
123 |
+
"\u00e0": 109,
|
124 |
+
"\u00e1": 110,
|
125 |
+
"\u00e2": 111,
|
126 |
+
"\u00e6": 112,
|
127 |
+
"\u00e7": 113,
|
128 |
+
"\u00e8": 114,
|
129 |
+
"\u00e9": 115,
|
130 |
+
"\u00ea": 116,
|
131 |
+
"\u00eb": 117,
|
132 |
+
"\u00ee": 118,
|
133 |
+
"\u00ef": 119,
|
134 |
+
"\u00f1": 120,
|
135 |
+
"\u00f2": 121,
|
136 |
+
"\u00f4": 122,
|
137 |
+
"\u00f6": 123,
|
138 |
+
"\u00f7": 124,
|
139 |
+
"\u00f9": 125,
|
140 |
+
"\u00fb": 126,
|
141 |
+
"\u00fc": 127,
|
142 |
+
"\u2013": 128,
|
143 |
+
"\u2014": 129,
|
144 |
+
"\u2018": 130,
|
145 |
+
"\u2019": 131,
|
146 |
+
"\u201c": 132,
|
147 |
+
"\u201d": 133,
|
148 |
+
"\ufeff": 134
|
149 |
+
},
|
150 |
+
"decode": {
|
151 |
+
"0": "\n",
|
152 |
+
"1": " ",
|
153 |
+
"2": "!",
|
154 |
+
"3": "\"",
|
155 |
+
"4": "#",
|
156 |
+
"5": "$",
|
157 |
+
"6": "%",
|
158 |
+
"7": "&",
|
159 |
+
"8": "'",
|
160 |
+
"9": "(",
|
161 |
+
"10": ")",
|
162 |
+
"11": "*",
|
163 |
+
"12": "+",
|
164 |
+
"13": ",",
|
165 |
+
"14": "-",
|
166 |
+
"15": ".",
|
167 |
+
"16": "/",
|
168 |
+
"17": "0",
|
169 |
+
"18": "1",
|
170 |
+
"19": "2",
|
171 |
+
"20": "3",
|
172 |
+
"21": "4",
|
173 |
+
"22": "5",
|
174 |
+
"23": "6",
|
175 |
+
"24": "7",
|
176 |
+
"25": "8",
|
177 |
+
"26": "9",
|
178 |
+
"27": ":",
|
179 |
+
"28": ";",
|
180 |
+
"29": "<",
|
181 |
+
"30": "=",
|
182 |
+
"31": ">",
|
183 |
+
"32": "?",
|
184 |
+
"33": "@",
|
185 |
+
"34": "A",
|
186 |
+
"35": "B",
|
187 |
+
"36": "C",
|
188 |
+
"37": "D",
|
189 |
+
"38": "E",
|
190 |
+
"39": "F",
|
191 |
+
"40": "G",
|
192 |
+
"41": "H",
|
193 |
+
"42": "I",
|
194 |
+
"43": "J",
|
195 |
+
"44": "K",
|
196 |
+
"45": "L",
|
197 |
+
"46": "M",
|
198 |
+
"47": "N",
|
199 |
+
"48": "O",
|
200 |
+
"49": "P",
|
201 |
+
"50": "Q",
|
202 |
+
"51": "R",
|
203 |
+
"52": "S",
|
204 |
+
"53": "T",
|
205 |
+
"54": "U",
|
206 |
+
"55": "V",
|
207 |
+
"56": "W",
|
208 |
+
"57": "X",
|
209 |
+
"58": "Y",
|
210 |
+
"59": "Z",
|
211 |
+
"60": "[",
|
212 |
+
"61": "\\",
|
213 |
+
"62": "]",
|
214 |
+
"63": "^",
|
215 |
+
"64": "_",
|
216 |
+
"65": "`",
|
217 |
+
"66": "a",
|
218 |
+
"67": "b",
|
219 |
+
"68": "c",
|
220 |
+
"69": "d",
|
221 |
+
"70": "e",
|
222 |
+
"71": "f",
|
223 |
+
"72": "g",
|
224 |
+
"73": "h",
|
225 |
+
"74": "i",
|
226 |
+
"75": "j",
|
227 |
+
"76": "k",
|
228 |
+
"77": "l",
|
229 |
+
"78": "m",
|
230 |
+
"79": "n",
|
231 |
+
"80": "o",
|
232 |
+
"81": "p",
|
233 |
+
"82": "q",
|
234 |
+
"83": "r",
|
235 |
+
"84": "s",
|
236 |
+
"85": "t",
|
237 |
+
"86": "u",
|
238 |
+
"87": "v",
|
239 |
+
"88": "w",
|
240 |
+
"89": "x",
|
241 |
+
"90": "y",
|
242 |
+
"91": "z",
|
243 |
+
"92": "{",
|
244 |
+
"93": "|",
|
245 |
+
"94": "}",
|
246 |
+
"95": "\u00a0",
|
247 |
+
"96": "\u00a3",
|
248 |
+
"97": "\u00b0",
|
249 |
+
"98": "\u00b2",
|
250 |
+
"99": "\u00b3",
|
251 |
+
"100": "\u00bc",
|
252 |
+
"101": "\u00bd",
|
253 |
+
"102": "\u00be",
|
254 |
+
"103": "\u00c6",
|
255 |
+
"104": "\u00c7",
|
256 |
+
"105": "\u00c8",
|
257 |
+
"106": "\u00c9",
|
258 |
+
"107": "\u00d7",
|
259 |
+
"108": "\u00dc",
|
260 |
+
"109": "\u00e0",
|
261 |
+
"110": "\u00e1",
|
262 |
+
"111": "\u00e2",
|
263 |
+
"112": "\u00e6",
|
264 |
+
"113": "\u00e7",
|
265 |
+
"114": "\u00e8",
|
266 |
+
"115": "\u00e9",
|
267 |
+
"116": "\u00ea",
|
268 |
+
"117": "\u00eb",
|
269 |
+
"118": "\u00ee",
|
270 |
+
"119": "\u00ef",
|
271 |
+
"120": "\u00f1",
|
272 |
+
"121": "\u00f2",
|
273 |
+
"122": "\u00f4",
|
274 |
+
"123": "\u00f6",
|
275 |
+
"124": "\u00f7",
|
276 |
+
"125": "\u00f9",
|
277 |
+
"126": "\u00fb",
|
278 |
+
"127": "\u00fc",
|
279 |
+
"128": "\u2013",
|
280 |
+
"129": "\u2014",
|
281 |
+
"130": "\u2018",
|
282 |
+
"131": "\u2019",
|
283 |
+
"132": "\u201c",
|
284 |
+
"133": "\u201d",
|
285 |
+
"134": "\ufeff"
|
286 |
+
}
|
287 |
+
}
|
config/shakespearean_config.json
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"vocab_size": 65,
|
3 |
+
"batch_size": 32,
|
4 |
+
"block_size": 256,
|
5 |
+
"max_iters": 3000,
|
6 |
+
"eval_interval": 300,
|
7 |
+
"learning_rate": 3e-5,
|
8 |
+
"eval_iters": 50,
|
9 |
+
"n_embd": 384,
|
10 |
+
"n_head": 6,
|
11 |
+
"n_layer": 6,
|
12 |
+
"dropout": 0.3,
|
13 |
+
"encode": {
|
14 |
+
"\n": 0,
|
15 |
+
" ": 1,
|
16 |
+
"!": 2,
|
17 |
+
"$": 3,
|
18 |
+
"&": 4,
|
19 |
+
"'": 5,
|
20 |
+
",": 6,
|
21 |
+
"-": 7,
|
22 |
+
".": 8,
|
23 |
+
"3": 9,
|
24 |
+
":": 10,
|
25 |
+
";": 11,
|
26 |
+
"?": 12,
|
27 |
+
"A": 13,
|
28 |
+
"B": 14,
|
29 |
+
"C": 15,
|
30 |
+
"D": 16,
|
31 |
+
"E": 17,
|
32 |
+
"F": 18,
|
33 |
+
"G": 19,
|
34 |
+
"H": 20,
|
35 |
+
"I": 21,
|
36 |
+
"J": 22,
|
37 |
+
"K": 23,
|
38 |
+
"L": 24,
|
39 |
+
"M": 25,
|
40 |
+
"N": 26,
|
41 |
+
"O": 27,
|
42 |
+
"P": 28,
|
43 |
+
"Q": 29,
|
44 |
+
"R": 30,
|
45 |
+
"S": 31,
|
46 |
+
"T": 32,
|
47 |
+
"U": 33,
|
48 |
+
"V": 34,
|
49 |
+
"W": 35,
|
50 |
+
"X": 36,
|
51 |
+
"Y": 37,
|
52 |
+
"Z": 38,
|
53 |
+
"a": 39,
|
54 |
+
"b": 40,
|
55 |
+
"c": 41,
|
56 |
+
"d": 42,
|
57 |
+
"e": 43,
|
58 |
+
"f": 44,
|
59 |
+
"g": 45,
|
60 |
+
"h": 46,
|
61 |
+
"i": 47,
|
62 |
+
"j": 48,
|
63 |
+
"k": 49,
|
64 |
+
"l": 50,
|
65 |
+
"m": 51,
|
66 |
+
"n": 52,
|
67 |
+
"o": 53,
|
68 |
+
"p": 54,
|
69 |
+
"q": 55,
|
70 |
+
"r": 56,
|
71 |
+
"s": 57,
|
72 |
+
"t": 58,
|
73 |
+
"u": 59,
|
74 |
+
"v": 60,
|
75 |
+
"w": 61,
|
76 |
+
"x": 62,
|
77 |
+
"y": 63,
|
78 |
+
"z": 64
|
79 |
+
},
|
80 |
+
"decode": {
|
81 |
+
"0": "\n",
|
82 |
+
"1": " ",
|
83 |
+
"2": "!",
|
84 |
+
"3": "$",
|
85 |
+
"4": "&",
|
86 |
+
"5": "'",
|
87 |
+
"6": ",",
|
88 |
+
"7": "-",
|
89 |
+
"8": ".",
|
90 |
+
"9": "3",
|
91 |
+
"10": ":",
|
92 |
+
"11": ";",
|
93 |
+
"12": "?",
|
94 |
+
"13": "A",
|
95 |
+
"14": "B",
|
96 |
+
"15": "C",
|
97 |
+
"16": "D",
|
98 |
+
"17": "E",
|
99 |
+
"18": "F",
|
100 |
+
"19": "G",
|
101 |
+
"20": "H",
|
102 |
+
"21": "I",
|
103 |
+
"22": "J",
|
104 |
+
"23": "K",
|
105 |
+
"24": "L",
|
106 |
+
"25": "M",
|
107 |
+
"26": "N",
|
108 |
+
"27": "O",
|
109 |
+
"28": "P",
|
110 |
+
"29": "Q",
|
111 |
+
"30": "R",
|
112 |
+
"31": "S",
|
113 |
+
"32": "T",
|
114 |
+
"33": "U",
|
115 |
+
"34": "V",
|
116 |
+
"35": "W",
|
117 |
+
"36": "X",
|
118 |
+
"37": "Y",
|
119 |
+
"38": "Z",
|
120 |
+
"39": "a",
|
121 |
+
"40": "b",
|
122 |
+
"41": "c",
|
123 |
+
"42": "d",
|
124 |
+
"43": "e",
|
125 |
+
"44": "f",
|
126 |
+
"45": "g",
|
127 |
+
"46": "h",
|
128 |
+
"47": "i",
|
129 |
+
"48": "j",
|
130 |
+
"49": "k",
|
131 |
+
"50": "l",
|
132 |
+
"51": "m",
|
133 |
+
"52": "n",
|
134 |
+
"53": "o",
|
135 |
+
"54": "p",
|
136 |
+
"55": "q",
|
137 |
+
"56": "r",
|
138 |
+
"57": "s",
|
139 |
+
"58": "t",
|
140 |
+
"59": "u",
|
141 |
+
"60": "v",
|
142 |
+
"61": "w",
|
143 |
+
"62": "x",
|
144 |
+
"63": "y",
|
145 |
+
"64": "z"
|
146 |
+
}
|
147 |
+
}
|
core/__init__.py
ADDED
File without changes
|
core/layers/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .layers import Block, FeedForward, MultiHeadAttention, Head, RoPE, LlamaBlock, RMSNorm
|
core/layers/layers.py
ADDED
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch.nn import functional as F
|
4 |
+
import tqdm
|
5 |
+
|
6 |
+
|
7 |
+
class Head(nn.Module):
|
8 |
+
"""One head of self-attention."""
|
9 |
+
|
10 |
+
def __init__(self, n_embd, head_size, block_size, dropout):
|
11 |
+
super().__init__()
|
12 |
+
self.key = nn.Linear(n_embd, head_size, bias=False)
|
13 |
+
self.query = nn.Linear(n_embd, head_size, bias=False)
|
14 |
+
self.value = nn.Linear(n_embd, head_size, bias=False)
|
15 |
+
self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
|
16 |
+
self.dropout = nn.Dropout(dropout)
|
17 |
+
|
18 |
+
def forward(self, x):
|
19 |
+
B, T, C = x.shape
|
20 |
+
k = self.key(x)
|
21 |
+
q = self.query(x)
|
22 |
+
wei = q @ k.transpose(-2, -1) * k.shape[-1] ** -0.5
|
23 |
+
wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
|
24 |
+
wei = F.softmax(wei, dim=-1)
|
25 |
+
wei = self.dropout(wei)
|
26 |
+
v = self.value(x)
|
27 |
+
out = wei @ v
|
28 |
+
return out
|
29 |
+
|
30 |
+
class MultiHeadAttention(nn.Module):
|
31 |
+
"""Multiple heads of self-attention in parallel."""
|
32 |
+
|
33 |
+
def __init__(self, n_embd, n_head, block_size, dropout):
|
34 |
+
super().__init__()
|
35 |
+
assert n_embd % n_head == 0, f"n_embd ({n_embd}) must be divisible by num_heads ({n_head})"
|
36 |
+
|
37 |
+
self.n_embd = n_embd
|
38 |
+
self.n_head = n_head
|
39 |
+
self.head_size = n_embd // n_head
|
40 |
+
|
41 |
+
self.heads = nn.ModuleList([Head(n_embd, self.head_size, block_size, dropout) for _ in range(n_head)])
|
42 |
+
self.proj = nn.Linear(n_embd, n_embd)
|
43 |
+
self.dropout = nn.Dropout(dropout)
|
44 |
+
|
45 |
+
def forward(self, x):
|
46 |
+
out = torch.cat([h(x) for h in self.heads], dim=-1)
|
47 |
+
out = self.dropout(self.proj(out))
|
48 |
+
return out
|
49 |
+
|
50 |
+
class FeedForward(nn.Module):
|
51 |
+
"""A simple linear layer followed by a non-linearity."""
|
52 |
+
|
53 |
+
def __init__(self, n_embd, dropout):
|
54 |
+
super().__init__()
|
55 |
+
self.net = nn.Sequential(
|
56 |
+
nn.Linear(n_embd, 4 * n_embd),
|
57 |
+
nn.ReLU(),
|
58 |
+
nn.Linear(4 * n_embd, n_embd),
|
59 |
+
nn.Dropout(dropout),
|
60 |
+
)
|
61 |
+
|
62 |
+
def forward(self, x):
|
63 |
+
return self.net(x)
|
64 |
+
|
65 |
+
class Block(nn.Module):
|
66 |
+
"""Transformer block: communication followed by computation."""
|
67 |
+
|
68 |
+
def __init__(self, n_embd, n_head, block_size, dropout):
|
69 |
+
super().__init__()
|
70 |
+
self.sa = MultiHeadAttention(n_embd, n_head, block_size, dropout)
|
71 |
+
self.ffwd = FeedForward(n_embd, dropout)
|
72 |
+
self.ln1 = nn.LayerNorm(n_embd)
|
73 |
+
self.ln2 = nn.LayerNorm(n_embd)
|
74 |
+
|
75 |
+
def forward(self, x):
|
76 |
+
x = x + self.sa(self.ln1(x))
|
77 |
+
x = x + self.ffwd(self.ln2(x))
|
78 |
+
return x
|
79 |
+
|
80 |
+
|
81 |
+
class RoPE(nn.Module):
|
82 |
+
"""Rotary Positional Encoding (RoPE) layer."""
|
83 |
+
|
84 |
+
def __init__(self, embd_dim, max_freq=10):
|
85 |
+
super().__init__()
|
86 |
+
self.embd_dim = embd_dim
|
87 |
+
self.max_freq = max_freq
|
88 |
+
self.freqs = 2 ** torch.linspace(0, max_freq - 1, embd_dim // 2) * torch.pi
|
89 |
+
self.inv_freqs = 1. / self.freqs
|
90 |
+
|
91 |
+
def forward(self, x):
|
92 |
+
x = x + torch.sin(x @ self.freqs) * self.inv_freqs
|
93 |
+
x = x + torch.cos(x @ self.freqs) * self.inv_freqs
|
94 |
+
return x
|
95 |
+
|
96 |
+
|
97 |
+
class RMSNorm(nn.Module):
|
98 |
+
"""Root Mean Square Layer Normalization (RMSNorm)."""
|
99 |
+
|
100 |
+
def __init__(self, embd_dim, epsilon=1e-8):
|
101 |
+
super().__init__()
|
102 |
+
self.embd_dim = embd_dim
|
103 |
+
self.epsilon = epsilon
|
104 |
+
self.gamma = nn.Parameter(torch.ones(embd_dim))
|
105 |
+
self.beta = nn.Parameter(torch.zeros(embd_dim))
|
106 |
+
|
107 |
+
def forward(self, x: torch.Tensor):
|
108 |
+
mean = x.mean(-1, keepdim=True)
|
109 |
+
variance = x.var(-1, keepdim=True)
|
110 |
+
x = x - mean
|
111 |
+
x = x / torch.sqrt(variance + self.epsilon)
|
112 |
+
x = x * self.gamma + self.beta
|
113 |
+
return x
|
114 |
+
|
115 |
+
|
116 |
+
class LlamaFFN(nn.Module):
|
117 |
+
"""Feed-forward network of the LLAMA model with SwiGLU activation."""
|
118 |
+
|
119 |
+
def __init__(self, n_embd, dropout):
|
120 |
+
super().__init__()
|
121 |
+
self.linear1 = nn.Linear(n_embd, 4 * n_embd)
|
122 |
+
self.linear2 = nn.Linear(4 * n_embd, n_embd)
|
123 |
+
self.dropout = nn.Dropout(dropout)
|
124 |
+
|
125 |
+
def swiglu(self, x):
|
126 |
+
"""Applies SwiGLU activation."""
|
127 |
+
x1, x2 = torch.chunk(x, 2, dim=-1)
|
128 |
+
return x1 * F.silu(x2)
|
129 |
+
|
130 |
+
def forward(self, x):
|
131 |
+
x = self.linear1(x)
|
132 |
+
x = self.swiglu(x)
|
133 |
+
x = self.dropout(x)
|
134 |
+
x = self.linear2(x)
|
135 |
+
return x
|
136 |
+
|
137 |
+
|
138 |
+
class AttentionHeadWithKVCacheAndRoPE(nn.Module):
|
139 |
+
"""One head of self-attention with key and value cache and RoPE."""
|
140 |
+
|
141 |
+
def __init__(self, n_embd, head_size, block_size, dropout):
|
142 |
+
super().__init__()
|
143 |
+
self.key = nn.Linear(n_embd, head_size, bias=False)
|
144 |
+
self.query = nn.Linear(n_embd, head_size, bias=False)
|
145 |
+
self.value = nn.Linear(n_embd, head_size, bias=False)
|
146 |
+
self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
|
147 |
+
self.dropout = nn.Dropout(dropout)
|
148 |
+
self.pe = RoPE(head_size)
|
149 |
+
self.ln = RMSNorm(n_embd)
|
150 |
+
|
151 |
+
def forward(self, x, kv_cache):
|
152 |
+
B, T, C = x.shape
|
153 |
+
k = self.key(x)
|
154 |
+
q = self.query(x)
|
155 |
+
v = self.value(x)
|
156 |
+
if kv_cache is not None:
|
157 |
+
k = torch.cat([kv_cache['k'], k], dim=1)
|
158 |
+
v = torch.cat([kv_cache['v'], v], dim=1)
|
159 |
+
wei = q @ k.transpose(-2, -1) * k.shape[-1] ** -0.5
|
160 |
+
wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
|
161 |
+
wei = F.softmax(wei, dim=-1)
|
162 |
+
wei = self.dropout(wei)
|
163 |
+
out = wei @ v
|
164 |
+
if kv_cache is None:
|
165 |
+
kv_cache = {'k': k, 'q': q, 'v': v}
|
166 |
+
else:
|
167 |
+
kv_cache['k'] = k
|
168 |
+
kv_cache['q'] = q
|
169 |
+
kv_cache['v'] = v
|
170 |
+
return self.pe(out) + x
|
171 |
+
|
172 |
+
|
173 |
+
class MultiHeadAttentionWithKVCacheAndRoPE(nn.Module):
|
174 |
+
"""Multiple heads of self-attention in parallel."""
|
175 |
+
|
176 |
+
def __init__(self, n_embd, n_head, block_size, dropout):
|
177 |
+
super().__init__()
|
178 |
+
assert n_embd % n_head == 0, f"n_embd ({n_embd}) must be divisible by num_heads ({n_head})"
|
179 |
+
|
180 |
+
self.n_embd = n_embd
|
181 |
+
self.n_head = n_head
|
182 |
+
self.head_size = n_embd // n_head
|
183 |
+
|
184 |
+
self.heads = nn.ModuleList([AttentionHeadWithKVCacheAndRoPE(n_embd, self.head_size, block_size, dropout) for _ in range(n_head)])
|
185 |
+
self.proj = nn.Linear(n_embd, n_embd)
|
186 |
+
self.dropout = nn.Dropout(dropout)
|
187 |
+
|
188 |
+
def forward(self, x, kv_cache):
|
189 |
+
out = torch.cat([h(x, kv_cache) for h in self.heads], dim=-1)
|
190 |
+
out = self.dropout(self.proj(out))
|
191 |
+
return out
|
192 |
+
|
193 |
+
|
194 |
+
class LlamaBlock(nn.Module):
|
195 |
+
"""LLAMA block: communication followed by computation."""
|
196 |
+
|
197 |
+
def __init__(self, n_embd, n_head, block_size, dropout):
|
198 |
+
super().__init__()
|
199 |
+
self.ln1 = RMSNorm(n_embd)
|
200 |
+
self.sa = MultiHeadAttentionWithKVCacheAndRoPE(n_embd, n_head, block_size, dropout)
|
201 |
+
self.ln2 = RMSNorm(n_embd)
|
202 |
+
self.ffwd = LlamaFFN(n_embd, dropout)
|
203 |
+
|
204 |
+
def forward(self, x, kv_cache):
|
205 |
+
x = x + self.sa(self.ln1(x), kv_cache)
|
206 |
+
x = x + self.ffwd(self.ln2(x))
|
207 |
+
return x
|
core/models/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from . import gpt
|
core/models/gpt.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import tqdm
|
4 |
+
from torch.nn import functional as F
|
5 |
+
from core.layers import Block
|
6 |
+
|
7 |
+
class GPTLanguageModel(nn.Module):
|
8 |
+
|
9 |
+
def __init__(self, vocab_size, n_embd, block_size, n_head, n_layer, dropout, device, name = "GPT"):
|
10 |
+
super().__init__()
|
11 |
+
self.name = name
|
12 |
+
self.block_size = block_size
|
13 |
+
self.device = device
|
14 |
+
self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
|
15 |
+
self.position_embedding_table = nn.Embedding(block_size, n_embd)
|
16 |
+
self.blocks = nn.Sequential(*[Block(n_embd, n_head, block_size, dropout) for _ in range(n_layer)])
|
17 |
+
self.ln_f = nn.LayerNorm(n_embd)
|
18 |
+
self.lm_head = nn.Linear(n_embd, vocab_size)
|
19 |
+
self.apply(self._init_weights)
|
20 |
+
self.history = {}
|
21 |
+
self.vocab_size = vocab_size
|
22 |
+
|
23 |
+
def _init_weights(self, module):
|
24 |
+
if isinstance(module, nn.Linear):
|
25 |
+
nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
26 |
+
if module.bias is not None:
|
27 |
+
nn.init.zeros_(module.bias)
|
28 |
+
elif isinstance(module, nn.Embedding):
|
29 |
+
nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
30 |
+
|
31 |
+
def forward(self, idx, targets=None):
|
32 |
+
B, T = idx.shape
|
33 |
+
|
34 |
+
assert torch.all(idx < self.vocab_size), f"Input indices must be less than vocab_size ({self.vocab_size})"
|
35 |
+
assert T <= self.block_size, f"Input sequence length ({T}) must be <= block_size ({self.block_size})"
|
36 |
+
|
37 |
+
tok_emb = self.token_embedding_table(idx)
|
38 |
+
pos_emb = self.position_embedding_table(torch.arange(T, device=idx.device))
|
39 |
+
x = tok_emb + pos_emb
|
40 |
+
x = self.blocks(x)
|
41 |
+
x = self.ln_f(x)
|
42 |
+
logits = self.lm_head(x)
|
43 |
+
|
44 |
+
if targets is None:
|
45 |
+
loss = None
|
46 |
+
else:
|
47 |
+
B, T, C = logits.shape
|
48 |
+
logits = logits.view(B * T, C)
|
49 |
+
targets = targets.view(B * T)
|
50 |
+
loss = F.cross_entropy(logits, targets)
|
51 |
+
|
52 |
+
return logits, loss
|
53 |
+
|
54 |
+
def generate(self, idx, max_new_tokens, max_seq_length=200, temperature=1.0):
|
55 |
+
for _ in range(max_new_tokens):
|
56 |
+
if idx.size(1) > max_seq_length:
|
57 |
+
idx = idx[:, -max_seq_length:]
|
58 |
+
idx_cond = idx[:, -self.block_size:]
|
59 |
+
logits, _ = self(idx_cond)
|
60 |
+
logits = logits[:, -1, :] / temperature
|
61 |
+
probs = F.softmax(logits, dim=-1)
|
62 |
+
idx_next = torch.multinomial(probs, num_samples=1)
|
63 |
+
idx = torch.cat((idx, idx_next), dim=1)
|
64 |
+
yield idx
|
core/models/llama.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import tqdm
|
4 |
+
from torch.nn import functional as F
|
5 |
+
from core.layers import LlamaBlock, RMSNorm
|
6 |
+
|
7 |
+
class LlamaLanguageModel(nn.Module):
|
8 |
+
|
9 |
+
def __init__(self, vocab_size, n_embd, block_size, n_head, n_layer, dropout, device, name = "llama"):
|
10 |
+
super().__init__()
|
11 |
+
self.name = name
|
12 |
+
self.block_size = block_size
|
13 |
+
self.device = device
|
14 |
+
self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
|
15 |
+
self.blocks = nn.Sequential(*[LlamaBlock(n_embd, n_head, block_size, dropout) for _ in range(n_layer)])
|
16 |
+
self.ln_f = RMSNorm(n_embd)
|
17 |
+
self.lm_head = nn.Linear(n_embd, vocab_size)
|
18 |
+
self.apply(self._init_weights)
|
19 |
+
self.history = {}
|
20 |
+
self.vocab_size = vocab_size
|
21 |
+
|
22 |
+
def _init_weights(self, module):
|
23 |
+
if isinstance(module, nn.Linear):
|
24 |
+
nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
25 |
+
if module.bias is not None:
|
26 |
+
nn.init.zeros_(module.bias)
|
27 |
+
elif isinstance(module, nn.Embedding):
|
28 |
+
nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
29 |
+
|
30 |
+
def forward(self, idx):
|
31 |
+
B, T = idx.shape
|
32 |
+
kv_cache = None
|
33 |
+
token_embeddings = self.token_embedding_table(idx)
|
34 |
+
for block in self.blocks:
|
35 |
+
token_embeddings = block(token_embeddings, kv_cache)
|
36 |
+
token_embeddings = self.ln_f(token_embeddings)
|
37 |
+
logits = self.lm_head(token_embeddings)
|
38 |
+
return logits, token_embeddings
|
39 |
+
|
40 |
+
|
41 |
+
def generate(self, idx, max_new_tokens, max_seq_length=200, temperature=1.0):
|
42 |
+
for _ in range(max_new_tokens):
|
43 |
+
if idx.size(1) > max_seq_length:
|
44 |
+
idx = idx[:, -max_seq_length:]
|
45 |
+
idx_cond = idx[:, -self.block_size:]
|
46 |
+
logits, _ = self(idx_cond)
|
47 |
+
logits = logits[:, -1, :] / temperature
|
48 |
+
probs = F.softmax(logits, dim=-1)
|
49 |
+
idx_next = torch.multinomial(probs, num_samples=1)
|
50 |
+
idx = torch.cat((idx, idx_next), dim=1)
|
51 |
+
yield idx
|
core/tokenizers/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from . import tokenizer
|
core/tokenizers/tokenizer.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
from typing import Iterable
|
4 |
+
import torch
|
5 |
+
|
6 |
+
class Tokenizer:
|
7 |
+
def __init__(self, data_path: str = None):
|
8 |
+
self.config = None
|
9 |
+
self.stoi = None
|
10 |
+
self.itos = None
|
11 |
+
self.vocab_size = None
|
12 |
+
if data_path:
|
13 |
+
self.data = self.load_data(data_path)
|
14 |
+
else:
|
15 |
+
self.data = None
|
16 |
+
|
17 |
+
def from_pretrained(self, config_path: str):
|
18 |
+
with open(config_path) as f:
|
19 |
+
config = json.load(f)
|
20 |
+
self.config = config
|
21 |
+
if 'encode' not in config:
|
22 |
+
raise ValueError("Config file must contain an 'encode' key.")
|
23 |
+
if 'decode' not in config:
|
24 |
+
raise ValueError("Config file must contain a 'decode' key.")
|
25 |
+
if 'vocab_size' not in config:
|
26 |
+
raise ValueError("Config file must contain a 'vocab_size' key.")
|
27 |
+
stoi = config['encode']
|
28 |
+
self.stoi = {k: int(v) for k, v in stoi.items()}
|
29 |
+
itos = config['decode']
|
30 |
+
self.itos = {int(k): v for k, v in itos.items()}
|
31 |
+
self.vocab_size = config['vocab_size']
|
32 |
+
return self
|
33 |
+
|
34 |
+
def load_data(self, path: str) -> str:
|
35 |
+
if not os.path.exists(path):
|
36 |
+
raise FileNotFoundError("File not found.")
|
37 |
+
if not path.endswith('.txt'):
|
38 |
+
raise ValueError("File must be a text file.")
|
39 |
+
with open(path, 'r', encoding='utf-8') as f:
|
40 |
+
text = f.read()
|
41 |
+
chars = sorted(list(set(text)))
|
42 |
+
vocab_size = len(chars)
|
43 |
+
stoi = {ch: i for i, ch in enumerate(chars)}
|
44 |
+
itos = {i: ch for i, ch in enumerate(chars)}
|
45 |
+
self.config = {"vocab_size": vocab_size, "encode": stoi, "decode": itos}
|
46 |
+
self.stoi = stoi
|
47 |
+
self.itos = itos
|
48 |
+
data = torch.tensor(self(text), dtype=torch.long)
|
49 |
+
n = int(0.9*len(data))
|
50 |
+
train_data = data[:n]
|
51 |
+
val_data = data[n:]
|
52 |
+
self.train_data = train_data
|
53 |
+
self.val_data = val_data
|
54 |
+
self.vocab_size = vocab_size
|
55 |
+
return text
|
56 |
+
|
57 |
+
def __repr__(self) -> str:
|
58 |
+
if self.config:
|
59 |
+
return f"Tokenizer(config={self.config})"
|
60 |
+
else:
|
61 |
+
return f"Tokenizer()"
|
62 |
+
|
63 |
+
def __str__(self) -> str:
|
64 |
+
if self.config:
|
65 |
+
return f"Tokenizer(config_path={self.config})"
|
66 |
+
else:
|
67 |
+
return f"Tokenizer()"
|
68 |
+
|
69 |
+
def __len__(self) -> int:
|
70 |
+
return len(self.stoi)
|
71 |
+
|
72 |
+
def __getitem__(self, key: str) -> int:
|
73 |
+
return self.stoi[key]
|
74 |
+
|
75 |
+
def __contains__(self, key: str) -> bool:
|
76 |
+
return key in self.stoi
|
77 |
+
|
78 |
+
def __iter__(self):
|
79 |
+
return iter(self.stoi)
|
80 |
+
|
81 |
+
def __reversed__(self):
|
82 |
+
return reversed(self.stoi)
|
83 |
+
|
84 |
+
def keys(self):
|
85 |
+
return self.stoi.keys()
|
86 |
+
|
87 |
+
def values(self):
|
88 |
+
return self.stoi.values()
|
89 |
+
|
90 |
+
def items(self):
|
91 |
+
return self.stoi.items()
|
92 |
+
|
93 |
+
def __call__(self, *args, **kwds) -> list[int]:
|
94 |
+
return self.encode(*args, **kwds)
|
95 |
+
|
96 |
+
def encode(self, s: str | list[str]) -> list[int]:
|
97 |
+
if isinstance(s, str):
|
98 |
+
return [self.stoi[c] for c in s]
|
99 |
+
elif isinstance(s, list):
|
100 |
+
return [[self.stoi[i] for i in c] for c in s]
|
101 |
+
else:
|
102 |
+
raise ValueError("Input must be a string or a list of strings.")
|
103 |
+
|
104 |
+
def decode(self, l: list[int]) -> str:
|
105 |
+
if isinstance(l[0], int):
|
106 |
+
return ''.join([self.itos[i] for i in l])
|
107 |
+
elif isinstance(l[0], Iterable):
|
108 |
+
return [''.join([self.itos[i] for i in c]) for c in l]
|
109 |
+
else:
|
110 |
+
raise ValueError("Input must be a list of integers or a list of list of integers.")
|
111 |
+
|
112 |
+
def save_pretrained(self, path: str) -> str:
|
113 |
+
with open(path + 'vocab.json', 'w') as f:
|
114 |
+
json.dump(self.config, f)
|
115 |
+
return "Tokenizer saved at {}.".format(path)
|
core/utils/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from . import gptutils, preprocessing
|
core/utils/gptutils.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import json
|
3 |
+
|
4 |
+
# ------------ Hyperparameters ------------
|
5 |
+
def hyperparameters(config_path: str):
|
6 |
+
with open(config_path) as f:
|
7 |
+
config = json.load(f)
|
8 |
+
|
9 |
+
batch_size = config['batch_size']
|
10 |
+
block_size = config['block_size']
|
11 |
+
max_iters = config['max_iters']
|
12 |
+
eval_interval = config['eval_interval']
|
13 |
+
learning_rate = config['learning_rate']
|
14 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
15 |
+
eval_iters = config['eval_iters']
|
16 |
+
n_embd = config['n_embd']
|
17 |
+
n_head = config['n_head']
|
18 |
+
n_layer = config['n_layer']
|
19 |
+
dropout = config['dropout']
|
20 |
+
return (batch_size, block_size, max_iters, eval_interval, learning_rate,
|
21 |
+
device, eval_iters, n_embd, n_head, n_layer, dropout)
|
22 |
+
# ----------------------------------------
|
23 |
+
|
24 |
+
def load_data(path) -> tuple[torch.Tensor, torch.Tensor, int, callable, callable]:
|
25 |
+
with open(path, 'r', encoding='utf-8') as f:
|
26 |
+
text = f.read()
|
27 |
+
|
28 |
+
# words = text.split()
|
29 |
+
# vocab_size = len(words)
|
30 |
+
# stoi = {word: i for i, word in enumerate(words)}
|
31 |
+
# itos = {i: word for i, word in enumerate(words)}
|
32 |
+
# def encode(s): return [stoi[w] for w in s.split()]
|
33 |
+
# def decode(ids): return ' '.join([itos[i] for i in ids])
|
34 |
+
|
35 |
+
chars = sorted(list(set(text)))
|
36 |
+
vocab_size = len(chars)
|
37 |
+
stoi = {ch: i for i, ch in enumerate(chars)}
|
38 |
+
itos = {i: ch for i, ch in enumerate(chars)}
|
39 |
+
def encode(s): return [stoi[c] for c in s]
|
40 |
+
def decode(l): return ''.join([itos[i] for i in l])
|
41 |
+
data = torch.tensor(encode(text), dtype=torch.long)
|
42 |
+
n = int(0.9*len(data))
|
43 |
+
train_data = data[:n]
|
44 |
+
val_data = data[n:]
|
45 |
+
|
46 |
+
|
47 |
+
return train_data, val_data, vocab_size, encode, decode
|
48 |
+
|
49 |
+
|
50 |
+
def get_batch(split, train_data, val_data, device, block_size, batch_size):
|
51 |
+
data = train_data if split == 'train' else val_data
|
52 |
+
ix = torch.randint(len(data) - block_size, (batch_size,))
|
53 |
+
x = torch.stack([data[i:i+block_size] for i in ix])
|
54 |
+
y = torch.stack([data[i+1:i+block_size+1] for i in ix])
|
55 |
+
x, y = x.to(device), y.to(device)
|
56 |
+
return x, y
|
57 |
+
|
58 |
+
|
59 |
+
@torch.no_grad()
|
60 |
+
def estimate_loss(model, get_batch, eval_iters, train_data, val_data, device, block_size, batch_size):
|
61 |
+
out = {}
|
62 |
+
model.eval()
|
63 |
+
for split in ['train', 'val']:
|
64 |
+
losses = torch.zeros(eval_iters)
|
65 |
+
for k in range(eval_iters):
|
66 |
+
X, Y = get_batch(split, train_data, val_data, device, block_size, batch_size)
|
67 |
+
logits, loss = model(X, Y)
|
68 |
+
losses[k] = loss.item()
|
69 |
+
out[split] = losses.mean()
|
70 |
+
model.train()
|
71 |
+
return out
|
core/utils/preprocessing.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
|
4 |
+
|
5 |
+
class DataLoader:
|
6 |
+
def __init__(self, data_path):
|
7 |
+
self.data_path = data_path
|
8 |
+
self.batch_size = None
|
9 |
+
self.block_size = None
|
10 |
+
self.data = None
|
11 |
+
self.train_data = None
|
12 |
+
self.val_data = None
|
13 |
+
|
14 |
+
def load_data(self, block_size=128, split=0.8, batch_size=64, device='cpu'):
|
15 |
+
with open(self.data_path, 'r') as f:
|
16 |
+
data = f.read()
|
17 |
+
self.block_size = block_size
|
18 |
+
self.batch_size = batch_size
|
19 |
+
self.device = device
|
20 |
+
self.data = data
|
21 |
+
|
22 |
+
def __len__(self):
|
23 |
+
return int(np.ceil(len(self.data) / self.batch_size))
|
24 |
+
|
25 |
+
def __getitem__(self, index):
|
26 |
+
indexes = self.indexes[index *
|
27 |
+
self.batch_size:(index + 1) * self.batch_size]
|
28 |
+
batch = [self.data[i] for i in indexes]
|
29 |
+
batch = np.array(batch)
|
30 |
+
return batch
|
31 |
+
|
32 |
+
def get_batch(self, split, device='cpu'):
|
33 |
+
if self.data is None:
|
34 |
+
raise ValueError('Data not loaded')
|
35 |
+
data = self.train_data if split == 'train' else self.val_data
|
36 |
+
ix = torch.randint(len(data) - self.block_size, (self.batch_size,))
|
37 |
+
x = torch.stack([data[i:i+self.block_size] for i in ix])
|
38 |
+
y = torch.stack([data[i+1:i+self.block_size+1] for i in ix])
|
39 |
+
x, y = x.to(device), y.to(device)
|
40 |
+
return x, y
|
41 |
+
|
42 |
+
|
43 |
+
class Encoder:
|
44 |
+
def __init__(self, data, type='char'):
|
45 |
+
self.data = data
|
46 |
+
self.type = type
|
47 |
+
self.vocab_size = None
|
48 |
+
if type == 'char':
|
49 |
+
self.chars = sorted(list(set(data)))
|
50 |
+
self.stoi = {ch: i for i, ch in enumerate(self.chars)}
|
51 |
+
self.itos = {i: ch for i, ch in enumerate(self.chars)}
|
52 |
+
self.vocab_size = len(self.chars)
|
53 |
+
elif type == 'word':
|
54 |
+
self.words = data.split()
|
55 |
+
self.stoi = {word: i for i, word in enumerate(self.words)}
|
56 |
+
self.itos = {i: word for i, word in enumerate(self.words)}
|
57 |
+
self.vocab_size = len(self.words)
|
58 |
+
else:
|
59 |
+
raise ValueError('Type must be either "char" or "word"')
|
60 |
+
|
61 |
+
def encode(self, string: str):
|
62 |
+
if self.type == 'char':
|
63 |
+
return torch.tensor([self.stoi[c] for c in string])
|
64 |
+
elif self.type == 'word':
|
65 |
+
return torch.tensor([self.stoi[w] for w in string.split()])
|
66 |
+
else:
|
67 |
+
raise ValueError('Type must be either "char" or "word"')
|
68 |
+
|
69 |
+
def decode(self, ids: list):
|
70 |
+
if self.type == 'char':
|
71 |
+
return ''.join([self.itos[i] for i in ids])
|
72 |
+
elif self.type == 'word':
|
73 |
+
return ' '.join([self.itos[i] for i in ids])
|
74 |
+
else:
|
75 |
+
raise ValueError('Type must be either "char" or "word"')
|
requirements.txt
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
altair==5.2.0
|
2 |
+
annotated-types==0.6.0
|
3 |
+
anyio==4.3.0
|
4 |
+
asttokens==2.4.1
|
5 |
+
attrs==23.2.0
|
6 |
+
blinker==1.7.0
|
7 |
+
cachetools==5.3.3
|
8 |
+
certifi==2024.2.2
|
9 |
+
charset-normalizer==3.3.2
|
10 |
+
click==8.1.7
|
11 |
+
colorama==0.4.6
|
12 |
+
comm==0.2.2
|
13 |
+
debugpy==1.8.1
|
14 |
+
decorator==5.1.1
|
15 |
+
dnspython==2.6.1
|
16 |
+
email_validator==2.1.1
|
17 |
+
executing==2.0.1
|
18 |
+
fastapi==0.110.3
|
19 |
+
filelock==3.13.1
|
20 |
+
fsspec==2024.2.0
|
21 |
+
gitdb==4.0.11
|
22 |
+
GitPython==3.1.42
|
23 |
+
h11==0.14.0
|
24 |
+
httpcore==1.0.5
|
25 |
+
httptools==0.6.1
|
26 |
+
httpx==0.27.0
|
27 |
+
idna==3.6
|
28 |
+
ipykernel==6.29.4
|
29 |
+
ipython==8.24.0
|
30 |
+
itsdangerous==2.2.0
|
31 |
+
jedi==0.19.1
|
32 |
+
Jinja2==3.1.3
|
33 |
+
joblib==1.4.0
|
34 |
+
jsonschema==4.21.1
|
35 |
+
jsonschema-specifications==2023.12.1
|
36 |
+
jupyter_client==8.6.1
|
37 |
+
jupyter_core==5.7.2
|
38 |
+
markdown-it-py==3.0.0
|
39 |
+
MarkupSafe==2.1.5
|
40 |
+
matplotlib-inline==0.1.7
|
41 |
+
mdurl==0.1.2
|
42 |
+
mpmath==1.3.0
|
43 |
+
nest-asyncio==1.6.0
|
44 |
+
networkx==3.2.1
|
45 |
+
numpy==1.26.4
|
46 |
+
orjson==3.10.2
|
47 |
+
packaging==23.2
|
48 |
+
pandas==2.2.1
|
49 |
+
parso==0.8.4
|
50 |
+
pillow==10.2.0
|
51 |
+
platformdirs==4.2.1
|
52 |
+
prompt-toolkit==3.0.43
|
53 |
+
protobuf==4.25.3
|
54 |
+
psutil==5.9.8
|
55 |
+
pure-eval==0.2.2
|
56 |
+
pyarrow==15.0.2
|
57 |
+
pydantic==2.7.1
|
58 |
+
pydantic-extra-types==2.7.0
|
59 |
+
pydantic-settings==2.2.1
|
60 |
+
pydantic_core==2.18.2
|
61 |
+
pydeck==0.8.1b0
|
62 |
+
Pygments==2.17.2
|
63 |
+
python-dateutil==2.9.0.post0
|
64 |
+
python-dotenv==1.0.1
|
65 |
+
python-multipart==0.0.9
|
66 |
+
pytz==2024.1
|
67 |
+
PyYAML==6.0.1
|
68 |
+
pyzmq==26.0.2
|
69 |
+
referencing==0.34.0
|
70 |
+
regex==2024.4.28
|
71 |
+
requests==2.31.0
|
72 |
+
rich==13.7.1
|
73 |
+
rpds-py==0.18.0
|
74 |
+
six==1.16.0
|
75 |
+
smmap==5.0.1
|
76 |
+
sniffio==1.3.1
|
77 |
+
stack-data==0.6.3
|
78 |
+
starlette==0.37.2
|
79 |
+
streamlit==1.32.2
|
80 |
+
sympy==1.12
|
81 |
+
tenacity==8.2.3
|
82 |
+
toml==0.10.2
|
83 |
+
toolz==0.12.1
|
84 |
+
torch==2.2.1
|
85 |
+
tornado==6.4
|
86 |
+
tqdm==4.66.2
|
87 |
+
traitlets==5.14.3
|
88 |
+
typing_extensions==4.10.0
|
89 |
+
tzdata==2024.1
|
90 |
+
ujson==5.9.0
|
91 |
+
urllib3==2.2.1
|
92 |
+
uvicorn==0.29.0
|
93 |
+
watchdog==4.0.0
|
94 |
+
watchfiles==0.21.0
|
95 |
+
wcwidth==0.2.13
|
96 |
+
websockets==12.0
|