harsh-manvar commited on
Commit
8e06205
1 Parent(s): 7bb24c5

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +81 -0
model.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from threading import Thread
2
+ from typing import Iterator
3
+
4
+ #import torch
5
+ from transformers.utils import logging
6
+ from ctransformers import AutoModelForCausalLM, AutoTokenizer
7
+ from transformers import TextIteratorStreamer
8
+
9
+ logging.set_verbosity_info()
10
+ logger = logging.get_logger("transformers")
11
+
12
+ config = {'max_new_tokens': 256, 'repetition_penalty': 1.1,
13
+ 'temperature': 0.1, 'stream': True}
14
+ model_id = 'TheBloke/Llama-2-7B-Chat-GGML'
15
+ device = "cpu"
16
+
17
+
18
+ model = AutoModelForCausalLM.from_pretrained(model_id, model_type="llama", lib='avx2', hf=True)
19
+ tokenizer = AutoTokenizer.from_pretrained(model)
20
+
21
+ def get_prompt(message: str, chat_history: list[tuple[str, str]],
22
+ system_prompt: str) -> str:
23
+ logger.info("get_prompt chat_history=%s",chat_history)
24
+ logger.info("get_prompt system_prompt=%s",system_prompt)
25
+ texts = [f'<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n']
26
+ logger.info("texts=%s",texts)
27
+ # The first user input is _not_ stripped
28
+ do_strip = False
29
+ for user_input, response in chat_history:
30
+ user_input = user_input.strip() if do_strip else user_input
31
+ do_strip = True
32
+ texts.append(f'{user_input} [/INST] {response.strip()} </s><s>[INST] ')
33
+ message = message.strip() if do_strip else message
34
+ logger.info("get_prompt message=%s",message)
35
+ texts.append(f'{message} [/INST]')
36
+ logger.info("get_prompt final texts=%s",texts)
37
+ return ''.join(texts)
38
+
39
+
40
+ def get_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> int:
41
+ logger.info("get_input_token_length=%s",message)
42
+ prompt = get_prompt(message, chat_history, system_prompt)
43
+ logger.info("prompt=%s",prompt)
44
+ input_ids = tokenizer([prompt], return_tensors='np', add_special_tokens=False)['input_ids']
45
+ logger.info("input_ids=%s",input_ids)
46
+ return input_ids.shape[-1]
47
+
48
+
49
+ def run(message: str,
50
+ chat_history: list[tuple[str, str]],
51
+ system_prompt: str,
52
+ max_new_tokens: int = 1024,
53
+ temperature: float = 0.8,
54
+ top_p: float = 0.95,
55
+ top_k: int = 50) -> Iterator[str]:
56
+ prompt = get_prompt(message, chat_history, system_prompt)
57
+ inputs = tokenizer([prompt], return_tensors='pt', add_special_tokens=False).to(device)
58
+
59
+ streamer = TextIteratorStreamer(tokenizer,
60
+ timeout=10.,
61
+ skip_prompt=True,
62
+ skip_special_tokens=True)
63
+ generate_kwargs = dict(
64
+ inputs,
65
+ streamer=streamer,
66
+ max_new_tokens=max_new_tokens,
67
+ do_sample=True,
68
+ top_p=top_p,
69
+ top_k=top_k,
70
+ temperature=temperature,
71
+ num_beams=1,
72
+ bits=4,
73
+ groupsize=128,
74
+ )
75
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
76
+ t.start()
77
+
78
+ outputs = []
79
+ for text in streamer:
80
+ outputs.append(text)
81
+ yield ''.join(outputs)