Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -3,8 +3,7 @@ import logging
|
|
3 |
from logging.handlers import RotatingFileHandler
|
4 |
import gradio as gr
|
5 |
import torch
|
6 |
-
from
|
7 |
-
from transformers import AutoModelForCausalLM, GemmaTokenizerFast, pipeline
|
8 |
from langchain_huggingface import HuggingFacePipeline
|
9 |
from langchain.prompts import PromptTemplate
|
10 |
from langchain.chains import LLMChain
|
@@ -20,20 +19,22 @@ logger.addHandler(file_handler)
|
|
20 |
logger.debug("Application started")
|
21 |
|
22 |
model_id = "google/gemma-2-9b-it"
|
23 |
-
tokenizer =
|
24 |
|
25 |
# Load model with GPU availability check
|
26 |
if torch.cuda.is_available():
|
27 |
logger.debug("GPU is available. Proceeding with GPU setup.")
|
28 |
model = AutoModelForCausalLM.from_pretrained(
|
29 |
model_id,
|
30 |
-
device_map="auto",
|
|
|
31 |
)
|
32 |
else:
|
33 |
logger.warning("GPU is not available. Proceeding with CPU setup.")
|
34 |
model = AutoModelForCausalLM.from_pretrained(
|
35 |
model_id,
|
36 |
-
|
|
|
37 |
)
|
38 |
|
39 |
model.eval()
|
@@ -53,6 +54,7 @@ pipe = pipeline(
|
|
53 |
# Initialize HuggingFacePipeline model for LangChain
|
54 |
chat_model = HuggingFacePipeline(pipeline=pipe)
|
55 |
|
|
|
56 |
|
57 |
# Define the conversation template for LangChain
|
58 |
template = """<|im_start|>system
|
@@ -68,12 +70,12 @@ template = """<|im_start|>system
|
|
68 |
prompt = PromptTemplate(
|
69 |
template=template, input_variables=["system_prompt", "history", "human_input"]
|
70 |
)
|
71 |
-
chain = prompt
|
72 |
|
73 |
# Prediction function using LangChain and model
|
74 |
-
def predict(message,
|
75 |
formatted_history = "\n".join(
|
76 |
-
[f"<|im_start|>{entry['role']}\n{entry['content']}<|im_end|>" for entry in
|
77 |
)
|
78 |
system_prompt = "You are a helpful coding assistant."
|
79 |
|
@@ -91,14 +93,12 @@ def predict(message, chat_history=[]):
|
|
91 |
# Gradio UI
|
92 |
interface = gr.Interface(
|
93 |
fn=predict,
|
94 |
-
inputs=
|
95 |
-
|
96 |
-
|
97 |
-
],
|
98 |
-
outputs="text", allow_flagging='never',
|
99 |
live=True,
|
100 |
)
|
101 |
|
102 |
interface.launch()
|
103 |
|
104 |
-
logger.debug("Chat interface initialized and launched")
|
|
|
3 |
from logging.handlers import RotatingFileHandler
|
4 |
import gradio as gr
|
5 |
import torch
|
6 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
|
|
7 |
from langchain_huggingface import HuggingFacePipeline
|
8 |
from langchain.prompts import PromptTemplate
|
9 |
from langchain.chains import LLMChain
|
|
|
19 |
logger.debug("Application started")
|
20 |
|
21 |
model_id = "google/gemma-2-9b-it"
|
22 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
23 |
|
24 |
# Load model with GPU availability check
|
25 |
if torch.cuda.is_available():
|
26 |
logger.debug("GPU is available. Proceeding with GPU setup.")
|
27 |
model = AutoModelForCausalLM.from_pretrained(
|
28 |
model_id,
|
29 |
+
device_map="auto",
|
30 |
+
torch_dtype=torch.bfloat16,
|
31 |
)
|
32 |
else:
|
33 |
logger.warning("GPU is not available. Proceeding with CPU setup.")
|
34 |
model = AutoModelForCausalLM.from_pretrained(
|
35 |
model_id,
|
36 |
+
low_cpu_mem_usage=True,
|
37 |
+
use_auth_token=os.getenv('HF_TOKEN'),
|
38 |
)
|
39 |
|
40 |
model.eval()
|
|
|
54 |
# Initialize HuggingFacePipeline model for LangChain
|
55 |
chat_model = HuggingFacePipeline(pipeline=pipe)
|
56 |
|
57 |
+
logger.debug("Model and tokenizer loaded successfully")
|
58 |
|
59 |
# Define the conversation template for LangChain
|
60 |
template = """<|im_start|>system
|
|
|
70 |
prompt = PromptTemplate(
|
71 |
template=template, input_variables=["system_prompt", "history", "human_input"]
|
72 |
)
|
73 |
+
chain = LLMChain(llm=chat_model, prompt=prompt)
|
74 |
|
75 |
# Prediction function using LangChain and model
|
76 |
+
def predict(message, history=[]):
|
77 |
formatted_history = "\n".join(
|
78 |
+
[f"<|im_start|>{entry['role']}\n{entry['content']}<|im_end|>" for entry in history]
|
79 |
)
|
80 |
system_prompt = "You are a helpful coding assistant."
|
81 |
|
|
|
93 |
# Gradio UI
|
94 |
interface = gr.Interface(
|
95 |
fn=predict,
|
96 |
+
inputs=gr.Textbox(label="User input"),
|
97 |
+
outputs="text",
|
98 |
+
allow_flagging='never',
|
|
|
|
|
99 |
live=True,
|
100 |
)
|
101 |
|
102 |
interface.launch()
|
103 |
|
104 |
+
logger.debug("Chat interface initialized and launched")
|