pvyas96 commited on
Commit
d41a141
1 Parent(s): a600542

Update pages/1_Simple_Chat_UI.py

Browse files
Files changed (1) hide show
  1. pages/1_Simple_Chat_UI.py +5 -5
pages/1_Simple_Chat_UI.py CHANGED
@@ -2,12 +2,12 @@ import streamlit as st
2
  from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM
3
 
4
  def load_model_tokenizer(model_name, hf_api_key):
5
- if model_name == "Mistral-7B":
6
- model_name="NousResearch/Nous-Hermes-2-Mistral-7B-DPO"
7
  model = AutoModelForCausalLM.from_pretrained(model_name, token=hf_api_key)
8
  tokenizer = AutoTokenizer.from_pretrained(model_name, tokenizer=hf_api_key)
9
- elif model_name == "blenderbot-400M-distill":
10
- model_name = "facebook/blenderbot-400M-distill"
11
  model = AutoModelForCausalLM.from_pretrained(model_name)
12
  tokenizer = AutoTokenizer.from_pretrained(model_name)
13
  return (model,tokenizer)
@@ -38,7 +38,7 @@ with st.sidebar:
38
  if "messages" not in st.session_state.keys():
39
  st.session_state.messages = [{"role": "assistant", "content": "How may I help you?"}]
40
 
41
- model_name = st.radio("Select model to chat", options=["Mistral-7B", "LLaMa-2B", "blenderbot-400M-distill"], horizontal=True, key='model_selection')
42
  model, tokenizer = load_model_tokenizer(model_name, hf_api_key)
43
 
44
  for message in st.session_state.messages:
 
2
  from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM
3
 
4
  def load_model_tokenizer(model_name, hf_api_key):
5
+ if model_name == "LLaMa-2B":
6
+ model_name="llmware/bling-sheared-llama-2.7b-0.1"
7
  model = AutoModelForCausalLM.from_pretrained(model_name, token=hf_api_key)
8
  tokenizer = AutoTokenizer.from_pretrained(model_name, tokenizer=hf_api_key)
9
+ elif model_name == "Red-Pajamas-3b":
10
+ model_name = "llmware/bling-red-pajamas-3b-0.1"
11
  model = AutoModelForCausalLM.from_pretrained(model_name)
12
  tokenizer = AutoTokenizer.from_pretrained(model_name)
13
  return (model,tokenizer)
 
38
  if "messages" not in st.session_state.keys():
39
  st.session_state.messages = [{"role": "assistant", "content": "How may I help you?"}]
40
 
41
+ model_name = st.radio("Select model to chat", options=["LLaMa-2B", "Red-Pajamas-3b"], horizontal=True, key='model_selection')
42
  model, tokenizer = load_model_tokenizer(model_name, hf_api_key)
43
 
44
  for message in st.session_state.messages: