jed-tiotuico commited on
Commit
d47ab3c
1 Parent(s): 673d242

changed to ft model

Browse files
Files changed (1) hide show
  1. app.py +20 -23
app.py CHANGED
@@ -6,16 +6,16 @@ import threading
6
  import streamlit as st
7
  import random
8
  from typing import Iterable
9
- # from unsloth import FastLanguageModel
10
  from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer, PreTrainedTokenizerFast
11
  from datetime import datetime
12
  from threading import Thread
13
 
14
- # fine_tuned_model_name = "jed-tiotuico/twitter-llama"
15
- # sota_model_name = "unsloth/mistral-7b-instruct-v0.2-bnb-4bit"
 
 
16
 
17
- fine_tuned_model_name = "MBZUAI/LaMini-GPT-124M"
18
- sota_model_name = "MBZUAI/LaMini-GPT-124M"
19
  alpaca_input_text_format = "### Instruction:\n{}\n\n### Response:\n"
20
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
21
  # if device is cpu try mps?
@@ -24,16 +24,15 @@ if device == "cpu":
24
  device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
25
 
26
  def get_model_tokenizer(sota_model_name):
27
- tokenizer = AutoTokenizer.from_pretrained(
28
- sota_model_name,
29
- cache_dir="/data/.hf_cache",
30
- trust_remote_code=True
 
 
 
31
  )
32
- model = AutoModelForCausalLM.from_pretrained(
33
- sota_model_name,
34
- cache_dir="/data/.hf_cache",
35
- trust_remote_code=True
36
- ).to(device)
37
 
38
  return model, tokenizer
39
 
@@ -61,16 +60,14 @@ def write_stream_user_chat_message(user_chat, model, token, prompt):
61
  return new_customer_msg
62
 
63
  def get_mistral_model_tokenizer(sota_model_name):
64
- tokenizer = AutoTokenizer.from_pretrained(
65
- sota_model_name,
66
- cache_dir="/data/.hf_cache",
67
- trust_remote_code=True
 
 
68
  )
69
- model = AutoModelForCausalLM.from_pretrained(
70
- sota_model_name,
71
- cache_dir="/data/.hf_cache",
72
- trust_remote_code=True
73
- ).to(device)
74
 
75
  return model, tokenizer
76
 
 
6
  import streamlit as st
7
  import random
8
  from typing import Iterable
9
+ from unsloth import FastLanguageModel
10
  from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer, PreTrainedTokenizerFast
11
  from datetime import datetime
12
  from threading import Thread
13
 
14
+ fine_tuned_model_name = "jed-tiotuico/twitter-llama"
15
+ sota_model_name = "unsloth/mistral-7b-instruct-v0.2-bnb-4bit"
16
+ # fine_tuned_model_name = "MBZUAI/LaMini-GPT-124M"
17
+ # sota_model_name = "MBZUAI/LaMini-GPT-124M"
18
 
 
 
19
  alpaca_input_text_format = "### Instruction:\n{}\n\n### Response:\n"
20
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
21
  # if device is cpu try mps?
 
24
  device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
25
 
26
  def get_model_tokenizer(sota_model_name):
27
+ model, tokenizer = FastLanguageModel.from_pretrained(
28
+ model_name = "jed-tiotuico/twitter-llama",
29
+ max_seq_length = 200,
30
+ dtype = None,
31
+ load_in_4bit = True,
32
+ cache_dir = "/data/.cache/hf-models",
33
+ token=st.secrets["HF_TOKEN"]
34
  )
35
+ FastLanguageModel.for_inference(model)
 
 
 
 
36
 
37
  return model, tokenizer
38
 
 
60
  return new_customer_msg
61
 
62
  def get_mistral_model_tokenizer(sota_model_name):
63
+ model, tokenizer = FastLanguageModel.from_pretrained(
64
+ model_name = "unsloth/mistral-7b-instruct-v0.2-bnb-4bit",
65
+ max_seq_length = max_seq_length,
66
+ dtype = dtype,
67
+ load_in_4bit = load_in_4bit,
68
+ cache_dir = "/data/.cache/hf-models",
69
  )
70
+ FastLanguageModel.for_inference(model)
 
 
 
 
71
 
72
  return model, tokenizer
73