lara1510 commited on
Commit
6028f6f
1 Parent(s): f0dd4e4

Update chatbot.py

Browse files
Files changed (1) hide show
  1. chatbot.py +24 -6
chatbot.py CHANGED
@@ -11,6 +11,10 @@ from langchain_community.llms import Ollama
11
  from langchain_core.messages import HumanMessage, AIMessage
12
  from langchain_core.prompts import ChatPromptTemplate
13
  from langchain_core.prompts import MessagesPlaceholder
 
 
 
 
14
 
15
 
16
 
@@ -43,12 +47,26 @@ def create_chain(chains, pdf_doc):
43
 
44
 
45
  def create_model():
46
-
47
- llm = HuggingFaceLLM.from_pretrained(
48
- repo_id="google/flan-t5-base",
49
- temperature=1.0,
50
- max_new_tokens=250
51
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  return llm
53
 
54
 
 
11
  from langchain_core.messages import HumanMessage, AIMessage
12
  from langchain_core.prompts import ChatPromptTemplate
13
  from langchain_core.prompts import MessagesPlaceholder
14
+ import torch
15
+ from transformers import pipeline
16
+ from transformers import AutoTokenizer, AutoModelForCausalLM
17
+ from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
18
 
19
 
20
 
 
47
 
48
 
49
  def create_model():
50
+ tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
51
+
52
+ model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2",
53
+ device_map='auto',
54
+ torch_dtype=torch.float16,
55
+ use_auth_token=True,
56
+ load_in_8bit=True,
57
+ )
58
+ pipe = pipeline("text-generation",
59
+ model=model,
60
+ tokenizer= tokenizer,
61
+ torch_dtype=torch.bfloat16,
62
+ device_map="auto",
63
+ max_new_tokens = 1024,
64
+ do_sample=True,
65
+ top_k=10,
66
+ num_return_sequences=1,
67
+ eos_token_id=tokenizer.eos_token_id
68
+ )
69
+ llm = HuggingFacePipeline(pipeline=pipe, model_kwargs={'temperature':0})
70
  return llm
71
 
72