Spaces:
Running
Running
Upload 2 files
Browse files- dnagpt/class6.ipynb +501 -0
- dnagpt/get_data.ipynb +0 -0
dnagpt/class6.ipynb
ADDED
@@ -0,0 +1,501 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"id": "29888eda-2755-4add-af9c-8927afb07db4",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [
|
9 |
+
{
|
10 |
+
"name": "stderr",
|
11 |
+
"output_type": "stream",
|
12 |
+
"text": [
|
13 |
+
"/home/liming/anaconda3/envs/pytorch/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
14 |
+
" from .autonotebook import tqdm as notebook_tqdm\n"
|
15 |
+
]
|
16 |
+
}
|
17 |
+
],
|
18 |
+
"source": [
|
19 |
+
"from transformers import AutoTokenizer, AutoModel\n",
|
20 |
+
"tokenizer = AutoTokenizer.from_pretrained('dnagpt/human_gpt2-v1')\n",
|
21 |
+
"tokenizer.pad_token = tokenizer.eos_token"
|
22 |
+
]
|
23 |
+
},
|
24 |
+
{
|
25 |
+
"cell_type": "code",
|
26 |
+
"execution_count": 3,
|
27 |
+
"id": "75e3778d-9cbb-48c4-8203-0f71f485a49d",
|
28 |
+
"metadata": {},
|
29 |
+
"outputs": [],
|
30 |
+
"source": [
|
31 |
+
"full_model = AutoModel.from_pretrained('dnagpt/human_gpt2-v1')"
|
32 |
+
]
|
33 |
+
},
|
34 |
+
{
|
35 |
+
"cell_type": "code",
|
36 |
+
"execution_count": 4,
|
37 |
+
"id": "be845310-8215-4434-bb92-d44c343f274e",
|
38 |
+
"metadata": {},
|
39 |
+
"outputs": [
|
40 |
+
{
|
41 |
+
"name": "stdout",
|
42 |
+
"output_type": "stream",
|
43 |
+
"text": [
|
44 |
+
"transformers.models.gpt2.modeling_gpt2\n"
|
45 |
+
]
|
46 |
+
}
|
47 |
+
],
|
48 |
+
"source": [
|
49 |
+
"gena_module_name = full_model.__class__.__module__\n",
|
50 |
+
"print(gena_module_name)"
|
51 |
+
]
|
52 |
+
},
|
53 |
+
{
|
54 |
+
"cell_type": "code",
|
55 |
+
"execution_count": 5,
|
56 |
+
"id": "bb9ee647-cac2-4475-ac44-2f11aef99735",
|
57 |
+
"metadata": {
|
58 |
+
"scrolled": true
|
59 |
+
},
|
60 |
+
"outputs": [],
|
61 |
+
"source": [
|
62 |
+
"import importlib\n",
|
63 |
+
"myclass = importlib.import_module(gena_module_name)\n",
|
64 |
+
"#dir(myclass)"
|
65 |
+
]
|
66 |
+
},
|
67 |
+
{
|
68 |
+
"cell_type": "code",
|
69 |
+
"execution_count": 6,
|
70 |
+
"id": "62ebfcc0-5d46-4c3b-b462-a3842a985e9b",
|
71 |
+
"metadata": {},
|
72 |
+
"outputs": [
|
73 |
+
{
|
74 |
+
"data": {
|
75 |
+
"text/plain": [
|
76 |
+
"transformers.models.gpt2.modeling_gpt2.GPT2ForSequenceClassification"
|
77 |
+
]
|
78 |
+
},
|
79 |
+
"execution_count": 6,
|
80 |
+
"metadata": {},
|
81 |
+
"output_type": "execute_result"
|
82 |
+
}
|
83 |
+
],
|
84 |
+
"source": [
|
85 |
+
"cls = getattr(importlib.import_module(gena_module_name), 'GPT2ForSequenceClassification')\n",
|
86 |
+
"cls"
|
87 |
+
]
|
88 |
+
},
|
89 |
+
{
|
90 |
+
"cell_type": "code",
|
91 |
+
"execution_count": 7,
|
92 |
+
"id": "4523e9b6-3613-41e4-b81d-c139d044e70c",
|
93 |
+
"metadata": {
|
94 |
+
"scrolled": true
|
95 |
+
},
|
96 |
+
"outputs": [
|
97 |
+
{
|
98 |
+
"name": "stderr",
|
99 |
+
"output_type": "stream",
|
100 |
+
"text": [
|
101 |
+
"Some weights of GPT2ForSequenceClassification were not initialized from the model checkpoint at dnagpt/human_gpt2-v1 and are newly initialized: ['score.weight']\n",
|
102 |
+
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
|
103 |
+
]
|
104 |
+
}
|
105 |
+
],
|
106 |
+
"source": [
|
107 |
+
"model = cls.from_pretrained('dnagpt/human_gpt2-v1', num_labels=2)\n",
|
108 |
+
"model.config.pad_token_id = model.config.eos_token_id"
|
109 |
+
]
|
110 |
+
},
|
111 |
+
{
|
112 |
+
"cell_type": "code",
|
113 |
+
"execution_count": 8,
|
114 |
+
"id": "bd024199-16e6-4eb1-b404-c49dc535469b",
|
115 |
+
"metadata": {},
|
116 |
+
"outputs": [],
|
117 |
+
"source": [
|
118 |
+
"from datasets import load_dataset\n",
|
119 |
+
"# load ~11k samples from promoters prediction dataset\n",
|
120 |
+
"dataset = load_dataset(\"yurakuratov/example_promoters_2k\")['train'].train_test_split(test_size=0.1)"
|
121 |
+
]
|
122 |
+
},
|
123 |
+
{
|
124 |
+
"cell_type": "code",
|
125 |
+
"execution_count": 9,
|
126 |
+
"id": "a3a5213d-66f0-4d17-876c-de4426145412",
|
127 |
+
"metadata": {},
|
128 |
+
"outputs": [
|
129 |
+
{
|
130 |
+
"data": {
|
131 |
+
"text/plain": [
|
132 |
+
"DatasetDict({\n",
|
133 |
+
" train: Dataset({\n",
|
134 |
+
" features: ['sequence', 'promoter_presence'],\n",
|
135 |
+
" num_rows: 10656\n",
|
136 |
+
" })\n",
|
137 |
+
" test: Dataset({\n",
|
138 |
+
" features: ['sequence', 'promoter_presence'],\n",
|
139 |
+
" num_rows: 1184\n",
|
140 |
+
" })\n",
|
141 |
+
"})"
|
142 |
+
]
|
143 |
+
},
|
144 |
+
"execution_count": 9,
|
145 |
+
"metadata": {},
|
146 |
+
"output_type": "execute_result"
|
147 |
+
}
|
148 |
+
],
|
149 |
+
"source": [
|
150 |
+
"dataset"
|
151 |
+
]
|
152 |
+
},
|
153 |
+
{
|
154 |
+
"cell_type": "code",
|
155 |
+
"execution_count": 10,
|
156 |
+
"id": "36a23870-1f56-466e-b2bb-c5047072b8f0",
|
157 |
+
"metadata": {},
|
158 |
+
"outputs": [
|
159 |
+
{
|
160 |
+
"data": {
|
161 |
+
"text/plain": [
|
162 |
+
"{'sequence': 'TTGATATGCCCGCAATAAATGTTAGCCCTTCCTCTTTACAAACAATTATTATTTCATTAAGACTTACTCCTGGGAAGAAGTCACACTTTTCTGTGTGCTAACATCATGTACATTCTGATTCTCTGTCCTGTGTGCCATTCCATCATGCTCTCTACCAGCAAAACGCCACCAGTCTAAGGAGGAGACAGAACCGCATAAAAATAGTCACACAGTCTTCAGAGCTGCACAGGGAGTGGCACAGGGGCTTCACAGGCCTTGGGAAGTTTGTAAAGTGAGGGAGAAGCACTTTACTGTCGCCAGCTCCATTGAAAACAGATTGTAGCAAATGACTTCCCAGGAGCTCCTCCTCTCACCTTGGTGAATGTTGGGAGAGAATGAAGACTGCGGCTGAGAGGCAGCGCCTTGAGCCTTGAACGGAAAGAGGGACAGGGCCTGGGGGCTGTCACTGAGGTGCAGAAATGGGTTTTGGGCTGGGCATCAGACAACCCCAGATCCCAACACTGGCTCAGCCAATTCCTGGCTGTGTGACATGGAGCCATTACTGTACTTTCCTTGCCAGGCCTCCATTTTTCCTCTCTCTAGAATGGGGAGGCTGGTAATACCCACTTTGGAGGGCAGTGTGAGGACCGAAGGGGGCTGTTAAGTGCTCGGCGCGCAGCGAGCACTCAGGACGGGTGGCAGCGGTCACCTTTGTCGCCACCGTCCAAGCCTTTGCAGAGGCTCAGGGCCGAGCCCCGGACATGCGCATGCGCACGCGCACCCGGCCGGCCAGGTGAGAAGGAGGCGGCCGCCTCGCACCTGGGCTTTAACCCTTCCCAGTCACGGGAGTGGGCGGAGTGCACGCAGGGACGGAGTCGGGGGCGAGCAGCGAGAGAGGCCGGGCTTGACCCTTGAGGCGGGGTGGGCAGCGTCCGGGGAGCGCGCCGAGCCGCTTCTAGCCTGGGGTCCGCGCGCGCCGCGGGGAGGAGGGGGAGTGCCGAGGGGAGCGAAGTCTCGCGAGATCGCGCGGCGGCGGCGGGAGCGGCGGCGGCGGCGGCCGGGGAGGTGAGCGGCGGGCGGTGGCGGCCGTTGGGGGCTGAGGCCGGGTGAGAGCGGCCGAGACCGAGGGCTGGGTGGGGGAGCGGGCCTGGTGAACATCCCGCGTCCCAGGCAGCCATCCCTGTTCGTCCCCGCAGCAGCTGCGGCGCCGTTTGGTCCGCACCGTCCCCGGGACTCGCGCGGGGCGGGGCGGCCGGGGCCGCGGGCCGGTGCCTCCTTCAACCGGTCCTCGCGCCCGGGCGCGACCCCGGGGCCGCCTCAGTGTCCATCTCGACTGCAGAGTTGGCCCTGCTGACAGTGGCAGCGGCGTTTATGGAGCCTCTGCTGTGTGCCGGGCACTTAAGTTGCGCGCTTGCACGCCGGGGCTGCACAGAGAGGTGGAGCGCCCGGGCCAGAGCCACACTGCGAGGAGGCGGTGGGAGGGGACCGGGACCCCTGCGCCCCCTTCCTCTTCCCACCTCCCACTCCCCTCAGCCCTTTTCTCCTGGGGTCCCGACCAGGCTTCAGAGGGGGTGTCTGGGAGCGCCCTGAAGTTGGATGCAGCTTGGTATATGGGTTTCTTGGGAGAGAGTTTTTCCAGGGCTTTTGTCTCGTTCCTAAAAGGGTCCCCGATCCCCAAGTGGTCAAGAACCATTGTTTGGCCCTGCTTGGAAGCAGAGGTGGGCTGAGAGCAGCGGGCGTGTGCTGGTAGTGGGGTGTCAGGCCCAATGAAGTGGAATATTGGAGCCCAGGGCTTGCCTTAGGCAAAGGAAACTCTTCTCTAAAAGTTCGATTTTTGTTTCAGCACGCGATTCCACTCTTGGGAACAGTGGTTCTCAACCTCGTCAGAATCAAGGTCCCTTCTTCCCCTTTGCAGCAAACATGTATAAAGCCCTCTTTACTACCCTGAAATGAAATTCATAAACAATGTGACCTCCCTACACACATCCTTGAAAAGCAGTATGATGGTTTGATTGAAGCGT',\n",
|
163 |
+
" 'promoter_presence': 1}"
|
164 |
+
]
|
165 |
+
},
|
166 |
+
"execution_count": 10,
|
167 |
+
"metadata": {},
|
168 |
+
"output_type": "execute_result"
|
169 |
+
}
|
170 |
+
],
|
171 |
+
"source": [
|
172 |
+
"dataset['train'][0]"
|
173 |
+
]
|
174 |
+
},
|
175 |
+
{
|
176 |
+
"cell_type": "code",
|
177 |
+
"execution_count": 11,
|
178 |
+
"id": "19e22500-c1e0-444e-9b5c-2a1a2af81997",
|
179 |
+
"metadata": {},
|
180 |
+
"outputs": [
|
181 |
+
{
|
182 |
+
"name": "stdout",
|
183 |
+
"output_type": "stream",
|
184 |
+
"text": [
|
185 |
+
"# base pairs: 2000\n"
|
186 |
+
]
|
187 |
+
}
|
188 |
+
],
|
189 |
+
"source": [
|
190 |
+
"print('# base pairs: ', len(dataset['train'][0]['sequence']))"
|
191 |
+
]
|
192 |
+
},
|
193 |
+
{
|
194 |
+
"cell_type": "code",
|
195 |
+
"execution_count": 12,
|
196 |
+
"id": "16db0a5d-7aec-4d59-b44b-d22ff43045fd",
|
197 |
+
"metadata": {},
|
198 |
+
"outputs": [
|
199 |
+
{
|
200 |
+
"name": "stdout",
|
201 |
+
"output_type": "stream",
|
202 |
+
"text": [
|
203 |
+
"tokens: TTGATATG CCCGC AATAAATG TTAGCCC TTCCTC TTTACAAAC AATTATT ATTTCATT AAGAC TTAC TCCTGGG AAGAAGTC ACAC TTTTC TGTGTGC TAAC ATCATG TAC ATTCTG ATTC TCTGTCC TGTGTGCC ATTCC ATCATGC TCTC TACCAGC AAAAC GCC ACCAGTC TAAGG AGGAGAC AGAACC GC ATAAAA ATAGTC ACAC AGTCTTC AGAGC TGCAC AGGG AGTGGCAC AGGGGC TTCAC AGGCC TTGGGAAG TTTGTAA AGTGAGGG AGAAGC ACTTTAC TGTCGCC AGCTCC ATTGAAAAC AGATTG TAGCAAATG ACTTCCC AGGAGC TCCTCC TCTCACC TTGGTG AATG TTGGGAG AGAATGAAG ACTGC GGC TGAG AGGCAGC GCC TTGAGCC TTGAAC GG AAAG AGGGAC AGGGCC TGGGGGC TGTCAC TGAGG TGCAGAA ATGGG TTTTGGGC TGGGC ATCAGAC AACCCC AGATCCC AACACTGGC TCAGCC AATTCC TGGC TGTGTGAC ATGGAGCC ATTACTG TACTTTCC TTGCC AGGCC TCCATT TTTCC TCTCTC TAGAA TGGGG AGGCTGG TAA TACCC ACTTTGG AGGGC AGTGTG AGGACCG AAGG GGGCTG TTAAG TGCTC GGC GCGC AGCG AGCACTC AGGAC GGGTGGC AGCGG TCACC TTTGTC GCC ACCGTCC AAGCC TTTGC AGAGGC TC AGGGCCG AGCCCC GGAC ATGCGC ATGCGC ACGC GCACCCGGCC GGCC AGGTG AGAAGG AGGC GGCCGCC TCGC ACCTGGGC TTTAACCC TTCCC AGTCAC GGG AGTGGGC GG AGTGCAC GCAGGG ACGG AGTCGG GGGCG AGCAGCG AGAGAGGCC GGGC TTGACCC TTGAGGC GGGG TGGGC AGCGTCC GGGG AGCGC GCCG AGCCGC TTC TAGCC TGGGGTCC GCGC GCGCC GCGGGG AGGAGGGGG AGTGCCG AGGGG AGCG AAG TCTCGC GAG ATCGC GC GGCGGC GGC GGGAGC GGCGGCGGCGGC GGCC GGGG AGGTG AGCGGC GGGCGG TGGC GGCCG TTGGGGGC TG AGGCCGGG TGAGAGC GGCC GAG ACCG AGGGCTGGG TGGGGG AGC GGGCC TGGTGAAC ATCCC GCG TCCCAGGC AGCC ATCCCTG TTCG TCCCCGC AGCAGC TGCGGC GCCG TTTGG TCCGC ACCG TCCCC GGG ACTCGC GCGGGGC GGGGC GGCC GGGGCC GC GGGCC GGTGCC TCCTTC AACCGG TCCTC GCGCCC GGGC GCG ACCCC GGGGCC GCC TCAGTG TCC ATCTCG ACTGC AGAG TTGGCCC TGCTGAC AGTGGC AGC GGCG TTTATGG AGCCTCTGC TGTGTGCC GGGC ACTT AAGTTGC GCGC TTGCAC GCC GGGGC TGCAC AGAG AGGTGG AGCGCCC GGGCC AGAGCC ACAC TGCG AGGAGGCGG TGGGAGGGG ACCGGG ACCCCTGC GCCCCC TTCCTC TTCCCACC TCCCAC TCCCC TCAGCCC TTTTCTCC TGGGG TCCCG ACC AGGCTTC AGAGGGGG TGTCTGGG AGCGCCC TGAAG TTGGATGC AGCTTGG TATATGGG TTTC TTGGG AGAG AGTTTT TCCAGGGC TTTTG TCTCG TTCCTAA AAGGG TCCCCG ATCCCC AAGTGG TCAAG AACCATTG TTTGGCCC TGCTTGG AAGCAGAGG TGGGCTGAG AGCAGC GGGC GTGTGC TGGTAG TGGGG TGTC AGGCCC AATGAAG TGGAA TATTGG AGCCC AGGGC TTGCC TTAGGC AAAGG AAACTC TTCTC TAAAAG TTCG ATTTTTG TTTC AGCAC GCG ATTCCACTC TTGGGAAC AGTGG TTCTCAACC TCG TCAGAA TCAAGG TCCCTTC TTCCCC TTTGC AGCAAAC ATG TATAAAG CCCTC TTTAC TACCC TGAAATGAA ATTCATAA ACAATGTG ACCTCCC TACACAC ATCCTTG AAAAGC AGTATG ATGG TTTGATTG AAGCG T\n"
|
204 |
+
]
|
205 |
+
}
|
206 |
+
],
|
207 |
+
"source": [
|
208 |
+
"print('tokens: ', ' '.join(tokenizer.tokenize(dataset['train'][0]['sequence'])))"
|
209 |
+
]
|
210 |
+
},
|
211 |
+
{
|
212 |
+
"cell_type": "code",
|
213 |
+
"execution_count": 13,
|
214 |
+
"id": "9876c07a-6332-40cc-aa84-2d26e6a3a5ab",
|
215 |
+
"metadata": {},
|
216 |
+
"outputs": [
|
217 |
+
{
|
218 |
+
"name": "stdout",
|
219 |
+
"output_type": "stream",
|
220 |
+
"text": [
|
221 |
+
"# tokens: 350\n"
|
222 |
+
]
|
223 |
+
}
|
224 |
+
],
|
225 |
+
"source": [
|
226 |
+
"print('# tokens: ', len(tokenizer.tokenize(dataset['train'][0]['sequence'])))"
|
227 |
+
]
|
228 |
+
},
|
229 |
+
{
|
230 |
+
"cell_type": "code",
|
231 |
+
"execution_count": 14,
|
232 |
+
"id": "7c4b0442-5fdf-4845-8bf2-4c3b6cc86e84",
|
233 |
+
"metadata": {},
|
234 |
+
"outputs": [
|
235 |
+
{
|
236 |
+
"name": "stderr",
|
237 |
+
"output_type": "stream",
|
238 |
+
"text": [
|
239 |
+
"Map: 100%|██████████| 10656/10656 [00:01<00:00, 10066.72 examples/s]\n",
|
240 |
+
"Map: 100%|██████████| 1184/1184 [00:00<00:00, 10784.01 examples/s]\n"
|
241 |
+
]
|
242 |
+
}
|
243 |
+
],
|
244 |
+
"source": [
|
245 |
+
"def preprocess_labels(example):\n",
|
246 |
+
" example['label'] = example['promoter_presence']\n",
|
247 |
+
" return example\n",
|
248 |
+
"\n",
|
249 |
+
"dataset = dataset.map(preprocess_labels)"
|
250 |
+
]
|
251 |
+
},
|
252 |
+
{
|
253 |
+
"cell_type": "code",
|
254 |
+
"execution_count": 20,
|
255 |
+
"id": "46e912af-1aab-4ed2-8007-0f3f970b2255",
|
256 |
+
"metadata": {},
|
257 |
+
"outputs": [],
|
258 |
+
"source": [
|
259 |
+
"def preprocess_function(examples):\n",
|
260 |
+
" # just truncate right, but for some tasks symmetric truncation from left and right is more reasonable\n",
|
261 |
+
" # set max_length to 128 to make experiments faster\n",
|
262 |
+
" return tokenizer(examples[\"sequence\"], truncation=True, max_length=256) #max_length 128"
|
263 |
+
]
|
264 |
+
},
|
265 |
+
{
|
266 |
+
"cell_type": "code",
|
267 |
+
"execution_count": 21,
|
268 |
+
"id": "0ef16794-71fd-4404-a51e-b55b5cae9aab",
|
269 |
+
"metadata": {},
|
270 |
+
"outputs": [
|
271 |
+
{
|
272 |
+
"name": "stderr",
|
273 |
+
"output_type": "stream",
|
274 |
+
"text": [
|
275 |
+
"Map: 100%|██████████| 10656/10656 [00:01<00:00, 6044.89 examples/s]\n",
|
276 |
+
"Map: 100%|██████████| 1184/1184 [00:00<00:00, 6994.12 examples/s]\n"
|
277 |
+
]
|
278 |
+
}
|
279 |
+
],
|
280 |
+
"source": [
|
281 |
+
"tokenized_dataset = dataset.map(preprocess_function, batched=True)"
|
282 |
+
]
|
283 |
+
},
|
284 |
+
{
|
285 |
+
"cell_type": "code",
|
286 |
+
"execution_count": 22,
|
287 |
+
"id": "765c199b-9d13-4fd2-ac51-1e5d154766fa",
|
288 |
+
"metadata": {},
|
289 |
+
"outputs": [],
|
290 |
+
"source": [
|
291 |
+
"from transformers import DataCollatorWithPadding\n",
|
292 |
+
"data_collator = DataCollatorWithPadding(tokenizer=tokenizer)"
|
293 |
+
]
|
294 |
+
},
|
295 |
+
{
|
296 |
+
"cell_type": "code",
|
297 |
+
"execution_count": 23,
|
298 |
+
"id": "afd6f48f-5524-4740-b8ea-54a79fd79dad",
|
299 |
+
"metadata": {},
|
300 |
+
"outputs": [
|
301 |
+
{
|
302 |
+
"data": {
|
303 |
+
"text/plain": [
|
304 |
+
"DatasetDict({\n",
|
305 |
+
" train: Dataset({\n",
|
306 |
+
" features: ['sequence', 'promoter_presence', 'label', 'input_ids', 'attention_mask'],\n",
|
307 |
+
" num_rows: 10656\n",
|
308 |
+
" })\n",
|
309 |
+
" test: Dataset({\n",
|
310 |
+
" features: ['sequence', 'promoter_presence', 'label', 'input_ids', 'attention_mask'],\n",
|
311 |
+
" num_rows: 1184\n",
|
312 |
+
" })\n",
|
313 |
+
"})"
|
314 |
+
]
|
315 |
+
},
|
316 |
+
"execution_count": 23,
|
317 |
+
"metadata": {},
|
318 |
+
"output_type": "execute_result"
|
319 |
+
}
|
320 |
+
],
|
321 |
+
"source": [
|
322 |
+
"tokenized_dataset"
|
323 |
+
]
|
324 |
+
},
|
325 |
+
{
|
326 |
+
"cell_type": "code",
|
327 |
+
"execution_count": 24,
|
328 |
+
"id": "e1b708bb-21e4-416d-b191-79f302ca4e93",
|
329 |
+
"metadata": {},
|
330 |
+
"outputs": [],
|
331 |
+
"source": [
|
332 |
+
"from transformers import TrainingArguments, Trainer\n",
|
333 |
+
"import numpy as np\n",
|
334 |
+
"\n",
|
335 |
+
"\n",
|
336 |
+
"def compute_metrics(eval_pred):\n",
|
337 |
+
" predictions, labels = eval_pred\n",
|
338 |
+
" predictions = np.argmax(predictions, axis=1)\n",
|
339 |
+
" return {'accuracy': (predictions==labels).sum() / len(labels)}\n",
|
340 |
+
"\n",
|
341 |
+
"# change training hyperparameters to archive better quality\n",
|
342 |
+
"training_args = TrainingArguments(\n",
|
343 |
+
" output_dir=\"test_run\",\n",
|
344 |
+
" learning_rate=1e-4,\n",
|
345 |
+
" lr_scheduler_type=\"constant_with_warmup\",\n",
|
346 |
+
" warmup_ratio=0.1,\n",
|
347 |
+
" optim='adamw_torch',\n",
|
348 |
+
" weight_decay=0.0,\n",
|
349 |
+
" per_device_train_batch_size=32,\n",
|
350 |
+
" per_device_eval_batch_size=32,\n",
|
351 |
+
" num_train_epochs=5,\n",
|
352 |
+
" evaluation_strategy=\"epoch\",\n",
|
353 |
+
" save_strategy=\"epoch\",\n",
|
354 |
+
" logging_strategy=\"epoch\",\n",
|
355 |
+
" load_best_model_at_end=True\n",
|
356 |
+
")\n",
|
357 |
+
"\n",
|
358 |
+
"trainer = Trainer(\n",
|
359 |
+
" model=model,\n",
|
360 |
+
" args=training_args,\n",
|
361 |
+
" train_dataset=tokenized_dataset[\"train\"],\n",
|
362 |
+
" eval_dataset=tokenized_dataset[\"test\"],\n",
|
363 |
+
" tokenizer=tokenizer,\n",
|
364 |
+
" data_collator=data_collator,\n",
|
365 |
+
" compute_metrics=compute_metrics,\n",
|
366 |
+
")"
|
367 |
+
]
|
368 |
+
},
|
369 |
+
{
|
370 |
+
"cell_type": "code",
|
371 |
+
"execution_count": 26,
|
372 |
+
"id": "082754e6-53fc-4a95-9d1e-887c42975d71",
|
373 |
+
"metadata": {},
|
374 |
+
"outputs": [
|
375 |
+
{
|
376 |
+
"name": "stderr",
|
377 |
+
"output_type": "stream",
|
378 |
+
"text": [
|
379 |
+
"/home/liming/anaconda3/envs/pytorch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
380 |
+
" warnings.warn('Was asked to gather along dimension 0, but all '\n"
|
381 |
+
]
|
382 |
+
},
|
383 |
+
{
|
384 |
+
"data": {
|
385 |
+
"text/html": [
|
386 |
+
"\n",
|
387 |
+
" <div>\n",
|
388 |
+
" \n",
|
389 |
+
" <progress value='835' max='835' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
|
390 |
+
" [835/835 08:19, Epoch 5/5]\n",
|
391 |
+
" </div>\n",
|
392 |
+
" <table border=\"1\" class=\"dataframe\">\n",
|
393 |
+
" <thead>\n",
|
394 |
+
" <tr style=\"text-align: left;\">\n",
|
395 |
+
" <th>Epoch</th>\n",
|
396 |
+
" <th>Training Loss</th>\n",
|
397 |
+
" <th>Validation Loss</th>\n",
|
398 |
+
" <th>Accuracy</th>\n",
|
399 |
+
" </tr>\n",
|
400 |
+
" </thead>\n",
|
401 |
+
" <tbody>\n",
|
402 |
+
" <tr>\n",
|
403 |
+
" <td>1</td>\n",
|
404 |
+
" <td>0.573000</td>\n",
|
405 |
+
" <td>0.462363</td>\n",
|
406 |
+
" <td>0.778716</td>\n",
|
407 |
+
" </tr>\n",
|
408 |
+
" <tr>\n",
|
409 |
+
" <td>2</td>\n",
|
410 |
+
" <td>0.360200</td>\n",
|
411 |
+
" <td>0.504239</td>\n",
|
412 |
+
" <td>0.760135</td>\n",
|
413 |
+
" </tr>\n",
|
414 |
+
" <tr>\n",
|
415 |
+
" <td>3</td>\n",
|
416 |
+
" <td>0.201600</td>\n",
|
417 |
+
" <td>0.529274</td>\n",
|
418 |
+
" <td>0.795608</td>\n",
|
419 |
+
" </tr>\n",
|
420 |
+
" <tr>\n",
|
421 |
+
" <td>4</td>\n",
|
422 |
+
" <td>0.103800</td>\n",
|
423 |
+
" <td>0.946800</td>\n",
|
424 |
+
" <td>0.784628</td>\n",
|
425 |
+
" </tr>\n",
|
426 |
+
" <tr>\n",
|
427 |
+
" <td>5</td>\n",
|
428 |
+
" <td>0.064900</td>\n",
|
429 |
+
" <td>1.108120</td>\n",
|
430 |
+
" <td>0.744932</td>\n",
|
431 |
+
" </tr>\n",
|
432 |
+
" </tbody>\n",
|
433 |
+
"</table><p>"
|
434 |
+
],
|
435 |
+
"text/plain": [
|
436 |
+
"<IPython.core.display.HTML object>"
|
437 |
+
]
|
438 |
+
},
|
439 |
+
"metadata": {},
|
440 |
+
"output_type": "display_data"
|
441 |
+
},
|
442 |
+
{
|
443 |
+
"name": "stderr",
|
444 |
+
"output_type": "stream",
|
445 |
+
"text": [
|
446 |
+
"/home/liming/anaconda3/envs/pytorch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
447 |
+
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
448 |
+
"/home/liming/anaconda3/envs/pytorch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
449 |
+
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
450 |
+
"/home/liming/anaconda3/envs/pytorch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
451 |
+
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
|
452 |
+
"/home/liming/anaconda3/envs/pytorch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
453 |
+
" warnings.warn('Was asked to gather along dimension 0, but all '\n"
|
454 |
+
]
|
455 |
+
},
|
456 |
+
{
|
457 |
+
"data": {
|
458 |
+
"text/plain": [
|
459 |
+
"TrainOutput(global_step=835, training_loss=0.2607016295016169, metrics={'train_runtime': 499.203, 'train_samples_per_second': 106.73, 'train_steps_per_second': 1.673, 'total_flos': 6960945435770880.0, 'train_loss': 0.2607016295016169, 'epoch': 5.0})"
|
460 |
+
]
|
461 |
+
},
|
462 |
+
"execution_count": 26,
|
463 |
+
"metadata": {},
|
464 |
+
"output_type": "execute_result"
|
465 |
+
}
|
466 |
+
],
|
467 |
+
"source": [
|
468 |
+
"trainer.train()"
|
469 |
+
]
|
470 |
+
},
|
471 |
+
{
|
472 |
+
"cell_type": "code",
|
473 |
+
"execution_count": null,
|
474 |
+
"id": "55c9d1eb-6fc6-451f-b596-870ccaa81d8d",
|
475 |
+
"metadata": {},
|
476 |
+
"outputs": [],
|
477 |
+
"source": []
|
478 |
+
}
|
479 |
+
],
|
480 |
+
"metadata": {
|
481 |
+
"kernelspec": {
|
482 |
+
"display_name": "Python 3 (ipykernel)",
|
483 |
+
"language": "python",
|
484 |
+
"name": "python3"
|
485 |
+
},
|
486 |
+
"language_info": {
|
487 |
+
"codemirror_mode": {
|
488 |
+
"name": "ipython",
|
489 |
+
"version": 3
|
490 |
+
},
|
491 |
+
"file_extension": ".py",
|
492 |
+
"mimetype": "text/x-python",
|
493 |
+
"name": "python",
|
494 |
+
"nbconvert_exporter": "python",
|
495 |
+
"pygments_lexer": "ipython3",
|
496 |
+
"version": "3.10.11"
|
497 |
+
}
|
498 |
+
},
|
499 |
+
"nbformat": 4,
|
500 |
+
"nbformat_minor": 5
|
501 |
+
}
|
dnagpt/get_data.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|