htigenai commited on
Commit
c8ef1f7
1 Parent(s): 2153031

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -15
app.py CHANGED
@@ -4,7 +4,6 @@ import torch
4
  import logging
5
  import sys
6
  import gc
7
- from contextlib import contextmanager
8
 
9
  # Set up logging
10
  logging.basicConfig(level=logging.INFO)
@@ -17,22 +16,25 @@ if torch.cuda.is_available():
17
 
18
  try:
19
  logger.info("Loading tokenizer...")
20
- model_id = "htigenai/finetune_test_2"
 
21
  tokenizer = AutoTokenizer.from_pretrained(
22
- model_id,
23
- use_fast=False # Use slow tokenizer to save memory
 
24
  )
25
  tokenizer.pad_token = tokenizer.eos_token
26
  logger.info("Tokenizer loaded successfully")
27
 
28
- logger.info("Loading model in 8-bit...")
 
29
  model = AutoModelForCausalLM.from_pretrained(
30
  model_id,
31
  device_map="auto",
32
- load_in_8bit=True, # Load in 8-bit instead of 4-bit
33
  torch_dtype=torch.float16,
34
  low_cpu_mem_usage=True,
35
- max_memory={0: "12GB", "cpu": "4GB"} # Limit memory usage
36
  )
37
  model.eval()
38
  logger.info("Model loaded successfully in 8-bit")
@@ -43,16 +45,15 @@ try:
43
 
44
  def generate_text(prompt, max_tokens=100, temperature=0.7):
45
  try:
46
- # Format the prompt
47
  formatted_prompt = f"### Human: {prompt}\n\n### Assistant:"
48
 
49
- # Generate with memory-efficient settings
50
  inputs = tokenizer(
51
  formatted_prompt,
52
  return_tensors="pt",
53
  padding=True,
54
  truncation=True,
55
- max_length=256 # Limit input length
56
  ).to(model.device)
57
 
58
  with torch.inference_mode():
@@ -72,11 +73,11 @@ try:
72
 
73
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
74
 
75
- # Extract only the assistant's response
76
  if "### Assistant:" in response:
77
  response = response.split("### Assistant:")[-1].strip()
78
 
79
- # Clean up memory after generation
80
  del outputs, inputs
81
  gc.collect()
82
  torch.cuda.empty_cache()
@@ -87,7 +88,7 @@ try:
87
  logger.error(f"Error during generation: {str(e)}")
88
  return f"Error generating response: {str(e)}"
89
 
90
- # Create a more memory-efficient Gradio interface
91
  iface = gr.Interface(
92
  fn=generate_text,
93
  inputs=[
@@ -117,7 +118,7 @@ try:
117
  lines=5
118
  ),
119
  title="HTIGENAI Reflection Analyzer (8-bit)",
120
- description="8-bit quantized text generation. Please keep prompts concise for best results.",
121
  examples=[
122
  ["What is machine learning?", 50, 0.7],
123
  ["Explain quantum computing", 50, 0.7],
@@ -125,7 +126,7 @@ try:
125
  cache_examples=False
126
  )
127
 
128
- # Launch with minimal memory usage
129
  iface.launch(
130
  server_name="0.0.0.0",
131
  share=False,
 
4
  import logging
5
  import sys
6
  import gc
 
7
 
8
  # Set up logging
9
  logging.basicConfig(level=logging.INFO)
 
16
 
17
  try:
18
  logger.info("Loading tokenizer...")
19
+ # Use the base model's tokenizer instead
20
+ base_model_id = "unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit"
21
  tokenizer = AutoTokenizer.from_pretrained(
22
+ base_model_id,
23
+ use_fast=True,
24
+ trust_remote_code=True
25
  )
26
  tokenizer.pad_token = tokenizer.eos_token
27
  logger.info("Tokenizer loaded successfully")
28
 
29
+ logger.info("Loading fine-tuned model in 8-bit...")
30
+ model_id = "htigenai/finetune_test_2"
31
  model = AutoModelForCausalLM.from_pretrained(
32
  model_id,
33
  device_map="auto",
34
+ load_in_8bit=True,
35
  torch_dtype=torch.float16,
36
  low_cpu_mem_usage=True,
37
+ max_memory={0: "12GB", "cpu": "4GB"}
38
  )
39
  model.eval()
40
  logger.info("Model loaded successfully in 8-bit")
 
45
 
46
  def generate_text(prompt, max_tokens=100, temperature=0.7):
47
  try:
48
+ # Format prompt with chat template
49
  formatted_prompt = f"### Human: {prompt}\n\n### Assistant:"
50
 
 
51
  inputs = tokenizer(
52
  formatted_prompt,
53
  return_tensors="pt",
54
  padding=True,
55
  truncation=True,
56
+ max_length=256
57
  ).to(model.device)
58
 
59
  with torch.inference_mode():
 
73
 
74
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
75
 
76
+ # Extract assistant's response
77
  if "### Assistant:" in response:
78
  response = response.split("### Assistant:")[-1].strip()
79
 
80
+ # Clean up
81
  del outputs, inputs
82
  gc.collect()
83
  torch.cuda.empty_cache()
 
88
  logger.error(f"Error during generation: {str(e)}")
89
  return f"Error generating response: {str(e)}"
90
 
91
+ # Create Gradio interface
92
  iface = gr.Interface(
93
  fn=generate_text,
94
  inputs=[
 
118
  lines=5
119
  ),
120
  title="HTIGENAI Reflection Analyzer (8-bit)",
121
+ description="Using Llama 3.1 base tokenizer with fine-tuned model. Keep prompts concise for best results.",
122
  examples=[
123
  ["What is machine learning?", 50, 0.7],
124
  ["Explain quantum computing", 50, 0.7],
 
126
  cache_examples=False
127
  )
128
 
129
+ # Launch interface
130
  iface.launch(
131
  server_name="0.0.0.0",
132
  share=False,