Spaces:
Runtime error
Runtime error
File size: 1,431 Bytes
4962437 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
from transformers import AutoTokenizer, AutoModelForCausalLM
class Petals:
"""Petals Bloom models."""
def __init__(
self,
model_name="bigscience/bloom-petals",
temperature=0.7,
max_new_tokens=256,
top_p=0.9,
top_k=None,
do_sample=True,
max_length=None
):
self.model_name = model_name
self.temperature = temperature
self.max_new_tokens = max_new_tokens
self.top_p = top_p
self.top_k = top_k
self.do_sample = do_sample
self.max_length = max_length
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained(model_name)
def _default_params(self):
"""Get the default parameters for calling Petals API."""
return {
"temperature": self.temperature,
"max_new_tokens": self.max_new_tokens,
"top_p": self.top_p,
"top_k": self.top_k,
"do_sample": self.do_sample,
"max_length": self.max_length,
}
def generate(self, prompt):
"""Generate text using the Petals API."""
params = self._default_params()
inputs = self.tokenizer(prompt, return_tensors="pt")["input_ids"]
outputs = self.model.generate(inputs, **params)
return self.tokenizer.decode(outputs[0]) |