mustapha commited on
Commit
0e93535
·
1 Parent(s): a421983

create main.py

Browse files
Files changed (1) hide show
  1. main.py +144 -0
main.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Python file to serve as the frontend"""
2
+ import streamlit as st
3
+ from streamlit_chat import message
4
+
5
+ from langchain.chains import ConversationChain, LLMChain
6
+ from langchain import PromptTemplate
7
+ from langchain.llms.base import LLM
8
+ from langchain.memory import ConversationBufferWindowMemory
9
+ from typing import Optional, List, Mapping, Any
10
+
11
+ import torch
12
+ from peft import PeftModel
13
+ import transformers
14
+
15
+ from transformers import LlamaTokenizer, LlamaForCausalLM, GenerationConfig
16
+
17
+ tokenizer = LlamaTokenizer.from_pretrained("decapoda-research/llama-7b-hf")
18
+
19
+ model = LlamaForCausalLM.from_pretrained(
20
+ "decapoda-research/llama-7b-hf",
21
+ load_in_8bit=True,
22
+ torch_dtype=torch.float16,
23
+ device_map="auto",
24
+ )
25
+ model = PeftModel.from_pretrained(
26
+ model, "tloen/alpaca-lora-7b",
27
+ torch_dtype=torch.float16
28
+ )
29
+ model.eval()
30
+
31
+ device = "cpu"
32
+ def evaluate_raw_prompt(
33
+ prompt:str,
34
+ temperature=0.1,
35
+ top_p=0.75,
36
+ top_k=40,
37
+ num_beams=4,
38
+ **kwargs,
39
+ ):
40
+ inputs = tokenizer(prompt, return_tensors="pt")
41
+ input_ids = inputs["input_ids"].to(device)
42
+ generation_config = GenerationConfig(
43
+ temperature=temperature,
44
+ top_p=top_p,
45
+ top_k=top_k,
46
+ num_beams=num_beams,
47
+ **kwargs,
48
+ )
49
+ with torch.no_grad():
50
+ generation_output = model.generate(
51
+ input_ids=input_ids,
52
+ generation_config=generation_config,
53
+ return_dict_in_generate=True,
54
+ output_scores=True,
55
+ max_new_tokens=256,
56
+ )
57
+ s = generation_output.sequences[0]
58
+ output = tokenizer.decode(s)
59
+ # return output
60
+ return output.split("### Response:")[1].strip()
61
+
62
+ class AlpacaLLM(LLM):
63
+ temperature: float
64
+ top_p: float
65
+ top_k: int
66
+ num_beams: int
67
+ @property
68
+ def _llm_type(self) -> str:
69
+ return "custom"
70
+
71
+ def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
72
+ if stop is not None:
73
+ raise ValueError("stop kwargs are not permitted.")
74
+ answer = evaluate_raw_prompt(prompt,
75
+ top_p= self.top_p,
76
+ top_k= self.top_k,
77
+ num_beams= self.num_beams,
78
+ temperature= self.temperature
79
+ )
80
+ return answer
81
+
82
+ @property
83
+ def _identifying_params(self) -> Mapping[str, Any]:
84
+ """Get the identifying parameters."""
85
+ return {
86
+ "top_p": self.top_p,
87
+ "top_k": self.top_k,
88
+ "num_beams": self.num_beams,
89
+ "temperature": self.temperature
90
+ }
91
+
92
+
93
+ template = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
94
+ ### Instruction:
95
+ You are a chatbot, you should answer my last question very briefly. You are consistent and non repetitive.
96
+ ### Chat:
97
+ {history}
98
+ Human: {human_input}
99
+ ### Response:"""
100
+
101
+ prompt = PromptTemplate(
102
+ input_variables=["history","human_input"],
103
+ template=template,
104
+ )
105
+
106
+
107
+ def load_chain():
108
+ """Logic for loading the chain you want to use should go here."""
109
+ llm = AlpacaLLM(top_p=0.75, top_k=40, num_beams=4, temperature=0.1)
110
+ # chain = ConversationChain(llm=llm)
111
+ chain = LLMChain(llm=llm, prompt=prompt, memory=ConversationBufferWindowMemory(k=2))
112
+ return chain
113
+
114
+ chain = load_chain()
115
+
116
+ # From here down is all the StreamLit UI.
117
+ st.set_page_config(page_title="LangChain Demo", page_icon=":robot:")
118
+ st.header("LangChain Demo")
119
+
120
+ if "generated" not in st.session_state:
121
+ st.session_state["generated"] = []
122
+
123
+ if "past" not in st.session_state:
124
+ st.session_state["past"] = []
125
+
126
+
127
+ def get_text():
128
+ input_text = st.text_input("Human: ", "Hello, how are you?", key="input")
129
+ return input_text
130
+
131
+
132
+ user_input = get_text()
133
+
134
+ if user_input:
135
+ output = chain.predict(human_input=user_input)
136
+
137
+ st.session_state.past.append(user_input)
138
+ st.session_state.generated.append(output)
139
+
140
+ if st.session_state["generated"]:
141
+
142
+ for i in range(len(st.session_state["generated"]) - 1, -1, -1):
143
+ message(st.session_state["generated"][i], key=str(i))
144
+ message(st.session_state["past"][i], is_user=True, key=str(i) + "_user")