Spaces:
Runtime error
Runtime error
File size: 3,248 Bytes
d81d6d2 422252e d81d6d2 422252e d81d6d2 422252e d81d6d2 d38f5f1 422252e d81d6d2 d38f5f1 422252e d81d6d2 6797eb5 d81d6d2 6797eb5 422252e 6797eb5 d81d6d2 6797eb5 d81d6d2 ed97a98 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
from threading import Thread
import gradio as gr
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
AutoConfig,
TextIteratorStreamer
)
MODEL_ID = "universeTBD/astrollama"
WINDOW_SIZE = 4096
DEVICE = "cuda"
config = AutoConfig.from_pretrained(pretrained_model_name_or_path=MODEL_ID)
tokenizer = AutoTokenizer.from_pretrained(
pretrained_model_name_or_path=MODEL_ID
)
model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path=MODEL_ID,
config=config,
device_map="auto",
use_safetensors=True,
trust_remote_code=True,
load_in_4bit=True,
torch_dtype=torch.bfloat16
)
def generate_text(prompt: str,
max_new_tokens: int = 512,
temperature: float = 0.5,
top_p: float = 0.95,
top_k: int = 50) -> str:
# Encode the prompt
inputs = tokenizer([prompt],
return_tensors='pt',
add_special_tokens=False).to(DEVICE)
# Prepare arguments for generation
input_length = inputs["input_ids"].shape[-1]
max_new_tokens = min(max_new_tokens, WINDOW_SIZE - input_length)
if temperature >= 1.0:
temperature = 0.99
elif temperature <= 0.0:
temperature = 0.01
if top_p > 1.0 or top_p <= 0.0:
top_p = 1.0
if top_k <= 0:
top_k = 100
streamer = TextIteratorStreamer(tokenizer,
timeout=10.,
skip_prompt=True,
skip_special_tokens=True)
generation_kwargs = dict(
**inputs,
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_beams=1,
)
# Generate text
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
generated_text = prompt
for new_text in streamer:
generated_text += new_text
return generated_text
demo = gr.Interface(
fn=generate_text,
inputs=[
# Prompt
gr.Textbox(
label="Prompt",
container=False,
show_label=False,
placeholder="Enter some text...",
lines=10,
scale=10,
),
gr.Slider(
label="Maximum new tokens",
minimum=1,
maximum=4096,
step=1,
value=1024,
),
gr.Slider(
label="Temperature",
minimum=0.01,
maximum=0.99,
step=0.01,
value=0.5,
),
gr.Slider(
label="Top-p (for sampling)",
minimum=0.05,
maximum=1.0,
step=0.05,
value=0.95,
),
gr.Slider(
label='Top-k (for sampling)',
minimum=1,
maximum=1000,
step=1,
value=100,
)
],
outputs=[
gr.Textbox(
container=False,
show_label=False,
placeholder="Generated output...",
scale=10,
lines=10,
)
],
)
demo.queue(max_size=20).launch()
|