{
"cells": [
{
"cell_type": "markdown",
"id": "963e9ae0-ac68-44be-8c7d-fb9842784362",
"metadata": {},
"source": [
"# 4.7 基于llama的基因大模型指令微调"
]
},
{
"cell_type": "markdown",
"id": "c844103d-4e27-41b9-9bf1-c6a577846ab6",
"metadata": {},
"source": [
"### **大模型的指令微调(Instruction Fine-Tuning)**\n",
"\n",
"指令微调是指通过对大语言模型(如 GPT、T5、LLaMA 等)进行微调,使其能够更好地理解和执行人类以指令形式表达的任务。这种技术是大模型适配实际应用和增强用户交互能力的关键手段。\n",
"\n",
"---\n",
"\n",
"### **1. 指令微调的核心概念**\n",
"\n",
"指令微调的目标是通过在包含指令的专用数据集上进行微调,让模型能够:\n",
"1. 理解用户的任务需求(以自然语言表达的指令形式)。\n",
"2. 根据指令内容生成符合预期的高质量响应。\n",
"3. 适应多任务场景,减少特定任务的单独训练需求。\n",
"\n",
"---\n",
"\n",
"### **2. 指令微调的关键特点**\n",
"\n",
"1. **多任务统一**:\n",
" - 不需要针对每个任务单独微调,而是通过指令微调使模型能适应多种任务。\n",
" \n",
"2. **自然语言交互**:\n",
" - 用户可以用自然语言指令与模型交互,无需提供特定格式的输入。\n",
"\n",
"3. **泛化能力**:\n",
" - 微调后的模型能够对未见过的任务产生合理的推断和响应。\n",
"\n",
"---\n",
"\n",
"### **3. 数据集的构建与使用**\n",
"\n",
"#### **(1)指令微调数据集的特点**\n",
"- 数据通常包含以下三部分:\n",
" 1. **指令(Instruction)**:任务描述或问题,例如“将以下文本翻译为法语”。\n",
" 2. **输入(Input)**:任务相关的上下文或数据,可以为空。\n",
" 3. **输出(Output)**:模型期望生成的结果。\n",
"\n",
"#### **(2)常用指令微调数据集**\n",
"- **FLAN**:包含多个 NLP 任务的指令数据集,用于 T5 等模型的微调。\n",
"- **OpenAI 提供的指令数据**:如 GPT 系列的 ChatGPT 调优数据集。\n",
"- **InstructGPT 数据**:通过人类标注的多任务指令数据,用于模型优化。\n",
"- **Self-Instruct**:通过模型自生成指令和回答,进一步扩展训练数据。\n",
"\n",
"#### **(3)构建自己的数据集**\n",
"- 如果需要特定领域的指令微调,可以自行构建数据集:\n",
" - 收集任务需求和示例。\n",
" - 设计多样化的指令。\n",
" - 使用专家标注或模型辅助生成高质量答案。\n",
"\n",
"---\n",
"\n",
"### **4. 微调的步骤**\n",
"\n",
"#### **(1)加载基础模型**\n",
"从 Hugging Face 或其他框架加载预训练的大语言模型,例如 GPT-2、T5、LLaMA。\n",
"\n",
"#### **(2)准备数据集**\n",
"将指令微调数据集格式化为:\n",
"```python\n",
"{\n",
" \"instruction\": \"Translate the following text to French\",\n",
" \"input\": \"Hello, how are you?\",\n",
" \"output\": \"Bonjour, comment ça va?\"\n",
"}\n",
"```\n",
"\n",
"#### **(3)定义微调方法**\n",
"使用 `Trainer` 或分布式框架(如 DeepSpeed、Accelerate)进行微调。\n",
"\n",
"---\n",
"\n",
"### **5. 示例代码:指令微调实现**\n",
"\n",
"以下是基于 Hugging Face 的指令微调代码示例:\n",
"\n",
"```python\n",
"from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer\n",
"from datasets import load_dataset\n",
"\n",
"# 1. 加载预训练模型和分词器\n",
"model_name = \"gpt2\"\n",
"tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
"model = AutoModelForCausalLM.from_pretrained(model_name)\n",
"\n",
"# 2. 加载指令微调数据集\n",
"# 数据格式应包含 instruction, input, output 字段\n",
"dataset = load_dataset(\"path/to/instruction_dataset\")\n",
"\n",
"# 3. 数据预处理\n",
"def preprocess_function(example):\n",
" # 将指令和输入拼接成完整的提示\n",
" prompt = example[\"instruction\"]\n",
" if example[\"input\"]:\n",
" prompt += f\"\\n{example['input']}\"\n",
" labels = example[\"output\"]\n",
" tokenized = tokenizer(prompt, truncation=True, max_length=512, padding=\"max_length\")\n",
" with tokenizer.as_target_tokenizer():\n",
" tokenized_labels = tokenizer(labels, truncation=True, max_length=512, padding=\"max_length\")\n",
" tokenized[\"labels\"] = tokenized_labels[\"input_ids\"]\n",
" return tokenized\n",
"\n",
"tokenized_datasets = dataset.map(preprocess_function, batched=True)\n",
"\n",
"# 4. 设置训练参数\n",
"training_args = TrainingArguments(\n",
" output_dir=\"./instruction_finetuned_model\",\n",
" per_device_train_batch_size=4,\n",
" num_train_epochs=3,\n",
" evaluation_strategy=\"epoch\",\n",
" save_strategy=\"epoch\",\n",
" learning_rate=5e-5,\n",
" weight_decay=0.01,\n",
" logging_dir=\"./logs\",\n",
" fp16=True,\n",
")\n",
"\n",
"# 5. 定义 Trainer\n",
"trainer = Trainer(\n",
" model=model,\n",
" args=training_args,\n",
" train_dataset=tokenized_datasets[\"train\"],\n",
" eval_dataset=tokenized_datasets[\"test\"],\n",
" tokenizer=tokenizer,\n",
")\n",
"\n",
"# 6. 开始训练\n",
"trainer.train()\n",
"\n",
"# 7. 保存模型\n",
"model.save_pretrained(\"./instruction_finetuned_model\")\n",
"tokenizer.save_pretrained(\"./instruction_finetuned_model\")\n",
"```\n",
"\n",
"---\n",
"\n",
"### **6. 指令微调的挑战**\n",
"\n",
"1. **数据质量**:\n",
" - 低质量或噪声数据可能导致模型生成结果不符合指令。\n",
"\n",
"2. **指令覆盖范围**:\n",
" - 数据集指令种类不足会限制模型的泛化能力。\n",
"\n",
"3. **计算资源需求**:\n",
" - 大模型的微调需要高性能 GPU 和大容量存储。\n",
"\n",
"4. **灾难性遗忘**:\n",
" - 微调过程中可能导致模型丧失部分原始能力。\n",
"\n",
"---\n",
"\n",
"### **7. 指令微调的应用场景**\n",
"\n",
"1. **多任务问答**:\n",
" - 适配多任务场景,支持翻译、总结、推理等功能。\n",
"\n",
"2. **特定领域优化**:\n",
" - 在法律、医疗等特定领域的任务指令上进行微调。\n",
"\n",
"3. **用户交互优化**:\n",
" - 提升模型对自然语言指令的理解和响应能力。\n",
"\n",
"4. **开放式对话生成**:\n",
" - 优化模型在对话场景下的表现,例如 ChatGPT 的微调。\n",
"\n",
"---\n",
"\n",
"### **总结**\n",
"\n",
"指令微调通过在特定格式的数据集上进一步训练大模型,使其能够更好地理解和执行用户的自然语言指令。这种方法适合多任务场景,并能提升模型的交互能力和领域适应性。借助高质量的指令数据集和高效的微调技术,大模型在实际应用中的表现可以得到显著提升。"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e77f8b39-e75a-4014-a98a-bde5b2534bf1",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"id": "7be8b814-42f6-4fb6-bf4b-ae23292030f6",
"metadata": {},
"source": [
"## 持续预训练 VS 指令微调"
]
},
{
"cell_type": "markdown",
"id": "f9bed0ae-337d-49af-85f0-c8e6263d78db",
"metadata": {},
"source": [
"**大模型的持续预训练**和**指令微调**是两种针对大模型的后续优化策略,虽然它们的目标都是提升模型性能,但在应用场景、方法和效果等方面有明显区别。以下是它们的对比分析:\n",
"\n",
"---\n",
"\n",
"### **1. 概念与目标**\n",
"\n",
"| **特性** | **持续预训练** | **指令微调** |\n",
"|------------------------|-----------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------|\n",
"| **定义** | 在通用预训练模型上,使用新的大规模语料(通用或领域特定数据)进行进一步预训练。 | 在包含指令任务的数据集上对大模型进行微调,以提升模型对人类指令的理解和执行能力。 |\n",
"| **目标** | 提升模型的通用能力或适应特定领域的语言理解与生成能力。 | 提高模型对多任务指令的泛化能力,让模型更好地理解和执行自然语言表达的具体任务。 |\n",
"| **典型应用** | 领域适配(医学、法律、金融)、性能优化、跨语言适配等。 | 多任务问答、开放式对话生成、翻译、推理等需要用户直接交互的场景。 |\n",
"\n",
"---\n",
"\n",
"### **2. 数据使用**\n",
"\n",
"| **特性** | **持续预训练** | **指令微调** |\n",
"|------------------------|-----------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------|\n",
"| **数据类型** | 通用语料(如新闻、社交媒体文本)或领域特定语料(如 PubMed、法律文档、金融报告)。 | 任务指令数据集,包括指令(Instruction)、输入(Input)和输出(Output)。 |\n",
"| **数据构建** | 通常需要清洗和去重大规模语料数据,避免与原始预训练数据重叠。 | 通常由人工标注或模型生成的指令数据构成,例如 FLAN、InstructGPT 数据集。 |\n",
"| **多样性要求** | 数据应覆盖尽可能广的领域或目标领域的多种场景,以提升模型在这些场景的表现。 | 数据需要覆盖多种任务类型(如翻译、分类、摘要)和丰富的指令表达形式,以提高模型对多任务的适配能力。 |\n",
"\n",
"---\n",
"\n",
"### **3. 方法与技术**\n",
"\n",
"| **特性** | **持续预训练** | **指令微调** |\n",
"|------------------------|-----------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------|\n",
"| **主要技术** | 继续使用自监督学习目标(如语言建模、掩码预测)进行训练。 | 使用监督学习,通常以任务输入和目标输出对为数据,通过微调适配特定任务需求。 |\n",
"| **模型调整** | - 可选择全量参数更新或冻结部分参数。
- 可结合参数高效微调技术(如 LoRA、Adapter)。 | - 通常使用监督训练方式,可能结合参数高效微调技术(如 LoRA)。 |\n",
"| **学习率** | 通常使用较小的学习率(如 `1e-5` 或更小),以防止破坏原始权重。 | 同样使用较小的学习率,但任务指令微调可能需要更高的关注任务特定的标签对准。 |\n",
"\n",
"---\n",
"\n",
"### **4. 模型能力与效果**\n",
"\n",
"| **特性** | **持续预训练** | **指令微调** |\n",
"|------------------------|-----------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------|\n",
"| **提升的能力** | - 对领域特定语言模式和知识的适配性提升显著。
- 对未见过的通用场景生成能力增强(扩展模型知识广度)。 | - 显著提升模型对指令理解的能力,尤其是自然语言表达的任务需求。
- 对多任务和零样本任务的泛化能力有较大提升。 |\n",
"| **局限性** | - 对具体任务的直接适配能力较弱,可能需要额外的任务微调。
- 数据选择不当可能导致灾难性遗忘。 | - 依赖高质量的指令数据集,数据质量不高会导致模型生成结果不稳定。
- 对通用能力的提升有限。 |\n",
"\n",
"---\n",
"\n",
"### **5. 应用场景与示例**\n",
"\n",
"| **特性** | **持续预训练** | **指令微调** |\n",
"|------------------------|-----------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------|\n",
"| **典型应用场景** | - 医学文献总结(通过 PubMed 语料持续预训练)。
- 法律条文分析(通过法律文档进一步训练)。
- 增强多语言生成能力(跨语言语料)。 | - ChatGPT 的多任务对话生成。
- 翻译、摘要、问答等用户交互任务的泛化处理。 |\n",
"| **实际示例** | - BioBERT:在 BERT 基础上使用生物医学语料持续预训练的模型。
- FinBERT:针对金融领域持续预训练的语言模型。 | - InstructGPT:在 GPT-3 基础上进行指令微调,用于多任务用户交互。
- FLAN-T5:通过 FLAN 数据集进行指令微调。 |\n",
"\n",
"---\n",
"\n",
"### **6. 持续预训练与指令微调的结合**\n",
"\n",
"持续预训练和指令微调可以结合使用,形成一个从领域适配到任务适配的完整流程:\n",
"1. **持续预训练**:\n",
" - 先在领域特定数据(如医学、法律、金融语料)上进行持续预训练,获取领域知识。\n",
"2. **指令微调**:\n",
" - 再利用多任务指令数据集对模型微调,使其能够高效执行领域内的多样化任务。\n",
"\n",
"这种结合方式特别适用于需要领域知识和任务适配的场景,例如医学问答系统或金融文本分析。\n",
"\n",
"---\n",
"\n",
"### **总结**\n",
"\n",
"| **维度** | **持续预训练** | **指令微调** |\n",
"|------------------------|-------------------------------------|----------------------------------|\n",
"| **目标** | 增强通用能力或适配特定领域。 | 提升对任务指令的理解和执行能力。 |\n",
"| **数据集** | 通用或领域语料。 | 指令数据集,包含输入和输出对。 |\n",
"| **方法** | 自监督学习,扩展语言建模能力。 | 监督学习,强化任务适配能力。 |\n",
"| **适用场景** | 领域特定任务(如医学、法律)。 | 多任务交互(如问答、对话生成)。 |\n",
"| **局限性** | 对具体任务适配较弱。 | 通用能力提升有限,依赖数据质量。 |\n",
"\n",
"两者各有侧重,且在许多场景下可以结合使用,形成一个强大的任务和领域适配框架。"
]
},
{
"cell_type": "markdown",
"id": "f97a705a-b946-4dc1-a173-a9df033d6f2b",
"metadata": {},
"source": [
"## 本节任务\n",
"本节任务是基于上一节预训练的llama生物大模型。对一些生物学任务进行微调,包含了多个不同类型的分类问题和多序列交换问题。具体可见sft_data下的数据。"
]
},
{
"cell_type": "markdown",
"id": "9782db62-95bd-40a6-9759-966b9a0b362e",
"metadata": {},
"source": [
"## 代码运行\n",
"\n",
"```\n",
"\n",
"#微调\n",
"./run_sft.sh\n",
"\n",
"运行时间约3小时\n",
"\n",
"#合并模型\n",
"./merge_sft_model.sh\n",
"\n",
"```"
]
},
{
"cell_type": "markdown",
"id": "182b82c4-d484-4c15-a600-03c3b51367ec",
"metadata": {},
"source": [
"## 模型验证"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "5aa3d240-44e1-4811-8f61-d6ff2500a798",
"metadata": {},
"outputs": [],
"source": [
"import subprocess\n",
"import os\n",
"# 设置环境变量, autodl一般区域\n",
"result = subprocess.run('bash -c \"source /etc/network_turbo && env | grep proxy\"', shell=True, capture_output=True, text=True)\n",
"output = result.stdout\n",
"for line in output.splitlines():\n",
" if '=' in line:\n",
" var, value = line.split('=', 1)\n",
" os.environ[var] = value"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "054a2956-9045-4ad5-a878-1bfc84ad4ed8",
"metadata": {},
"outputs": [],
"source": [
"from transformers import AutoTokenizer, AutoConfig,AutoModel\n",
"from transformers import DataCollatorForLanguageModeling\n",
"from transformers import Trainer, TrainingArguments\n",
"from transformers import AutoConfig, AutoModelForCausalLM,LlamaForCausalLM,LlamaTokenizer\n",
"from tokenizers import Tokenizer\n",
"from datasets import load_dataset"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "63c8bf16-9576-41bc-b27c-c92ba4289cf4",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"DatasetDict({\n",
" train: Dataset({\n",
" features: ['instruction', 'input', 'output'],\n",
" num_rows: 19839\n",
" })\n",
"})"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from datasets import load_dataset\n",
"dna_ft_dataset = load_dataset('json', data_files='val_data.json')\n",
"dna_ft_dataset"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "95928da3-ca64-4a17-80f4-945da395702c",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"DatasetDict({\n",
" train: Dataset({\n",
" features: ['instruction', 'input', 'output'],\n",
" num_rows: 1983\n",
" })\n",
" test: Dataset({\n",
" features: ['instruction', 'input', 'output'],\n",
" num_rows: 17856\n",
" })\n",
"})"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data = dna_ft_dataset[\"train\"].train_test_split(train_size=0.1, seed=42)\n",
"data"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "a3e65bcd-85ce-4261-8ba6-7665c4ec60e2",
"metadata": {},
"outputs": [],
"source": [
"tokenizer = LlamaTokenizer.from_pretrained(\"dnahlm-llama-7b-sft-v0\") #dnagpt/dnahlm-llama-7b-sft-v0\n",
"tokenizer.pad_token = tokenizer.eos_token"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "3d3fe49b-f48f-42b2-bc97-028e443111e4",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "4f060ff2029447b9bad5e2b2e40b7133",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Loading checkpoint shards: 0%| | 0/3 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"LlamaForCausalLM(\n",
" (model): LlamaModel(\n",
" (embed_tokens): Embedding(91644, 4096, padding_idx=0)\n",
" (layers): ModuleList(\n",
" (0-31): 32 x LlamaDecoderLayer(\n",
" (self_attn): LlamaSdpaAttention(\n",
" (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
" (k_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
" (v_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
" (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
" (rotary_emb): LlamaRotaryEmbedding()\n",
" )\n",
" (mlp): LlamaMLP(\n",
" (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)\n",
" (up_proj): Linear(in_features=4096, out_features=11008, bias=False)\n",
" (down_proj): Linear(in_features=11008, out_features=4096, bias=False)\n",
" (act_fn): SiLU()\n",
" )\n",
" (input_layernorm): LlamaRMSNorm((4096,), eps=1e-06)\n",
" (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-06)\n",
" )\n",
" )\n",
" (norm): LlamaRMSNorm((4096,), eps=1e-06)\n",
" (rotary_emb): LlamaRotaryEmbedding()\n",
" )\n",
" (lm_head): Linear(in_features=4096, out_features=91644, bias=False)\n",
")"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model = LlamaForCausalLM.from_pretrained(\"dnahlm-llama-7b-sft-v0\") #continue pretrain\n",
"model"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "c54df9fe-86c4-4963-b313-b438894bf9dd",
"metadata": {},
"outputs": [],
"source": [
"#构建提示词\n",
"def format_input(entry):\n",
" instruction_text = (\n",
" f\"Below is an instruction that describes a task. \"\n",
" f\"Write a response that appropriately completes the request.\"\n",
" f\"\\n\\n### Instruction:\\n{entry['instruction']}\"\n",
" )\n",
"\n",
" input_text = f\"\\n\\n### Input:\\n{entry['input']}\" if entry[\"input\"] else \"\"\n",
"\n",
" return instruction_text + input_text + \"\\n\\n### Response:\\n\"\n",
"\n",
"#构建提示词\n",
"def build_prompt(entry):\n",
"\n",
" input_data = format_input(entry)\n",
"\n",
" desired_response = entry['output']\n",
"\n",
" return input_data + desired_response\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "ee540cfb-1f6e-4e02-a3bc-c814e43685cb",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'instruction': 'Determine core promoter detection of following dna sequence, The result will be one of the following: Non-promoter, promoter.',\n",
" 'input': 'CCGTGCGACCGGAAGTGGGGCGGCGACCCCGGAAGTCCCCGCCGGGTGCAGCTTGGTCGGTTCGATCGCC',\n",
" 'output': 'promoter'}"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"example = data[\"test\"][0]\n",
"example"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "7ee35528-7b3f-4e60-b88b-1bc3e950012b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Below is an instruction that describes a task. Write a response that appropriately completes the request.\n",
"\n",
"### Instruction:\n",
"Determine core promoter detection of following dna sequence, The result will be one of the following: Non-promoter, promoter.\n",
"\n",
"### Input:\n",
"CCGTGCGACCGGAAGTGGGGCGGCGACCCCGGAAGTCCCCGCCGGGTGCAGCTTGGTCGGTTCGATCGCC\n",
"\n",
"### Response:\n",
"promoter\n"
]
}
],
"source": [
"prompt = build_prompt(example)\n",
"print(prompt)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "8aa6f38f-3bcc-4566-8a66-a541db91e031",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['▁Below',\n",
" '▁is',\n",
" '▁an',\n",
" '▁instruction',\n",
" '▁that',\n",
" '▁describes',\n",
" '▁a',\n",
" '▁task',\n",
" '.',\n",
" '▁Write',\n",
" '▁a',\n",
" '▁response',\n",
" '▁that',\n",
" '▁appropri',\n",
" 'ately',\n",
" '▁comple',\n",
" 'tes',\n",
" '▁the',\n",
" '▁request',\n",
" '.',\n",
" '<0x0A>',\n",
" '<0x0A>',\n",
" '##',\n",
" '#',\n",
" '▁Inst',\n",
" 'ruction',\n",
" ':',\n",
" '<0x0A>',\n",
" 'Det',\n",
" 'erm',\n",
" 'ine',\n",
" '▁core',\n",
" '▁prom',\n",
" 'oter',\n",
" '▁detection',\n",
" '▁of',\n",
" '▁following',\n",
" '▁d',\n",
" 'na',\n",
" '▁sequence',\n",
" ',',\n",
" '▁The',\n",
" '▁result',\n",
" '▁will',\n",
" '▁be',\n",
" '▁one',\n",
" '▁of',\n",
" '▁the',\n",
" '▁following',\n",
" ':',\n",
" '▁Non',\n",
" '-',\n",
" 'prom',\n",
" 'oter',\n",
" ',',\n",
" '▁prom',\n",
" 'oter',\n",
" '.',\n",
" '<0x0A>',\n",
" '<0x0A>',\n",
" '##',\n",
" '#',\n",
" '▁Input',\n",
" ':',\n",
" '<0x0A>',\n",
" 'CCG',\n",
" 'TGCG',\n",
" 'ACCGG',\n",
" 'AAG',\n",
" 'TGGGGC',\n",
" 'GGCG',\n",
" 'ACCCCGG',\n",
" 'AAG',\n",
" 'TCCCC',\n",
" 'GCCGGG',\n",
" 'TGCAGC',\n",
" 'TTGG',\n",
" 'TCGG',\n",
" 'TTCG',\n",
" 'ATCGCC',\n",
" '<0x0A>',\n",
" '<0x0A>',\n",
" '##',\n",
" '#',\n",
" '▁Response',\n",
" ':',\n",
" '<0x0A>',\n",
" 'prom',\n",
" 'oter']"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tokenizer.tokenize(prompt)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "11875339-4901-4912-86e5-afe8c74921d9",
"metadata": {},
"outputs": [],
"source": [
"def inference(text, model, tokenizer, max_input_tokens=1000, max_output_tokens=1000):\n",
" # Tokenize\n",
" input_ids = tokenizer.encode(\n",
" text,\n",
" return_tensors=\"pt\",\n",
" truncation=True,\n",
" max_length=max_input_tokens\n",
" # return_attention_mask=True,\n",
" )\n",
"\n",
" # Generate\n",
" device = model.device\n",
" generated_tokens_with_prompt = model.generate(\n",
" input_ids=input_ids.to(device),\n",
" #max_length=max_output_tokens,\n",
" max_new_tokens=8,\n",
" temperature=0.01 # 控制生成的多样性\n",
" )\n",
"\n",
" # Decode\n",
" generated_text_with_prompt = tokenizer.decode(generated_tokens_with_prompt[0], skip_special_tokens=True)\n",
" generated_text_answer = generated_text_with_prompt[len(text):]\n",
"\n",
"\n",
" return generated_text_answer\n",
"\n",
"# 如果需要进一步清理\n",
"def clean_generated_text(text):\n",
" # 去除 'Ġ' 符号并替换为空格\n",
" text = text.replace('Ġ', ' ')\n",
" # 去除多余的空格\n",
" text = ' '.join(text.split())\n",
" return text"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "1b02644a-8b24-45aa-b22d-0f7ce2270dd9",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"input (test): Below is an instruction that describes a task. Write a response that appropriately completes the request.\n",
"\n",
"### Instruction:\n",
"Determine core promoter detection of following dna sequence, The result will be one of the following: Non-promoter, promoter.\n",
"\n",
"### Input:\n",
"CCGTGCGACCGGAAGTGGGGCGGCGACCCCGGAAGTCCCCGCCGGGTGCAGCTTGGTCGGTTCGATCGCC\n",
"\n",
"### Response:\n",
"\n",
"real answer: promoter\n",
"--------------------------\n",
"\n",
"model's answer: \n",
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/root/miniconda3/lib/python3.12/site-packages/transformers/generation/configuration_utils.py:601: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0.01` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`.\n",
" warnings.warn(\n",
"Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"promoter\n"
]
}
],
"source": [
"input_text = format_input(data[\"test\"][0])\n",
"\n",
"print(\"input (test):\", input_text)\n",
"\n",
"print(\"real answer:\", data[\"test\"][0][\"output\"])\n",
"\n",
"print(\"--------------------------\\n\")\n",
"\n",
"print(\"model's answer: \\n\")\n",
"print(inference(input_text, model, tokenizer))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e2df1569-7f70-46ee-b93f-cbd879e32e14",
"metadata": {},
"outputs": [],
"source": [
"test_data = data[\"test\"].shuffle(seed=199).select(range(100))\n",
"\n",
"data_list = []\n",
"\n",
"for entry in test_data:\n",
" input_text = format_input(entry)\n",
" #print(input_text)\n",
" response_text = inference(input_text, model, tokenizer)\n",
" #print(response_text)\n",
" data = {\n",
" \"instruction\":entry[\"instruction\"],\n",
" \"input\":entry[\"input\"],\n",
" \"output\":entry[\"output\"],\n",
" \"model_response\":response_text\n",
" }\n",
"\n",
" data_list.append(data)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0c6e47cb-1b64-4690-a51d-f1816b82f15f",
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"\n",
"# 定义输出文件路径\n",
"output_file = 'llama-sft-2.json'\n",
"\n",
"# 将 Dataset 对象导出为 JSON 文件\n",
"# test_data.to_json(output_file)\n",
"with open(output_file, \"w\") as file:\n",
" json.dump(data_list, file, indent=4) # \"indent\" for pretty-printing\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "68831e19-5a99-46d8-9f40-e8bf6957dbfc",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Donor Sites |||||||||||| Non-Splice Sites\n",
"promoter |||||||||||| promoter\n",
"promoter |||||||||||| promoter\n",
"promoter |||||||||||| Non-promoter\n",
"promoter |||||||||||| promoter\n",
"Donor Sites |||||||||||| Non-Splice Sites\n",
"promoter |||||||||||| promoter\n",
"promoter |||||||||||| Non-promoter\n",
"Non-promoter |||||||||||| promoter\n",
"Non-promoter |||||||||||| Non-promoter\n",
"Donor Sites |||||||||||| Donor Sites\n",
"Non-promoter |||||||||||| Non-promoter\n",
"Non-promoter |||||||||||| Non-promoter\n",
"Non-promoter |||||||||||| promoter\n",
"promoter |||||||||||| promoter\n",
"promoter |||||||||||| promoter\n",
"Donor Sites |||||||||||| Splice Sites\n",
"Background Sequences |||||||||||| Background Sequences\n",
"Non-promoter |||||||||||| Non-promoter\n",
"Non-promoter |||||||||||| Non-promoter\n",
"promoter |||||||||||| Non-promoter\n",
"promoter |||||||||||| promoter\n",
"promoter |||||||||||| promoter\n",
"promoter |||||||||||| Non-promoter\n",
"promoter |||||||||||| promoter\n",
"promoter |||||||||||| promoter\n",
"Non-promoter |||||||||||| Non-promoter\n",
"Non-Splice Sites |||||||||||| Non-Splice Sites\n",
"Non-promoter |||||||||||| Non-promoter\n",
"promoter |||||||||||| Non-promoter\n",
"Non-promoter |||||||||||| Non-promoter\n",
"Binding Sites |||||||||||| Background Sequences\n",
"Non-promoter |||||||||||| Non-promoter\n",
"Non-Splice Sites |||||||||||| Non-Splice Sites\n",
"Non-promoter |||||||||||| Non-promoter\n",
"Non-promoter |||||||||||| promoter\n",
"Non-promoter |||||||||||| Non-promoter\n",
"Donor Sites |||||||||||| Donor Sites\n",
"Non-promoter |||||||||||| promoter\n",
"promoter |||||||||||| promoter\n",
"Background Sequences |||||||||||| Background Sequences\n",
"Non-promoter |||||||||||| Non-promoter\n",
"Binding Sites |||||||||||| Binding Sites\n",
"promoter |||||||||||| promoter\n",
"Non-promoter |||||||||||| Non-promoter\n",
"Non-promoter |||||||||||| Non-promoter\n",
"Non-promoter |||||||||||| Non-promoter\n",
"Non-promoter |||||||||||| Non-promoter\n",
"Donor Sites |||||||||||| Donor Sites\n",
"promoter |||||||||||| promoter\n",
"promoter |||||||||||| promoter\n",
"Non-promoter |||||||||||| Non-promoter\n",
"Binding Sites |||||||||||| Binding Sites\n",
"promoter |||||||||||| Non-promoter\n",
"promoter |||||||||||| promoter\n",
"Background Sequences |||||||||||| Binding Sites\n",
"promoter |||||||||||| promoter\n",
"Non-promoter |||||||||||| Non-promoter\n",
"Background Sequences |||||||||||| Background Sequences\n",
"promoter |||||||||||| promoter\n",
"promoter |||||||||||| Non-promoter\n",
"promoter |||||||||||| promoter\n",
"Donor Sites |||||||||||| Non-Splice Sites\n",
"Binding Sites |||||||||||| Binding Sites\n",
"promoter |||||||||||| promoter\n",
"Donor Sites |||||||||||| Donor Sites\n",
"Non-promoter |||||||||||| promoter\n",
"Binding Sites |||||||||||| Binding Sites\n",
"Donor Sites |||||||||||| Donor Sites\n",
"Non-promoter |||||||||||| Non-promoter\n",
"Donor Sites |||||||||||| Donor Sites\n",
"Non-promoter |||||||||||| promoter\n",
"promoter |||||||||||| promoter\n",
"promoter |||||||||||| promoter\n",
"promoter |||||||||||| promoter\n",
"Non-promoter |||||||||||| Non-promoter\n",
"Acceptor Sites |||||||||||| Acceptor Sites\n",
"promoter |||||||||||| promoter\n",
"Donor Sites |||||||||||| Donor Sites\n",
"Donor Sites |||||||||||| Acceptor Sites\n",
"promoter |||||||||||| promoter\n",
"promoter |||||||||||| promoter\n",
"promoter |||||||||||| promoter\n",
"Non-promoter |||||||||||| Non-promoter\n",
"Non-promoter |||||||||||| promoter\n",
"promoter |||||||||||| Non-promoter\n",
"Non-promoter |||||||||||| Non-promoter\n",
"promoter |||||||||||| promoter\n",
"Background Sequences |||||||||||| Binding Sites\n",
"Acceptor Sites |||||||||||| Splice Sites\n",
"Non-Splice Sites |||||||||||| Non-Splice Sites\n",
"Donor Sites |||||||||||| Non-Splice Sites\n",
"Donor Sites |||||||||||| Donor Sites\n",
"Non-promoter |||||||||||| Non-promoter\n",
"promoter |||||||||||| promoter\n",
"Background Sequences |||||||||||| Binding Sites\n",
"promoter |||||||||||| promoter\n",
"promoter |||||||||||| promoter\n",
"Acceptor Sites |||||||||||| Splice Sites\n",
"promoter |||||||||||| promoter\n",
"presicion 0.73 same 0.3\n"
]
}
],
"source": [
"import json\n",
"from tqdm import tqdm\n",
"\n",
"\n",
"\n",
"with open(output_file, \"r\") as file:\n",
" test_data = json.load(file)\n",
"\n",
"all_num = len(test_data)\n",
"right_sum = 0\n",
"same_sum = 0\n",
"for item in test_data:\n",
" output = item[\"output\"]\n",
" #output = \" \".join(tokenizer.tokenize(output))\n",
" model_response = item[\"model_response\"]\n",
"\n",
" print(output,\"||||||||||||\", model_response)\n",
"\n",
" if model_response == output: #same it\n",
" same_sum = same_sum + 1\n",
" \n",
" if output.find(\"Non\")==-1: # no Non\n",
" if model_response.find(output)!=-1 and model_response.find(\"Non\")==-1: #find it, but no Non\n",
" right_sum = right_sum + 1\n",
" else:\n",
" if model_response.find(output)!=-1: #find it\n",
" right_sum = right_sum + 1\n",
"\n",
"\n",
"print(\"presicion\", right_sum/all_num, \"same\", same_sum/all_num)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7bc38f47-4a7d-43eb-abe8-db4310d280e3",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}