Text Generation
Transformers
Safetensors
code
replit_lm
custom_code
camlcoder / README.md
sadiqj's picture
Update README.md (#1)
d4c2598
---
license: cc-by-sa-4.0
datasets:
- bigcode/the-stack-dedup
- sadiqj/opam-source
tags:
- code
language:
- code
programming_language:
- OCaml
---
# camlcoder
## Model Description
`camlcoder` is a 2.7B Causal Language Model focused on **Code Completion** for OCaml. It is a fine-tuned version of [replit-code-v1-3b](https://www.huggingface.co/replit/replit-code-v1-3b). The model has been trained on a subset of the [Stack Dedup v1.2 dataset](https://arxiv.org/abs/2211.15533) and the most recent version of [all packages in Opam that compile on OCaml 5.0](https://www.huggingface.com/sadiqj/opam-source).
## License
The model checkpoint and vocabulary file are licensed under the Creative Commons license (CC BY-SA-4.0).
## Contact
For questions and comments about the model, please post in the community section.
## How to Use
First of all, you need to install the latest versions of the following dependencies:
```
einops
sentencepiece
safetensors
torch
transformers
```
You can then use the model as follows:
```python
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, StoppingCriteria, StoppingCriteriaList
import torch
max_length = 256
tokenizer = AutoTokenizer.from_pretrained('sadiqj/camlcoder', trust_remote_code=True, max_length=max_length, use_safetensors=True)
model = AutoModelForCausalLM.from_pretrained('sadiqj/camlcoder', trust_remote_code=True, use_safetensors=True).to(device='cuda:0', dtype=torch.bfloat16)
input_ids = tokenizer.encode('(* Return the middle element of the list *)\nlet get_middle l =', return_tensors='pt').to(device='cuda:0')
newline_id = tokenizer.encode('\n\n', return_tensors='pt')[0][0].item()
class StopOnNewlines(StoppingCriteria):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
return newline_id in input_ids
output = model.generate(input_ids, max_length=max_length, stopping_criteria=StoppingCriteriaList([StopOnNewlines()]), use_cache=True)
print(tokenizer.decode(output[0], skip_special_tokens=True))
```