ritikapatri commited on
Commit
733d17e
1 Parent(s): ab03a8e

Delete amp_finetune.py

Browse files
Files changed (1) hide show
  1. amp_finetune.py +0 -114
amp_finetune.py DELETED
@@ -1,114 +0,0 @@
1
- from Bio import SeqIO
2
- import pandas as pd
3
- import ssl
4
- import io
5
- from urllib.request import urlopen
6
- from datasets import Dataset
7
- from tokenizers import Tokenizer
8
- import torch
9
- from progen.progen2.models.progen.modeling_progen import ProGenForCausalLM
10
- from transformers import PreTrainedTokenizerFast, TrainingArguments, Trainer, DataCollatorForLanguageModeling
11
- import math
12
- import os
13
-
14
- # parsing data file
15
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
16
-
17
- # source 1
18
- ssl._create_default_https_context = ssl._create_unverified_context
19
- url = "https://aps.unmc.edu/assets/sequences/APD_sequence_release_09142020.fasta"
20
- response = urlopen(url)
21
- str = response.read().decode("utf-8", "ignore")
22
- aps_file = io.StringIO(str)
23
-
24
- sequences = []
25
- for record in SeqIO.parse(aps_file, "fasta"):
26
- header = record.id
27
- description = record.description
28
- sequence = record.seq
29
- a = sequence._data
30
- dec = a.decode()
31
- sequences.append(dec)
32
-
33
- # source 2
34
- # http://dramp.cpu-bioinfor.org/
35
- dramp_file = "amp_datasets/general_amps.fasta"
36
- for record in SeqIO.parse(dramp_file, "fasta"):
37
- header = record.id
38
- description = record.description
39
- sequence = record.seq
40
- a = sequence._data
41
- dec = a.decode()
42
- sequences.append(dec)
43
-
44
-
45
- # # source 3
46
- # https://dbaasp.org/home
47
- dbaasp_file = "amp_datasets/peptides-fasta.txt"
48
- for record in SeqIO.parse(dbaasp_file, "fasta"):
49
- header = record.id
50
- description = record.description
51
- sequence = record.seq
52
- a = sequence._data
53
- dec = a.decode()
54
- sequences.append(dec)
55
-
56
- train_len = int(len(sequences) * 0.8)
57
- test_len = int(len(sequences) - train_len)
58
- train = sequences[:train_len]
59
- test = sequences[:test_len]
60
-
61
- # model and tokenizer
62
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
63
- model = ProGenForCausalLM.from_pretrained('checkpoints/progen2-small', low_cpu_mem_usage=True).to(device)
64
-
65
- num_params = sum(p.numel() for p in model.parameters())
66
-
67
- tokenizer = PreTrainedTokenizerFast(tokenizer_file = "progen/progen2/tokenizer.json", pad_token="[PAD]")
68
- train = [tokenizer(sequence) for sequence in train]
69
- # print(train[0])
70
- test = [tokenizer(sequence) for sequence in test]
71
- # print(test[0])
72
-
73
- training_args = TrainingArguments(
74
- output_dir="./amp_model",
75
- evaluation_strategy = "epoch",
76
- save_strategy = "epoch",
77
- learning_rate=5e-4,
78
- per_device_train_batch_size=16,
79
- per_device_eval_batch_size=16,
80
- num_train_epochs=20,
81
- weight_decay=0.01,
82
- load_best_model_at_end=True,
83
- )
84
-
85
- data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False, mlm_probability=0.15)
86
-
87
- trainer = Trainer(
88
- model=model,
89
- args=training_args,
90
- train_dataset=train,
91
- eval_dataset=test,
92
- data_collator=data_collator
93
- )
94
-
95
- trainer.train()
96
-
97
- # use for downstream tasks
98
- # classication AMP / non-AMP
99
- # generative model for generating new AMP sequences
100
-
101
- target_input = tokenizer('')
102
- # attention_mask = torch.ones(len(input_ids), device=device)
103
-
104
- with torch.no_grad():
105
- output = model.generate(input_ids=None, max_length=1024, num_return_sequences=1, pad_token_id=tokenizer.pad_token_id)
106
- # generated_sequences = [tokenizer.decode(output)]
107
- tokenizer.batch_decode(output, skip_special_tokens=True)
108
- # generated_sequences = tokenizer.batch_decode(output, skip_special_tokens=True)
109
-
110
- generated_sequences = [tokenizer.decode(s, skip_special_tokens=True) for s in output]
111
-
112
- with open('output.txt', 'w') as f:
113
- for s in generated_sequences:
114
- f.write(s + "\n")