lara1510 commited on
Commit
42aece6
1 Parent(s): 9f0ea65

Update chatbot.py

Browse files
Files changed (1) hide show
  1. chatbot.py +8 -11
chatbot.py CHANGED
@@ -1,4 +1,4 @@
1
-
2
  from langchain.text_splitter import RecursiveCharacterTextSplitter
3
  from langchain_community.embeddings import HuggingFaceEmbeddings
4
  from langchain_community.vectorstores import Chroma
@@ -46,27 +46,24 @@ def create_chain(chains, pdf_doc):
46
 
47
 
48
  def create_model():
 
49
  tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
50
-
51
  model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2",
52
  device_map='auto',
53
  torch_dtype=torch.float16,
54
- use_auth_token=True,
55
- load_in_8bit=True,
56
- )
57
  pipe = pipeline("text-generation",
58
  model=model,
59
- tokenizer= tokenizer,
60
  torch_dtype=torch.bfloat16,
61
  device_map="auto",
62
- max_new_tokens = 1024,
63
  do_sample=True,
64
  top_k=10,
65
  num_return_sequences=1,
66
- eos_token_id=tokenizer.eos_token_id
67
- )
68
- llm = HuggingFacePipeline(pipeline=pipe, model_kwargs={'temperature':0})
69
- return llm
70
 
71
 
72
  def create_vector_db(doc):
 
1
+ import os
2
  from langchain.text_splitter import RecursiveCharacterTextSplitter
3
  from langchain_community.embeddings import HuggingFaceEmbeddings
4
  from langchain_community.vectorstores import Chroma
 
46
 
47
 
48
  def create_model():
49
+ hf_api_token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
50
  tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
 
51
  model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2",
52
  device_map='auto',
53
  torch_dtype=torch.float16,
54
+ token=hf_api_token,
55
+ load_in_8bit=True)
 
56
  pipe = pipeline("text-generation",
57
  model=model,
58
+ tokenizer=tokenizer,
59
  torch_dtype=torch.bfloat16,
60
  device_map="auto",
61
+ max_new_tokens=1024,
62
  do_sample=True,
63
  top_k=10,
64
  num_return_sequences=1,
65
+ eos_token_id=tokenizer.eos_token_id)
66
+ llm = HuggingFacePipeline(pipeline=pipe, model_kwargs={'temperature': 0})
 
 
67
 
68
 
69
  def create_vector_db(doc):