IC4T commited on
Commit
7d040e6
·
1 Parent(s): ace0fb5
Files changed (1) hide show
  1. app.py +14 -6
app.py CHANGED
@@ -15,11 +15,15 @@ from langchain.embeddings import HuggingFaceEmbeddings, HuggingFaceInstructEmbed
15
  from langchain.prompts.prompt import PromptTemplate
16
  from langchain import PromptTemplate, LLMChain
17
  from langchain.llms import HuggingFacePipeline
 
18
 
19
- from training.generate import InstructionTextGenerationPipeline, load_model_tokenizer_for_generate
 
 
20
  # from googletrans import Translator
21
  # translator = Translator()
22
 
 
23
  load_dotenv()
24
 
25
  embeddings_model_name = os.environ.get("EMBEDDINGS_MODEL_NAME")
@@ -52,14 +56,18 @@ retriever = db.as_retriever(search_kwargs={"k": target_source_chunks})
52
  # Prepare the LLM
53
  # callbacks = [StreamingStdOutCallbackHandler()]
54
 
 
 
55
  match model_type:
56
  case "dolly-v2-3b":
57
  model, tokenizer = load_model_tokenizer_for_generate(model_path)
58
- llm = HuggingFacePipeline(
59
- pipeline=InstructionTextGenerationPipeline(
60
- # Return the full text, because this is what the HuggingFacePipeline expects.
61
- model=model, tokenizer=tokenizer, return_full_text=True, task="text-generation", max_new_tokens=model_n_ctx))#, max_new_tokens=model_n_ctx
62
- #))
 
 
63
  # case "GPT4All":
64
  # llm = GPT4All(model=model_path, n_ctx=model_n_ctx, backend='gptj', callbacks=callbacks, verbose=False)
65
  case _default:
 
15
  from langchain.prompts.prompt import PromptTemplate
16
  from langchain import PromptTemplate, LLMChain
17
  from langchain.llms import HuggingFacePipeline
18
+ from instruct_pipeline import InstructionTextGenerationPipeline
19
 
20
+ from training.generate import load_model_tokenizer_for_generate
21
+
22
+ # from training.generate import InstructionTextGenerationPipeline, load_model_tokenizer_for_generate
23
  # from googletrans import Translator
24
  # translator = Translator()
25
 
26
+
27
  load_dotenv()
28
 
29
  embeddings_model_name = os.environ.get("EMBEDDINGS_MODEL_NAME")
 
56
  # Prepare the LLM
57
  # callbacks = [StreamingStdOutCallbackHandler()]
58
 
59
+
60
+
61
  match model_type:
62
  case "dolly-v2-3b":
63
  model, tokenizer = load_model_tokenizer_for_generate(model_path)
64
+ generate_text = InstructionTextGenerationPipeline(model=model, tokenizer=tokenizer)
65
+ llm = HuggingFacePipeline(pipeline=generate_text)
66
+ # llm = HuggingFacePipeline(
67
+ # pipeline=InstructionTextGenerationPipeline(
68
+ # # Return the full text, because this is what the HuggingFacePipeline expects.
69
+ # model=model, tokenizer=tokenizer, return_full_text=True, task="text-generation", max_new_tokens=model_n_ctx))#, max_new_tokens=model_n_ctx
70
+ # #))
71
  # case "GPT4All":
72
  # llm = GPT4All(model=model_path, n_ctx=model_n_ctx, backend='gptj', callbacks=callbacks, verbose=False)
73
  case _default: