TomData commited on
Commit
3ebff47
·
1 Parent(s): cef6758

changed to mistral model

Browse files
Files changed (1) hide show
  1. src/chatbot.py +8 -3
src/chatbot.py CHANGED
@@ -12,11 +12,14 @@ import os
12
  #load_dotenv(find_dotenv())
13
 
14
 
 
15
  embeddings = HuggingFaceEmbeddings(model_name="paraphrase-multilingual-MiniLM-L12-v2")
16
  llm = HuggingFaceHub(
17
  # Try different model here
18
- # repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1",
19
- repo_id="CohereForAI/c4ai-command-r-v01",
 
 
20
  task="text-generation",
21
  model_kwargs={
22
  "max_new_tokens": 512,
@@ -24,6 +27,8 @@ llm = HuggingFaceHub(
24
  "temperature": 0.1,
25
  "repetition_penalty": 1.03,
26
  }
 
 
27
  )
28
  # To Do: Experiment with different templates replying in german or english depending on the input language
29
  prompt1 = ChatPromptTemplate.from_template("""<s>[INST]
@@ -55,7 +60,7 @@ db = get_vectorstore(embeddings=embeddings, folder_path=folder_path, index_name=
55
 
56
  def chatbot(message, history, db=db, llm=llm, prompt=prompt2):
57
  raw_response = RAG(llm=llm, prompt=prompt, db=db, question=message)
58
- response = raw_response['answer']#.split("Antwort: ")[1]
59
  return response
60
 
61
  # Retrieve speech contents based on keywords
 
12
  #load_dotenv(find_dotenv())
13
 
14
 
15
+
16
  embeddings = HuggingFaceEmbeddings(model_name="paraphrase-multilingual-MiniLM-L12-v2")
17
  llm = HuggingFaceHub(
18
  # Try different model here
19
+ repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1",
20
+ # repo_id="CohereForAI/c4ai-command-r-v01", # too large 69gb
21
+ # repo_id="CohereForAI/c4ai-command-r-v01-4bit", # too large 22 gb
22
+ # repo_id="meta-llama/Meta-Llama-3-8B", # too large 16 gb
23
  task="text-generation",
24
  model_kwargs={
25
  "max_new_tokens": 512,
 
27
  "temperature": 0.1,
28
  "repetition_penalty": 1.03,
29
  }
30
+ #,huggingfacehub_api_token
31
+
32
  )
33
  # To Do: Experiment with different templates replying in german or english depending on the input language
34
  prompt1 = ChatPromptTemplate.from_template("""<s>[INST]
 
60
 
61
  def chatbot(message, history, db=db, llm=llm, prompt=prompt2):
62
  raw_response = RAG(llm=llm, prompt=prompt, db=db, question=message)
63
+ response = raw_response['answer'].split("Antwort: ")[1]
64
  return response
65
 
66
  # Retrieve speech contents based on keywords