Josh Nguyen commited on
Commit
d81d6d2
1 Parent(s): 6e7d907

First commit

Browse files
Files changed (2) hide show
  1. app.py +124 -0
  2. requirements.txt +8 -0
app.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from threading import Thread
2
+ import gradio as gr
3
+ import torch
4
+ from transformers import (
5
+ AutoModelForCausalLM,
6
+ AutoTokenizer,
7
+ AutoConfig,
8
+ TextIteratorStreamer
9
+ )
10
+
11
+ MODEL_ID = "universeTBD/astrollama"
12
+ WINDOW_SIZE = 4096
13
+ DEVICE = "cuda"
14
+
15
+ config = AutoConfig.from_pretrained(pretrained_model_name_or_path=MODEL_ID)
16
+
17
+ tokenizer = AutoTokenizer.from_pretrained(
18
+ pretrained_model_name_or_path=MODEL_ID
19
+ )
20
+
21
+ model = AutoModelForCausalLM.from_pretrained(
22
+ pretrained_model_name_or_path=MODEL_ID,
23
+ config=config,
24
+ device_map="auto",
25
+ use_safetensors=True,
26
+ trust_remote_code=True,
27
+ load_in_4bit=True,
28
+ torch_dtype=torch.bfloat16
29
+ )
30
+
31
+
32
+ def generate_text(prompt: str,
33
+ max_new_tokens: int = 512,
34
+ temperature: float = 0.5,
35
+ top_p: float = 0.95,
36
+ top_k: int = 50) -> str:
37
+
38
+ # Encode the prompt
39
+ inputs = tokenizer([prompt],
40
+ return_tensors="pt",
41
+ add_special_tokens=False,
42
+ return_token_type_ids=False)
43
+ inputs = inputs.to(DEVICE)
44
+
45
+ # Prepare arguments for generation
46
+ input_length = inputs["input_ids"].shape[-1]
47
+ max_new_tokens = min(max_new_tokens, WINDOW_SIZE - input_length)
48
+ if temperature >= 1.0:
49
+ temperature = 0.99
50
+ elif temperature <= 0.0:
51
+ temperature = 0.01
52
+ if top_p > 1.0 or top_p <= 0.0:
53
+ top_p = 1.0
54
+ if top_k <= 0:
55
+ top_k = 100
56
+ streamer = TextIteratorStreamer(tokenizer,
57
+ timeout=10.,
58
+ skip_prompt=True,
59
+ skip_special_tokens=True)
60
+ generation_kwargs = dict(
61
+ inputs=inputs,
62
+ streamer=inputs,
63
+ max_new_tokens=max_new_tokens,
64
+ do_sample=True,
65
+ return_full_text=True,
66
+ top_p=top_p,
67
+ top_k=top_k,
68
+ temperature=temperature,
69
+ num_beams=1,
70
+ )
71
+
72
+ # Generate text
73
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
74
+ thread.start()
75
+
76
+ # outputs = []
77
+ for text in streamer:
78
+ return text
79
+
80
+
81
+ demo = gr.Interface(
82
+ fn=generate_text,
83
+ inputs=[
84
+ # Prompt
85
+ gr.Textbox(
86
+ label="Prompt",
87
+ container=False,
88
+ show_label=False,
89
+ placeholder="Enter some text...",
90
+ scale=10,
91
+ ),
92
+ gr.Slider(
93
+ label="Temperature",
94
+ minimum=0.01,
95
+ maximum=0.99,
96
+ step=0.01,
97
+ value=0.5,
98
+ ),
99
+ gr.Slider(
100
+ label="Top-p (for sampling)",
101
+ minimum=0.05,
102
+ maximum=1.0,
103
+ step=0.05,
104
+ value=0.95,
105
+ ),
106
+ gr.Slider(
107
+ label='Top-k (for sampling)',
108
+ minimum=1,
109
+ maximum=1000,
110
+ step=1,
111
+ value=100,
112
+ )
113
+ ],
114
+ outputs=[
115
+ gr.Textbox(
116
+ container=False,
117
+ show_label=False,
118
+ placeholder="Generated output...",
119
+ scale=10,
120
+ )
121
+ ],
122
+ )
123
+
124
+ demo.queue(max_size=20).launch(server_port=7878)
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.21.0
2
+ bitsandbytes==0.40.2
3
+ gradio==3.37.0
4
+ protobuf==3.20.3
5
+ scipy==1.11.1
6
+ sentencepiece==0.1.99
7
+ torch==2.0.1
8
+ transformers==4.31.0