Spaces:
Runtime error
Runtime error
File size: 4,768 Bytes
1905a86 a73723b 1905a86 7ad3091 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
import streamlit as st
from contextlib import nullcontext
import torch
import tiktoken
from model import GPTConfig, GPT
import sys
init_from = 'resume' # either 'resume' (from an out_dir) or a gpt2 variant (e.g. 'gpt2-xl')
out_dir = 'out' # ignored if init_from is not 'resume'
#start = "\n" # or "<|endoftext|>" or etc. Can also specify a file, use as: "FILE:prompt.txt"
num_samples = 1 # number of samples to draw
max_new_tokens = 400 # number of tokens generated in each sample
temperature = 0.3 # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions
top_k = 200 # retain only the top_k most likely tokens, clamp others to have 0 probability
seed = 1337
device = 'cpu' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc.
dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32' or 'bfloat16' or 'float16'
compile = False # use PyTorch 2.0 to compile the model to be faster
exec(open('configurator.py').read()) # overrides from command line or config file
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
# init from a model saved in a specific directory
ckpt_path = ('FineTune_ckpt.pt')
checkpoint = torch.load(ckpt_path, map_location=device)
gptconf = GPTConfig(**checkpoint['model_args'])
model = GPT(gptconf)
state_dict = checkpoint['model']
unwanted_prefix = '_orig_mod.'
for k,v in list(state_dict.items()):
if k.startswith(unwanted_prefix):
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
model.load_state_dict(state_dict)
model.eval()
model.to(device)
if compile:
model = torch.compile(model) # requires PyTorch 2.0 (optional)
enc = tiktoken.get_encoding("gpt2")
encode = lambda s: enc.encode(s, allowed_special={"<|endoftext|>"})
decode = lambda l: enc.decode(l)
def get_response(prompt):
start_ids = encode(prompt)
x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])
with torch.no_grad():
with ctx:
for k in range(num_samples):
y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
return (decode(y[0].tolist()).split('[EndOfText]')[0])
# ai_logo_path = 'https://raw.githubusercontent.com/chrischenhub/AlphaGPT/04dda92907aa0aee109b9a05c4521748288a3b7e/alphaplaylogo.svg'
# customer_logo_path = 'https://raw.githubusercontent.com/chrischenhub/AlphaGPT/04dda92907aa0aee109b9a05c4521748288a3b7e/montreal-canadiens-seeklogo.svg'
# # App title
# # Replicate Credentials
# with st.sidebar:
# col1, col2 = st.columns([1, 3])
# with col1:
# st.image(ai_logo_path, width=60)
# with col2:
# st.title('AlphaPlay')
# st.write('The following keys are used to generate system prompts')
# # Store LLM generated responses
# if "messages" not in st.session_state.keys():
# st.session_state.messages = [{"role": "assistant", "content": "How may I assist you today?"}]
# # Display or clear chat messages
# for message in st.session_state.messages:
# if message["role"]=='assistant':
# with st.chat_message(message["role"],avatar=ai_logo_path):
# st.write(message["content"])
# else:
# with st.chat_message(message["role"],avatar=customer_logo_path):
# st.write(message["content"])
# def clear_chat_history():
# st.session_state.messages = [{"role": "assistant", "content": "How may I assist you today?"}]
# st.sidebar.button('Clear Chat History', on_click=clear_chat_history)
# # User-provided prompt
# if prompt := st.chat_input():
# st.session_state.messages.append({"role": "user", "content": prompt})
# with st.chat_message("user", avatar=customer_logo_path):
# st.write(prompt)
# # Generate a new response if last message is not from assistant
# if st.session_state.messages[-1]["role"] != "assistant":
# with st.chat_message("assistant",avatar=ai_logo_path):
# with st.spinner("Thinking..."):
# response = get_response(prompt)
# placeholder = st.empty()
# full_response = ''
# for item in response:
# full_response += item
# placeholder.markdown(full_response)
# placeholder.markdown(full_response)
# message = {"role": "assistant", "content": full_response}
# st.session_state.messages.append(message)
|