|
--- |
|
license: cc-by-sa-4.0 |
|
datasets: |
|
- bigcode/the-stack-dedup |
|
- sadiqj/opam-source |
|
tags: |
|
- code |
|
language: |
|
- code |
|
programming_language: |
|
- OCaml |
|
--- |
|
|
|
|
|
# camlcoder |
|
|
|
## Model Description |
|
`camlcoder` is a 2.7B Causal Language Model focused on **Code Completion** for OCaml. It is a fine-tuned version of [replit-code-v1-3b](https://www.huggingface.co/replit/replit-code-v1-3b). The model has been trained on a subset of the [Stack Dedup v1.2 dataset](https://arxiv.org/abs/2211.15533) and the most recent version of [all packages in Opam that compile on OCaml 5.0](https://www.huggingface.com/sadiqj/opam-source). |
|
|
|
## License |
|
The model checkpoint and vocabulary file are licensed under the Creative Commons license (CC BY-SA-4.0). |
|
|
|
## Contact |
|
For questions and comments about the model, please post in the community section. |
|
|
|
## How to Use |
|
First of all, you need to install the latest versions of the following dependencies: |
|
``` |
|
einops |
|
sentencepiece |
|
safetensors |
|
torch |
|
transformers |
|
``` |
|
|
|
You can then use the model as follows: |
|
```python |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, StoppingCriteria, StoppingCriteriaList |
|
import torch |
|
|
|
max_length = 256 |
|
|
|
tokenizer = AutoTokenizer.from_pretrained('sadiqj/camlcoder', trust_remote_code=True, max_length=max_length, use_safetensors=True) |
|
model = AutoModelForCausalLM.from_pretrained('sadiqj/camlcoder', trust_remote_code=True, use_safetensors=True).to(device='cuda:0', dtype=torch.bfloat16) |
|
|
|
input_ids = tokenizer.encode('(* Return the middle element of the list *)\nlet get_middle l =', return_tensors='pt').to(device='cuda:0') |
|
|
|
newline_id = tokenizer.encode('\n\n', return_tensors='pt')[0][0].item() |
|
class StopOnNewlines(StoppingCriteria): |
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: |
|
return newline_id in input_ids |
|
|
|
output = model.generate(input_ids, max_length=max_length, stopping_criteria=StoppingCriteriaList([StopOnNewlines()]), use_cache=True) |
|
|
|
print(tokenizer.decode(output[0], skip_special_tokens=True)) |
|
``` |