joshnguyen's picture
Update README.md
1dc4f2b verified
---
license: mit
language:
- en
metrics:
- perplexity
pipeline_tag: text-generation
tags:
- llama-2
- astronomy
- astrophysics
- arxiv
inference: false
---
<p><h1>AstroLLaMA</h1></p>
## Loading the model
```python
from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(
pretrained_model_name_or_path="AstroMLab/astrollama"
)
model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path="AstroMLab/astrollama",
device_map="auto",
)
```
## Generating text from a prompt
```python
import torch
from transformers import pipeline
generator = pipeline(
task="text-generation",
model=model,
tokenizer=tokenizer,
device_map="auto"
)
# Taken from https://arxiv.org/abs/2308.12823
prompt = "In this letter, we report the discovery of the highest redshift, " \
"heavily obscured, radio-loud QSO candidate selected using JWST NIRCam/MIRI, " \
"mid-IR, sub-mm, and radio imaging in the COSMOS-Web field. "
# For reproducibility
torch.manual_seed(42)
generated_text = generator(
prompt,
do_sample=True,
max_length=512
)
```
## Embedding text with AstroLLaMA
```python
texts = [
"Abstract 1",
"Abstract 2"
]
inputs = tokenizer(
texts,
return_tensors="pt",
return_token_type_ids=False,
padding=True,
truncation=True,
max_length=4096
)
inputs.to(model.device)
outputs = model(**inputs, output_hidden_states=True)
# Last layer of the hidden states. Get average embedding of all tokens
embeddings = outputs["hidden_states"][-1][:, 1:, ...].mean(1).detach().cpu().numpy()
```