astrollama / README.md
joshnguyen's picture
Update README.md
45a37da
---
license: mit
datasets:
- universeTBD/arxiv-astro-abstracts-all
language:
- en
metrics:
- perplexity
pipeline_tag: text-generation
tags:
- llama-2
- astronomy
- astrophysics
- arxiv
inference: false
---
<p><h1>AstroLLaMA</h1></p>
**Play with the model in our Hugging Face space!** https://huggingface.co/spaces/universeTBD/astrollama
<p align="center">
<img src="https://huggingface.co/universeTBD/astrollama/resolve/main/images/astrollama-logo.png" alt="AstroLLaMA" width="500px"/>
</p>
## Loading the model
```python
from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(
pretrained_model_name_or_path="universeTBD/astrollama"
)
model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path="universeTBD/astrollama",
device_map="auto",
)
```
## Generating text from a prompt
```python
import torch
from transformers import pipeline
generator = pipeline(
task="text-generation",
model=model,
tokenizer=tokenizer,
device_map="auto"
)
# Taken from https://arxiv.org/abs/2308.12823
prompt = "In this letter, we report the discovery of the highest redshift, " \
"heavily obscured, radio-loud QSO candidate selected using JWST NIRCam/MIRI, " \
"mid-IR, sub-mm, and radio imaging in the COSMOS-Web field. "
# For reproducibility
torch.manual_seed(42)
generated_text = generator(
prompt,
do_sample=True,
max_length=512
)
```
## Embedding text with AstroLLaMA
```python
texts = [
"Abstract 1",
"Abstract 2"
]
inputs = tokenizer(
texts,
return_tensors="pt",
return_token_type_ids=False,
padding=True,
truncation=True,
max_length=4096
)
inputs.to(model.device)
outputs = model(**inputs, output_hidden_states=True)
# Last layer of the hidden states. Get average embedding of all tokens
embeddings = outputs["hidden_states"][-1][:, 1:, ...].mean(1).detach().cpu().numpy()
```