sebaweis commited on
Commit
1cb5d1d
1 Parent(s): 9f8a65e

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +54 -0
README.md CHANGED
@@ -46,4 +46,58 @@ python finetune.py \
46
  --batch-size 128 --micro-batch-size 8 --eval-limit 45 \
47
  --eval-file code_eval.jsonl --wandb-project jerboa --wandb-log-model \
48
  --wandb-watch gradients --num-epochs 6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  ```
 
46
  --batch-size 128 --micro-batch-size 8 --eval-limit 45 \
47
  --eval-file code_eval.jsonl --wandb-project jerboa --wandb-log-model \
48
  --wandb-watch gradients --num-epochs 6
49
+ ```
50
+
51
+ ```Python
52
+ import torch
53
+ from transformers import AutoTokenizer, AutoModelForCausalLM
54
+
55
+
56
+ TOKENIZER_SOURCE = 'tiiuae/falcon-7b'
57
+ BASE_MODEL = 'jinaai/falcon-7b-code-alpaca'
58
+ DEVICE = "cuda"
59
+
60
+ PROMPT = """
61
+ Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
62
+
63
+ ### Instruction:
64
+ Write a for loop in python
65
+
66
+ ### Input:
67
+
68
+ ### Response:
69
+ """
70
+ model = AutoModelForCausalLM.from_pretrained(
71
+ pretrained_model_name_or_path=BASE_MODEL,
72
+ torch_dtype=torch.float16,
73
+ trust_remote_code=True,
74
+ device_map='auto',
75
+ )
76
+
77
+ model.eval()
78
+
79
+ tokenizer = AutoTokenizer.from_pretrained(
80
+ TOKENIZER_SOURCE,
81
+ trust_remote_code=True,
82
+ padding_side='left',
83
+ )
84
+ tokenizer.pad_token = tokenizer.eos_token
85
+
86
+ inputs = tokenizer(PROMPT, return_tensors="pt")
87
+ input_ids = inputs["input_ids"].to(DEVICE)
88
+ input_attention_mask = inputs["attention_mask"].to(DEVICE)
89
+
90
+ with torch.no_grad():
91
+ generation_output = model.generate(
92
+ input_ids=input_ids,
93
+ attention_mask=input_attention_mask,
94
+ return_dict_in_generate=True,
95
+ max_new_tokens=32,
96
+ eos_token_id=tokenizer.eos_token_id,
97
+ )
98
+ generation_output = generation_output.sequences[0]
99
+ output = tokenizer.decode(generation_output, skip_special_tokens=True)
100
+
101
+ print(output)
102
+
103
  ```