prelington commited on
Commit
bea054a
·
verified ·
1 Parent(s): 0d23760

Create cohelp_full.py

Browse files
Files changed (1) hide show
  1. cohelp_full.py +81 -0
cohelp_full.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ cohelp_full.py — Minimal single-file script to: create tiny example dataset, fine-tune a causal LM (GPT-2 style), run a CLI chat, launch a Gradio demo, and upload to HF Hub.
3
+
4
+
5
+ Features (all in one file):
6
+ - small example dataset (jsonl) generated when needed
7
+ - Trainer-based fine-tuning
8
+ - a naive chat-friendly prompt formatting
9
+ - simple loss-masking so only assistant tokens produce loss (naive implementation)
10
+ - lightweight Gradio demo for interactive testing
11
+ - upload_to_hub function to push model + tokenizer + model_card
12
+
13
+
14
+ Caveats:
15
+ - This is an educational minimal repo. For production, use larger datasets, handle tokenization / padding carefully, prefer LoRA/PEFT, and add safety filters.
16
+
17
+
18
+ Usage examples:
19
+ - Train: python cohelp_full.py --do_train --output_dir outputs/cohelp
20
+ - Demo (local): python cohelp_full.py --do_demo --model outputs/cohelp
21
+ - CLI chat: python cohelp_full.py --do_chat --model outputs/cohelp
22
+ - Upload: python cohelp_full.py --upload --repo_id your-user/cohelp
23
+
24
+
25
+ """
26
+
27
+
28
+ import os
29
+ import argparse
30
+ import json
31
+ from pathlib import Path
32
+ from typing import List
33
+
34
+
35
+ import torch
36
+ from datasets import load_dataset, Dataset
37
+ from transformers import (
38
+ AutoTokenizer,
39
+ AutoModelForCausalLM,
40
+ TrainingArguments,
41
+ Trainer,
42
+ DataCollatorForLanguageModeling,
43
+ )
44
+
45
+
46
+ # -------- CONFIGURATION --------
47
+ BASE_MODEL = "gpt2" # change to distilgpt2 or other causal model
48
+ SPECIAL_TOKENS = {
49
+ "bos_token": "<|bos|>",
50
+ "eos_token": "<|eos|>",
51
+ "pad_token": "<|pad|>",
52
+ "additional_special_tokens": ["<|user|>", "<|assistant|>"]
53
+ }
54
+ DEFAULT_MAX_LENGTH = 512
55
+
56
+
57
+ # -------- Helpers: prompt formatting --------
58
+ def build_prompt(history: List[dict], user_input: str = None):
59
+ """Build a prompt string from history and optional new user_input.
60
+ history is a list of dicts like {"role":"user"/"assistant", "text":...}
61
+ """
62
+ parts = [SPECIAL_TOKENS["bos_token"]]
63
+ for turn in history:
64
+ if turn["role"] == "user":
65
+ parts.append("<|user|>")
66
+ parts.append(turn["text"])
67
+ else:
68
+ parts.append("<|assistant|>")
69
+ parts.append(turn["text"])
70
+ if user_input is not None:
71
+ parts.append("<|user|>")
72
+ parts.append(user_input)
73
+ parts.append("<|assistant|>")
74
+ return " \n".join(parts)
75
+
76
+
77
+ # -------- Tiny example dataset generator --------
78
+ EXAMPLE_JSONL = "cohelp_example.jsonl"
79
+ EXAMPLE_LINES = [
80
+ {"role": "user", "text": "Hi, who are you?"},
81
+ histor