Spaces:
Build error
Build error
File size: 5,438 Bytes
0e93535 6a59529 0e93535 6a59529 0e93535 6a59529 0e93535 6a59529 0e93535 b74d347 9ae79ea 1552f4f 0e93535 e38228f 0e93535 e38228f 0e93535 e38228f 0e93535 e38228f 0e93535 e38228f 0e93535 e38228f 0e93535 e38228f 0e93535 e38228f 0e93535 e38228f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
"""Python file to serve as the frontend"""
import streamlit as st
from streamlit_chat import message
from langchain.chains import ConversationChain, LLMChain
from langchain import PromptTemplate
from langchain.llms.base import LLM
from langchain.memory import ConversationBufferWindowMemory
from typing import Optional, List, Mapping, Any
import torch
from peft import PeftModel
import transformers
from transformers import LlamaTokenizer, LlamaForCausalLM, GenerationConfig
from transformers import BitsAndBytesConfig
tokenizer = LlamaTokenizer.from_pretrained("decapoda-research/llama-7b-hf")
quantization_config = BitsAndBytesConfig(llm_int8_enable_fp32_cpu_offload=True)
model = LlamaForCausalLM.from_pretrained(
"decapoda-research/llama-7b-hf",
# load_in_8bit=True,
# torch_dtype=torch.float16,
device_map="auto",
# device_map={"":"cpu"},
max_memory={"cpu":"15GiB"}
quantization_config=quantization_config
)
model = PeftModel.from_pretrained(
model, "tloen/alpaca-lora-7b",
# torch_dtype=torch.float16,
device_map={"":"cpu"},
)
device = "cpu"
print("model device :", model.device, flush=True)
# model.to(device)
model.eval()
def evaluate_raw_prompt(
prompt:str,
temperature=0.1,
top_p=0.75,
top_k=40,
num_beams=4,
**kwargs,
):
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"].to(device)
generation_config = GenerationConfig(
temperature=temperature,
top_p=top_p,
top_k=top_k,
num_beams=num_beams,
**kwargs,
)
with torch.no_grad():
generation_output = model.generate(
input_ids=input_ids,
generation_config=generation_config,
return_dict_in_generate=True,
output_scores=True,
max_new_tokens=256,
)
s = generation_output.sequences[0]
output = tokenizer.decode(s)
# return output
return output.split("### Response:")[1].strip()
class AlpacaLLM(LLM):
temperature: float
top_p: float
top_k: int
num_beams: int
@property
def _llm_type(self) -> str:
return "custom"
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
if stop is not None:
raise ValueError("stop kwargs are not permitted.")
answer = evaluate_raw_prompt(prompt,
top_p= self.top_p,
top_k= self.top_k,
num_beams= self.num_beams,
temperature= self.temperature
)
return answer
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {
"top_p": self.top_p,
"top_k": self.top_k,
"num_beams": self.num_beams,
"temperature": self.temperature
}
template = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
### Instruction:
You are a chatbot, you should answer my last question very briefly. You are consistent and non repetitive.
### Chat:
{history}
Human: {human_input}
### Response:"""
prompt = PromptTemplate(
input_variables=["history","human_input"],
template=template,
)
def load_chain():
"""Logic for loading the chain you want to use should go here."""
llm = AlpacaLLM(top_p=0.75, top_k=40, num_beams=4, temperature=0.1)
# chain = ConversationChain(llm=llm)
chain = LLMChain(llm=llm, prompt=prompt, memory=ConversationBufferWindowMemory(k=2))
return chain
chain = load_chain()
# # From here down is all the StreamLit UI.
# st.set_page_config(page_title="LangChain Demo", page_icon=":robot:")
# st.header("LangChain Demo")
# if "generated" not in st.session_state:
# st.session_state["generated"] = []
# if "past" not in st.session_state:
# st.session_state["past"] = []
# def get_text():
# input_text = st.text_input("Human: ", "Hello, how are you?", key="input")
# return input_text
# user_input = get_text()
# if user_input:
# output = chain.predict(human_input=user_input)
# st.session_state.past.append(user_input)
# st.session_state.generated.append(output)
# if st.session_state["generated"]:
# for i in range(len(st.session_state["generated"]) - 1, -1, -1):
# message(st.session_state["generated"][i], key=str(i))
# message(st.session_state["past"][i], is_user=True, key=str(i) + "_user")
st.title("ChatAlpaca")
if "history" not in st.session_state:
st.session_state.history = []
st.session_state.history.append({"message": "Hey, I'm a Alpaca chatBot. Ask whatever you want!", "is_user": False})
def generate_answer():
user_message = st.session_state.input_text
inputs = tokenizer(st.session_state.input_text, return_tensors="pt")
result = model.generate(**inputs)
message_bot = tokenizer.decode(result[0], skip_special_tokens=True) # .replace("<s>", "").replace("</s>", "")
st.session_state.history.append({"message": user_message, "is_user": True})
st.session_state.history.append({"message": message_bot, "is_user": False})
st.text_input("Response", key="input_text", on_change=generate_answer)
for chat in st.session_state.history:
st_message(**chat) |