{ "cells": [ { "cell_type": "markdown", "id": "963e9ae0-ac68-44be-8c7d-fb9842784362", "metadata": {}, "source": [ "# 4.7 基于llama的基因大模型指令微调" ] }, { "cell_type": "markdown", "id": "c844103d-4e27-41b9-9bf1-c6a577846ab6", "metadata": {}, "source": [ "### **大模型的指令微调(Instruction Fine-Tuning)**\n", "\n", "指令微调是指通过对大语言模型(如 GPT、T5、LLaMA 等)进行微调,使其能够更好地理解和执行人类以指令形式表达的任务。这种技术是大模型适配实际应用和增强用户交互能力的关键手段。\n", "\n", "---\n", "\n", "### **1. 指令微调的核心概念**\n", "\n", "指令微调的目标是通过在包含指令的专用数据集上进行微调,让模型能够:\n", "1. 理解用户的任务需求(以自然语言表达的指令形式)。\n", "2. 根据指令内容生成符合预期的高质量响应。\n", "3. 适应多任务场景,减少特定任务的单独训练需求。\n", "\n", "---\n", "\n", "### **2. 指令微调的关键特点**\n", "\n", "1. **多任务统一**:\n", " - 不需要针对每个任务单独微调,而是通过指令微调使模型能适应多种任务。\n", " \n", "2. **自然语言交互**:\n", " - 用户可以用自然语言指令与模型交互,无需提供特定格式的输入。\n", "\n", "3. **泛化能力**:\n", " - 微调后的模型能够对未见过的任务产生合理的推断和响应。\n", "\n", "---\n", "\n", "### **3. 数据集的构建与使用**\n", "\n", "#### **(1)指令微调数据集的特点**\n", "- 数据通常包含以下三部分:\n", " 1. **指令(Instruction)**:任务描述或问题,例如“将以下文本翻译为法语”。\n", " 2. **输入(Input)**:任务相关的上下文或数据,可以为空。\n", " 3. **输出(Output)**:模型期望生成的结果。\n", "\n", "#### **(2)常用指令微调数据集**\n", "- **FLAN**:包含多个 NLP 任务的指令数据集,用于 T5 等模型的微调。\n", "- **OpenAI 提供的指令数据**:如 GPT 系列的 ChatGPT 调优数据集。\n", "- **InstructGPT 数据**:通过人类标注的多任务指令数据,用于模型优化。\n", "- **Self-Instruct**:通过模型自生成指令和回答,进一步扩展训练数据。\n", "\n", "#### **(3)构建自己的数据集**\n", "- 如果需要特定领域的指令微调,可以自行构建数据集:\n", " - 收集任务需求和示例。\n", " - 设计多样化的指令。\n", " - 使用专家标注或模型辅助生成高质量答案。\n", "\n", "---\n", "\n", "### **4. 微调的步骤**\n", "\n", "#### **(1)加载基础模型**\n", "从 Hugging Face 或其他框架加载预训练的大语言模型,例如 GPT-2、T5、LLaMA。\n", "\n", "#### **(2)准备数据集**\n", "将指令微调数据集格式化为:\n", "```python\n", "{\n", " \"instruction\": \"Translate the following text to French\",\n", " \"input\": \"Hello, how are you?\",\n", " \"output\": \"Bonjour, comment ça va?\"\n", "}\n", "```\n", "\n", "#### **(3)定义微调方法**\n", "使用 `Trainer` 或分布式框架(如 DeepSpeed、Accelerate)进行微调。\n", "\n", "---\n", "\n", "### **5. 示例代码:指令微调实现**\n", "\n", "以下是基于 Hugging Face 的指令微调代码示例:\n", "\n", "```python\n", "from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer\n", "from datasets import load_dataset\n", "\n", "# 1. 加载预训练模型和分词器\n", "model_name = \"gpt2\"\n", "tokenizer = AutoTokenizer.from_pretrained(model_name)\n", "model = AutoModelForCausalLM.from_pretrained(model_name)\n", "\n", "# 2. 加载指令微调数据集\n", "# 数据格式应包含 instruction, input, output 字段\n", "dataset = load_dataset(\"path/to/instruction_dataset\")\n", "\n", "# 3. 数据预处理\n", "def preprocess_function(example):\n", " # 将指令和输入拼接成完整的提示\n", " prompt = example[\"instruction\"]\n", " if example[\"input\"]:\n", " prompt += f\"\\n{example['input']}\"\n", " labels = example[\"output\"]\n", " tokenized = tokenizer(prompt, truncation=True, max_length=512, padding=\"max_length\")\n", " with tokenizer.as_target_tokenizer():\n", " tokenized_labels = tokenizer(labels, truncation=True, max_length=512, padding=\"max_length\")\n", " tokenized[\"labels\"] = tokenized_labels[\"input_ids\"]\n", " return tokenized\n", "\n", "tokenized_datasets = dataset.map(preprocess_function, batched=True)\n", "\n", "# 4. 设置训练参数\n", "training_args = TrainingArguments(\n", " output_dir=\"./instruction_finetuned_model\",\n", " per_device_train_batch_size=4,\n", " num_train_epochs=3,\n", " evaluation_strategy=\"epoch\",\n", " save_strategy=\"epoch\",\n", " learning_rate=5e-5,\n", " weight_decay=0.01,\n", " logging_dir=\"./logs\",\n", " fp16=True,\n", ")\n", "\n", "# 5. 定义 Trainer\n", "trainer = Trainer(\n", " model=model,\n", " args=training_args,\n", " train_dataset=tokenized_datasets[\"train\"],\n", " eval_dataset=tokenized_datasets[\"test\"],\n", " tokenizer=tokenizer,\n", ")\n", "\n", "# 6. 开始训练\n", "trainer.train()\n", "\n", "# 7. 保存模型\n", "model.save_pretrained(\"./instruction_finetuned_model\")\n", "tokenizer.save_pretrained(\"./instruction_finetuned_model\")\n", "```\n", "\n", "---\n", "\n", "### **6. 指令微调的挑战**\n", "\n", "1. **数据质量**:\n", " - 低质量或噪声数据可能导致模型生成结果不符合指令。\n", "\n", "2. **指令覆盖范围**:\n", " - 数据集指令种类不足会限制模型的泛化能力。\n", "\n", "3. **计算资源需求**:\n", " - 大模型的微调需要高性能 GPU 和大容量存储。\n", "\n", "4. **灾难性遗忘**:\n", " - 微调过程中可能导致模型丧失部分原始能力。\n", "\n", "---\n", "\n", "### **7. 指令微调的应用场景**\n", "\n", "1. **多任务问答**:\n", " - 适配多任务场景,支持翻译、总结、推理等功能。\n", "\n", "2. **特定领域优化**:\n", " - 在法律、医疗等特定领域的任务指令上进行微调。\n", "\n", "3. **用户交互优化**:\n", " - 提升模型对自然语言指令的理解和响应能力。\n", "\n", "4. **开放式对话生成**:\n", " - 优化模型在对话场景下的表现,例如 ChatGPT 的微调。\n", "\n", "---\n", "\n", "### **总结**\n", "\n", "指令微调通过在特定格式的数据集上进一步训练大模型,使其能够更好地理解和执行用户的自然语言指令。这种方法适合多任务场景,并能提升模型的交互能力和领域适应性。借助高质量的指令数据集和高效的微调技术,大模型在实际应用中的表现可以得到显著提升。" ] }, { "cell_type": "code", "execution_count": null, "id": "e77f8b39-e75a-4014-a98a-bde5b2534bf1", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "7be8b814-42f6-4fb6-bf4b-ae23292030f6", "metadata": {}, "source": [ "## 持续预训练 VS 指令微调" ] }, { "cell_type": "markdown", "id": "f9bed0ae-337d-49af-85f0-c8e6263d78db", "metadata": {}, "source": [ "**大模型的持续预训练**和**指令微调**是两种针对大模型的后续优化策略,虽然它们的目标都是提升模型性能,但在应用场景、方法和效果等方面有明显区别。以下是它们的对比分析:\n", "\n", "---\n", "\n", "### **1. 概念与目标**\n", "\n", "| **特性** | **持续预训练** | **指令微调** |\n", "|------------------------|-----------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------|\n", "| **定义** | 在通用预训练模型上,使用新的大规模语料(通用或领域特定数据)进行进一步预训练。 | 在包含指令任务的数据集上对大模型进行微调,以提升模型对人类指令的理解和执行能力。 |\n", "| **目标** | 提升模型的通用能力或适应特定领域的语言理解与生成能力。 | 提高模型对多任务指令的泛化能力,让模型更好地理解和执行自然语言表达的具体任务。 |\n", "| **典型应用** | 领域适配(医学、法律、金融)、性能优化、跨语言适配等。 | 多任务问答、开放式对话生成、翻译、推理等需要用户直接交互的场景。 |\n", "\n", "---\n", "\n", "### **2. 数据使用**\n", "\n", "| **特性** | **持续预训练** | **指令微调** |\n", "|------------------------|-----------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------|\n", "| **数据类型** | 通用语料(如新闻、社交媒体文本)或领域特定语料(如 PubMed、法律文档、金融报告)。 | 任务指令数据集,包括指令(Instruction)、输入(Input)和输出(Output)。 |\n", "| **数据构建** | 通常需要清洗和去重大规模语料数据,避免与原始预训练数据重叠。 | 通常由人工标注或模型生成的指令数据构成,例如 FLAN、InstructGPT 数据集。 |\n", "| **多样性要求** | 数据应覆盖尽可能广的领域或目标领域的多种场景,以提升模型在这些场景的表现。 | 数据需要覆盖多种任务类型(如翻译、分类、摘要)和丰富的指令表达形式,以提高模型对多任务的适配能力。 |\n", "\n", "---\n", "\n", "### **3. 方法与技术**\n", "\n", "| **特性** | **持续预训练** | **指令微调** |\n", "|------------------------|-----------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------|\n", "| **主要技术** | 继续使用自监督学习目标(如语言建模、掩码预测)进行训练。 | 使用监督学习,通常以任务输入和目标输出对为数据,通过微调适配特定任务需求。 |\n", "| **模型调整** | - 可选择全量参数更新或冻结部分参数。
- 可结合参数高效微调技术(如 LoRA、Adapter)。 | - 通常使用监督训练方式,可能结合参数高效微调技术(如 LoRA)。 |\n", "| **学习率** | 通常使用较小的学习率(如 `1e-5` 或更小),以防止破坏原始权重。 | 同样使用较小的学习率,但任务指令微调可能需要更高的关注任务特定的标签对准。 |\n", "\n", "---\n", "\n", "### **4. 模型能力与效果**\n", "\n", "| **特性** | **持续预训练** | **指令微调** |\n", "|------------------------|-----------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------|\n", "| **提升的能力** | - 对领域特定语言模式和知识的适配性提升显著。
- 对未见过的通用场景生成能力增强(扩展模型知识广度)。 | - 显著提升模型对指令理解的能力,尤其是自然语言表达的任务需求。
- 对多任务和零样本任务的泛化能力有较大提升。 |\n", "| **局限性** | - 对具体任务的直接适配能力较弱,可能需要额外的任务微调。
- 数据选择不当可能导致灾难性遗忘。 | - 依赖高质量的指令数据集,数据质量不高会导致模型生成结果不稳定。
- 对通用能力的提升有限。 |\n", "\n", "---\n", "\n", "### **5. 应用场景与示例**\n", "\n", "| **特性** | **持续预训练** | **指令微调** |\n", "|------------------------|-----------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------|\n", "| **典型应用场景** | - 医学文献总结(通过 PubMed 语料持续预训练)。
- 法律条文分析(通过法律文档进一步训练)。
- 增强多语言生成能力(跨语言语料)。 | - ChatGPT 的多任务对话生成。
- 翻译、摘要、问答等用户交互任务的泛化处理。 |\n", "| **实际示例** | - BioBERT:在 BERT 基础上使用生物医学语料持续预训练的模型。
- FinBERT:针对金融领域持续预训练的语言模型。 | - InstructGPT:在 GPT-3 基础上进行指令微调,用于多任务用户交互。
- FLAN-T5:通过 FLAN 数据集进行指令微调。 |\n", "\n", "---\n", "\n", "### **6. 持续预训练与指令微调的结合**\n", "\n", "持续预训练和指令微调可以结合使用,形成一个从领域适配到任务适配的完整流程:\n", "1. **持续预训练**:\n", " - 先在领域特定数据(如医学、法律、金融语料)上进行持续预训练,获取领域知识。\n", "2. **指令微调**:\n", " - 再利用多任务指令数据集对模型微调,使其能够高效执行领域内的多样化任务。\n", "\n", "这种结合方式特别适用于需要领域知识和任务适配的场景,例如医学问答系统或金融文本分析。\n", "\n", "---\n", "\n", "### **总结**\n", "\n", "| **维度** | **持续预训练** | **指令微调** |\n", "|------------------------|-------------------------------------|----------------------------------|\n", "| **目标** | 增强通用能力或适配特定领域。 | 提升对任务指令的理解和执行能力。 |\n", "| **数据集** | 通用或领域语料。 | 指令数据集,包含输入和输出对。 |\n", "| **方法** | 自监督学习,扩展语言建模能力。 | 监督学习,强化任务适配能力。 |\n", "| **适用场景** | 领域特定任务(如医学、法律)。 | 多任务交互(如问答、对话生成)。 |\n", "| **局限性** | 对具体任务适配较弱。 | 通用能力提升有限,依赖数据质量。 |\n", "\n", "两者各有侧重,且在许多场景下可以结合使用,形成一个强大的任务和领域适配框架。" ] }, { "cell_type": "markdown", "id": "f97a705a-b946-4dc1-a173-a9df033d6f2b", "metadata": {}, "source": [ "## 本节任务\n", "本节任务是基于上一节预训练的llama生物大模型。对一些生物学任务进行微调,包含了多个不同类型的分类问题和多序列交换问题。具体可见sft_data下的数据。" ] }, { "cell_type": "markdown", "id": "9782db62-95bd-40a6-9759-966b9a0b362e", "metadata": {}, "source": [ "## 代码运行\n", "\n", "```\n", "\n", "#微调\n", "./run_sft.sh\n", "\n", "运行时间约3小时\n", "\n", "#合并模型\n", "./merge_sft_model.sh\n", "\n", "```" ] }, { "cell_type": "markdown", "id": "182b82c4-d484-4c15-a600-03c3b51367ec", "metadata": {}, "source": [ "## 模型验证" ] }, { "cell_type": "code", "execution_count": 1, "id": "5aa3d240-44e1-4811-8f61-d6ff2500a798", "metadata": {}, "outputs": [], "source": [ "import subprocess\n", "import os\n", "# 设置环境变量, autodl一般区域\n", "result = subprocess.run('bash -c \"source /etc/network_turbo && env | grep proxy\"', shell=True, capture_output=True, text=True)\n", "output = result.stdout\n", "for line in output.splitlines():\n", " if '=' in line:\n", " var, value = line.split('=', 1)\n", " os.environ[var] = value" ] }, { "cell_type": "code", "execution_count": 2, "id": "054a2956-9045-4ad5-a878-1bfc84ad4ed8", "metadata": {}, "outputs": [], "source": [ "from transformers import AutoTokenizer, AutoConfig,AutoModel\n", "from transformers import DataCollatorForLanguageModeling\n", "from transformers import Trainer, TrainingArguments\n", "from transformers import AutoConfig, AutoModelForCausalLM,LlamaForCausalLM,LlamaTokenizer\n", "from tokenizers import Tokenizer\n", "from datasets import load_dataset" ] }, { "cell_type": "code", "execution_count": 3, "id": "63c8bf16-9576-41bc-b27c-c92ba4289cf4", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "DatasetDict({\n", " train: Dataset({\n", " features: ['instruction', 'input', 'output'],\n", " num_rows: 19839\n", " })\n", "})" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from datasets import load_dataset\n", "dna_ft_dataset = load_dataset('json', data_files='val_data.json')\n", "dna_ft_dataset" ] }, { "cell_type": "code", "execution_count": 4, "id": "95928da3-ca64-4a17-80f4-945da395702c", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "DatasetDict({\n", " train: Dataset({\n", " features: ['instruction', 'input', 'output'],\n", " num_rows: 1983\n", " })\n", " test: Dataset({\n", " features: ['instruction', 'input', 'output'],\n", " num_rows: 17856\n", " })\n", "})" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data = dna_ft_dataset[\"train\"].train_test_split(train_size=0.1, seed=42)\n", "data" ] }, { "cell_type": "code", "execution_count": 5, "id": "a3e65bcd-85ce-4261-8ba6-7665c4ec60e2", "metadata": {}, "outputs": [], "source": [ "tokenizer = LlamaTokenizer.from_pretrained(\"dnahlm-llama-7b-sft-v0\") #dnagpt/dnahlm-llama-7b-sft-v0\n", "tokenizer.pad_token = tokenizer.eos_token" ] }, { "cell_type": "code", "execution_count": 6, "id": "3d3fe49b-f48f-42b2-bc97-028e443111e4", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "4f060ff2029447b9bad5e2b2e40b7133", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Loading checkpoint shards: 0%| | 0/3 [00:00',\n", " '<0x0A>',\n", " '##',\n", " '#',\n", " '▁Inst',\n", " 'ruction',\n", " ':',\n", " '<0x0A>',\n", " 'Det',\n", " 'erm',\n", " 'ine',\n", " '▁core',\n", " '▁prom',\n", " 'oter',\n", " '▁detection',\n", " '▁of',\n", " '▁following',\n", " '▁d',\n", " 'na',\n", " '▁sequence',\n", " ',',\n", " '▁The',\n", " '▁result',\n", " '▁will',\n", " '▁be',\n", " '▁one',\n", " '▁of',\n", " '▁the',\n", " '▁following',\n", " ':',\n", " '▁Non',\n", " '-',\n", " 'prom',\n", " 'oter',\n", " ',',\n", " '▁prom',\n", " 'oter',\n", " '.',\n", " '<0x0A>',\n", " '<0x0A>',\n", " '##',\n", " '#',\n", " '▁Input',\n", " ':',\n", " '<0x0A>',\n", " 'CCG',\n", " 'TGCG',\n", " 'ACCGG',\n", " 'AAG',\n", " 'TGGGGC',\n", " 'GGCG',\n", " 'ACCCCGG',\n", " 'AAG',\n", " 'TCCCC',\n", " 'GCCGGG',\n", " 'TGCAGC',\n", " 'TTGG',\n", " 'TCGG',\n", " 'TTCG',\n", " 'ATCGCC',\n", " '<0x0A>',\n", " '<0x0A>',\n", " '##',\n", " '#',\n", " '▁Response',\n", " ':',\n", " '<0x0A>',\n", " 'prom',\n", " 'oter']" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tokenizer.tokenize(prompt)" ] }, { "cell_type": "code", "execution_count": 11, "id": "11875339-4901-4912-86e5-afe8c74921d9", "metadata": {}, "outputs": [], "source": [ "def inference(text, model, tokenizer, max_input_tokens=1000, max_output_tokens=1000):\n", " # Tokenize\n", " input_ids = tokenizer.encode(\n", " text,\n", " return_tensors=\"pt\",\n", " truncation=True,\n", " max_length=max_input_tokens\n", " # return_attention_mask=True,\n", " )\n", "\n", " # Generate\n", " device = model.device\n", " generated_tokens_with_prompt = model.generate(\n", " input_ids=input_ids.to(device),\n", " #max_length=max_output_tokens,\n", " max_new_tokens=8,\n", " temperature=0.01 # 控制生成的多样性\n", " )\n", "\n", " # Decode\n", " generated_text_with_prompt = tokenizer.decode(generated_tokens_with_prompt[0], skip_special_tokens=True)\n", " generated_text_answer = generated_text_with_prompt[len(text):]\n", "\n", "\n", " return generated_text_answer\n", "\n", "# 如果需要进一步清理\n", "def clean_generated_text(text):\n", " # 去除 'Ġ' 符号并替换为空格\n", " text = text.replace('Ġ', ' ')\n", " # 去除多余的空格\n", " text = ' '.join(text.split())\n", " return text" ] }, { "cell_type": "code", "execution_count": 12, "id": "1b02644a-8b24-45aa-b22d-0f7ce2270dd9", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "input (test): Below is an instruction that describes a task. Write a response that appropriately completes the request.\n", "\n", "### Instruction:\n", "Determine core promoter detection of following dna sequence, The result will be one of the following: Non-promoter, promoter.\n", "\n", "### Input:\n", "CCGTGCGACCGGAAGTGGGGCGGCGACCCCGGAAGTCCCCGCCGGGTGCAGCTTGGTCGGTTCGATCGCC\n", "\n", "### Response:\n", "\n", "real answer: promoter\n", "--------------------------\n", "\n", "model's answer: \n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/root/miniconda3/lib/python3.12/site-packages/transformers/generation/configuration_utils.py:601: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0.01` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`.\n", " warnings.warn(\n", "Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "promoter\n" ] } ], "source": [ "input_text = format_input(data[\"test\"][0])\n", "\n", "print(\"input (test):\", input_text)\n", "\n", "print(\"real answer:\", data[\"test\"][0][\"output\"])\n", "\n", "print(\"--------------------------\\n\")\n", "\n", "print(\"model's answer: \\n\")\n", "print(inference(input_text, model, tokenizer))" ] }, { "cell_type": "code", "execution_count": null, "id": "e2df1569-7f70-46ee-b93f-cbd879e32e14", "metadata": {}, "outputs": [], "source": [ "test_data = data[\"test\"].shuffle(seed=199).select(range(100))\n", "\n", "data_list = []\n", "\n", "for entry in test_data:\n", " input_text = format_input(entry)\n", " #print(input_text)\n", " response_text = inference(input_text, model, tokenizer)\n", " #print(response_text)\n", " data = {\n", " \"instruction\":entry[\"instruction\"],\n", " \"input\":entry[\"input\"],\n", " \"output\":entry[\"output\"],\n", " \"model_response\":response_text\n", " }\n", "\n", " data_list.append(data)" ] }, { "cell_type": "code", "execution_count": null, "id": "0c6e47cb-1b64-4690-a51d-f1816b82f15f", "metadata": {}, "outputs": [], "source": [ "import json\n", "\n", "# 定义输出文件路径\n", "output_file = 'llama-sft-2.json'\n", "\n", "# 将 Dataset 对象导出为 JSON 文件\n", "# test_data.to_json(output_file)\n", "with open(output_file, \"w\") as file:\n", " json.dump(data_list, file, indent=4) # \"indent\" for pretty-printing\n", "\n" ] }, { "cell_type": "code", "execution_count": 16, "id": "68831e19-5a99-46d8-9f40-e8bf6957dbfc", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Donor Sites |||||||||||| Non-Splice Sites\n", "promoter |||||||||||| promoter\n", "promoter |||||||||||| promoter\n", "promoter |||||||||||| Non-promoter\n", "promoter |||||||||||| promoter\n", "Donor Sites |||||||||||| Non-Splice Sites\n", "promoter |||||||||||| promoter\n", "promoter |||||||||||| Non-promoter\n", "Non-promoter |||||||||||| promoter\n", "Non-promoter |||||||||||| Non-promoter\n", "Donor Sites |||||||||||| Donor Sites\n", "Non-promoter |||||||||||| Non-promoter\n", "Non-promoter |||||||||||| Non-promoter\n", "Non-promoter |||||||||||| promoter\n", "promoter |||||||||||| promoter\n", "promoter |||||||||||| promoter\n", "Donor Sites |||||||||||| Splice Sites\n", "Background Sequences |||||||||||| Background Sequences\n", "Non-promoter |||||||||||| Non-promoter\n", "Non-promoter |||||||||||| Non-promoter\n", "promoter |||||||||||| Non-promoter\n", "promoter |||||||||||| promoter\n", "promoter |||||||||||| promoter\n", "promoter |||||||||||| Non-promoter\n", "promoter |||||||||||| promoter\n", "promoter |||||||||||| promoter\n", "Non-promoter |||||||||||| Non-promoter\n", "Non-Splice Sites |||||||||||| Non-Splice Sites\n", "Non-promoter |||||||||||| Non-promoter\n", "promoter |||||||||||| Non-promoter\n", "Non-promoter |||||||||||| Non-promoter\n", "Binding Sites |||||||||||| Background Sequences\n", "Non-promoter |||||||||||| Non-promoter\n", "Non-Splice Sites |||||||||||| Non-Splice Sites\n", "Non-promoter |||||||||||| Non-promoter\n", "Non-promoter |||||||||||| promoter\n", "Non-promoter |||||||||||| Non-promoter\n", "Donor Sites |||||||||||| Donor Sites\n", "Non-promoter |||||||||||| promoter\n", "promoter |||||||||||| promoter\n", "Background Sequences |||||||||||| Background Sequences\n", "Non-promoter |||||||||||| Non-promoter\n", "Binding Sites |||||||||||| Binding Sites\n", "promoter |||||||||||| promoter\n", "Non-promoter |||||||||||| Non-promoter\n", "Non-promoter |||||||||||| Non-promoter\n", "Non-promoter |||||||||||| Non-promoter\n", "Non-promoter |||||||||||| Non-promoter\n", "Donor Sites |||||||||||| Donor Sites\n", "promoter |||||||||||| promoter\n", "promoter |||||||||||| promoter\n", "Non-promoter |||||||||||| Non-promoter\n", "Binding Sites |||||||||||| Binding Sites\n", "promoter |||||||||||| Non-promoter\n", "promoter |||||||||||| promoter\n", "Background Sequences |||||||||||| Binding Sites\n", "promoter |||||||||||| promoter\n", "Non-promoter |||||||||||| Non-promoter\n", "Background Sequences |||||||||||| Background Sequences\n", "promoter |||||||||||| promoter\n", "promoter |||||||||||| Non-promoter\n", "promoter |||||||||||| promoter\n", "Donor Sites |||||||||||| Non-Splice Sites\n", "Binding Sites |||||||||||| Binding Sites\n", "promoter |||||||||||| promoter\n", "Donor Sites |||||||||||| Donor Sites\n", "Non-promoter |||||||||||| promoter\n", "Binding Sites |||||||||||| Binding Sites\n", "Donor Sites |||||||||||| Donor Sites\n", "Non-promoter |||||||||||| Non-promoter\n", "Donor Sites |||||||||||| Donor Sites\n", "Non-promoter |||||||||||| promoter\n", "promoter |||||||||||| promoter\n", "promoter |||||||||||| promoter\n", "promoter |||||||||||| promoter\n", "Non-promoter |||||||||||| Non-promoter\n", "Acceptor Sites |||||||||||| Acceptor Sites\n", "promoter |||||||||||| promoter\n", "Donor Sites |||||||||||| Donor Sites\n", "Donor Sites |||||||||||| Acceptor Sites\n", "promoter |||||||||||| promoter\n", "promoter |||||||||||| promoter\n", "promoter |||||||||||| promoter\n", "Non-promoter |||||||||||| Non-promoter\n", "Non-promoter |||||||||||| promoter\n", "promoter |||||||||||| Non-promoter\n", "Non-promoter |||||||||||| Non-promoter\n", "promoter |||||||||||| promoter\n", "Background Sequences |||||||||||| Binding Sites\n", "Acceptor Sites |||||||||||| Splice Sites\n", "Non-Splice Sites |||||||||||| Non-Splice Sites\n", "Donor Sites |||||||||||| Non-Splice Sites\n", "Donor Sites |||||||||||| Donor Sites\n", "Non-promoter |||||||||||| Non-promoter\n", "promoter |||||||||||| promoter\n", "Background Sequences |||||||||||| Binding Sites\n", "promoter |||||||||||| promoter\n", "promoter |||||||||||| promoter\n", "Acceptor Sites |||||||||||| Splice Sites\n", "promoter |||||||||||| promoter\n", "presicion 0.73 same 0.3\n" ] } ], "source": [ "import json\n", "from tqdm import tqdm\n", "\n", "\n", "\n", "with open(output_file, \"r\") as file:\n", " test_data = json.load(file)\n", "\n", "all_num = len(test_data)\n", "right_sum = 0\n", "same_sum = 0\n", "for item in test_data:\n", " output = item[\"output\"]\n", " #output = \" \".join(tokenizer.tokenize(output))\n", " model_response = item[\"model_response\"]\n", "\n", " print(output,\"||||||||||||\", model_response)\n", "\n", " if model_response == output: #same it\n", " same_sum = same_sum + 1\n", " \n", " if output.find(\"Non\")==-1: # no Non\n", " if model_response.find(output)!=-1 and model_response.find(\"Non\")==-1: #find it, but no Non\n", " right_sum = right_sum + 1\n", " else:\n", " if model_response.find(output)!=-1: #find it\n", " right_sum = right_sum + 1\n", "\n", "\n", "print(\"presicion\", right_sum/all_num, \"same\", same_sum/all_num)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "7bc38f47-4a7d-43eb-abe8-db4310d280e3", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.3" } }, "nbformat": 4, "nbformat_minor": 5 }