szope commited on
Commit
944847d
1 Parent(s): 61c2687

Create app.py file

Browse files
Files changed (1) hide show
  1. app.py +75 -0
app.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Union
2
+ from vllm.engine.llm_engine import LLMEngine
3
+ from vllm.engine.arg_utils import EngineArgs
4
+ from vllm.usage.usage_lib import UsageContext
5
+ from vllm.utils import Counter
6
+ from vllm.outputs import RequestOutput
7
+ from vllm import SamplingParams
8
+ from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
9
+ import gradio as gr
10
+
11
+
12
+ class StreamingLLM:
13
+ def __init__(
14
+ self,
15
+ model: str,
16
+ dtype: str = "auto",
17
+ quantization: Optional[str] = None,
18
+ **kwargs,
19
+ ) -> None:
20
+ engine_args = EngineArgs(model=model, quantization=quantization, dtype=dtype, enforce_eager=True)
21
+ self.llm_engine = LLMEngine.from_engine_args(engine_args, usage_context=UsageContext.LLM_CLASS)
22
+ self.request_counter = Counter()
23
+
24
+ def generate(
25
+ self,
26
+ prompt: Optional[str] = None,
27
+ sampling_params: Optional[SamplingParams] = None
28
+ ) -> List[RequestOutput]:
29
+
30
+ request_id = str(next(self.request_counter))
31
+ self.llm_engine.add_request(request_id, prompt, sampling_params)
32
+
33
+ while self.llm_engine.has_unfinished_requests():
34
+ step_outputs = self.llm_engine.step()
35
+ for output in step_outputs:
36
+ yield output
37
+
38
+
39
+ class UI:
40
+ def __init__(
41
+ self,
42
+ llm: StreamingLLM,
43
+ tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
44
+ sampling_params: Optional[SamplingParams] = None,
45
+ ) -> None:
46
+ self.llm = llm
47
+ self.tokenizer = tokenizer
48
+ self.sampling_params = sampling_params
49
+
50
+ def _generate(self, message, history):
51
+ history_chat_format = []
52
+ for human, assistant in history:
53
+ history_chat_format.append({"role": "user", "content": human })
54
+ history_chat_format.append({"role": "assistant", "content": assistant})
55
+ history_chat_format.append({"role": "user", "content": message})
56
+
57
+ prompt = self.tokenizer.apply_chat_template(history_chat_format, tokenize=False)
58
+
59
+ for chunk in self.llm.generate(prompt, self.sampling_params):
60
+ yield chunk.outputs[0].text
61
+
62
+ def launch(self):
63
+ gr.ChatInterface(self._generate).launch()
64
+
65
+
66
+ if __name__ == "__main__":
67
+ llm = StreamingLLM(model="casperhansen/llama-3-70b-instruct-awq", quantization="AWQ", dtype="float16")
68
+ tokenizer = llm.llm_engine.tokenizer.tokenizer
69
+ sampling_params = SamplingParams(temperature=0.6,
70
+ top_p=0.9,
71
+ max_tokens=4096,
72
+ stop_token_ids=[tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|eot_id|>")]
73
+ )
74
+ ui = UI(llm, tokenizer, sampling_params)
75
+ ui.launch()