Leri777 commited on
Commit
4df36c7
·
verified ·
1 Parent(s): a5219f1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -14
app.py CHANGED
@@ -3,8 +3,7 @@ import logging
3
  from logging.handlers import RotatingFileHandler
4
  import gradio as gr
5
  import torch
6
- from accelerate import Accelerator
7
- from transformers import AutoModelForCausalLM, GemmaTokenizerFast, pipeline
8
  from langchain_huggingface import HuggingFacePipeline
9
  from langchain.prompts import PromptTemplate
10
  from langchain.chains import LLMChain
@@ -20,20 +19,22 @@ logger.addHandler(file_handler)
20
  logger.debug("Application started")
21
 
22
  model_id = "google/gemma-2-9b-it"
23
- tokenizer = GemmaTokenizerFast.from_pretrained(model_id)
24
 
25
  # Load model with GPU availability check
26
  if torch.cuda.is_available():
27
  logger.debug("GPU is available. Proceeding with GPU setup.")
28
  model = AutoModelForCausalLM.from_pretrained(
29
  model_id,
30
- device_map="auto", torch_dtype=torch.bfloat16,
 
31
  )
32
  else:
33
  logger.warning("GPU is not available. Proceeding with CPU setup.")
34
  model = AutoModelForCausalLM.from_pretrained(
35
  model_id,
36
- device_map="auto", low_cpu_mem_usage=True, token=os.getenv('HF_TOKEN'),
 
37
  )
38
 
39
  model.eval()
@@ -53,6 +54,7 @@ pipe = pipeline(
53
  # Initialize HuggingFacePipeline model for LangChain
54
  chat_model = HuggingFacePipeline(pipeline=pipe)
55
 
 
56
 
57
  # Define the conversation template for LangChain
58
  template = """<|im_start|>system
@@ -68,12 +70,12 @@ template = """<|im_start|>system
68
  prompt = PromptTemplate(
69
  template=template, input_variables=["system_prompt", "history", "human_input"]
70
  )
71
- chain = prompt | chat_model
72
 
73
  # Prediction function using LangChain and model
74
- def predict(message, chat_history=[]):
75
  formatted_history = "\n".join(
76
- [f"<|im_start|>{entry['role']}\n{entry['content']}<|im_end|>" for entry in chat_history]
77
  )
78
  system_prompt = "You are a helpful coding assistant."
79
 
@@ -91,14 +93,12 @@ def predict(message, chat_history=[]):
91
  # Gradio UI
92
  interface = gr.Interface(
93
  fn=predict,
94
- inputs=[
95
- gr.Textbox(label="User input"),
96
- gr.State(),
97
- ],
98
- outputs="text", allow_flagging='never',
99
  live=True,
100
  )
101
 
102
  interface.launch()
103
 
104
- logger.debug("Chat interface initialized and launched")
 
3
  from logging.handlers import RotatingFileHandler
4
  import gradio as gr
5
  import torch
6
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
 
7
  from langchain_huggingface import HuggingFacePipeline
8
  from langchain.prompts import PromptTemplate
9
  from langchain.chains import LLMChain
 
19
  logger.debug("Application started")
20
 
21
  model_id = "google/gemma-2-9b-it"
22
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
23
 
24
  # Load model with GPU availability check
25
  if torch.cuda.is_available():
26
  logger.debug("GPU is available. Proceeding with GPU setup.")
27
  model = AutoModelForCausalLM.from_pretrained(
28
  model_id,
29
+ device_map="auto",
30
+ torch_dtype=torch.bfloat16,
31
  )
32
  else:
33
  logger.warning("GPU is not available. Proceeding with CPU setup.")
34
  model = AutoModelForCausalLM.from_pretrained(
35
  model_id,
36
+ low_cpu_mem_usage=True,
37
+ use_auth_token=os.getenv('HF_TOKEN'),
38
  )
39
 
40
  model.eval()
 
54
  # Initialize HuggingFacePipeline model for LangChain
55
  chat_model = HuggingFacePipeline(pipeline=pipe)
56
 
57
+ logger.debug("Model and tokenizer loaded successfully")
58
 
59
  # Define the conversation template for LangChain
60
  template = """<|im_start|>system
 
70
  prompt = PromptTemplate(
71
  template=template, input_variables=["system_prompt", "history", "human_input"]
72
  )
73
+ chain = LLMChain(llm=chat_model, prompt=prompt)
74
 
75
  # Prediction function using LangChain and model
76
+ def predict(message, history=[]):
77
  formatted_history = "\n".join(
78
+ [f"<|im_start|>{entry['role']}\n{entry['content']}<|im_end|>" for entry in history]
79
  )
80
  system_prompt = "You are a helpful coding assistant."
81
 
 
93
  # Gradio UI
94
  interface = gr.Interface(
95
  fn=predict,
96
+ inputs=gr.Textbox(label="User input"),
97
+ outputs="text",
98
+ allow_flagging='never',
 
 
99
  live=True,
100
  )
101
 
102
  interface.launch()
103
 
104
+ logger.debug("Chat interface initialized and launched")