Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -4,7 +4,6 @@ import torch
|
|
4 |
import logging
|
5 |
import sys
|
6 |
import gc
|
7 |
-
from contextlib import contextmanager
|
8 |
|
9 |
# Set up logging
|
10 |
logging.basicConfig(level=logging.INFO)
|
@@ -17,22 +16,25 @@ if torch.cuda.is_available():
|
|
17 |
|
18 |
try:
|
19 |
logger.info("Loading tokenizer...")
|
20 |
-
|
|
|
21 |
tokenizer = AutoTokenizer.from_pretrained(
|
22 |
-
|
23 |
-
use_fast=
|
|
|
24 |
)
|
25 |
tokenizer.pad_token = tokenizer.eos_token
|
26 |
logger.info("Tokenizer loaded successfully")
|
27 |
|
28 |
-
logger.info("Loading model in 8-bit...")
|
|
|
29 |
model = AutoModelForCausalLM.from_pretrained(
|
30 |
model_id,
|
31 |
device_map="auto",
|
32 |
-
load_in_8bit=True,
|
33 |
torch_dtype=torch.float16,
|
34 |
low_cpu_mem_usage=True,
|
35 |
-
max_memory={0: "12GB", "cpu": "4GB"}
|
36 |
)
|
37 |
model.eval()
|
38 |
logger.info("Model loaded successfully in 8-bit")
|
@@ -43,16 +45,15 @@ try:
|
|
43 |
|
44 |
def generate_text(prompt, max_tokens=100, temperature=0.7):
|
45 |
try:
|
46 |
-
# Format
|
47 |
formatted_prompt = f"### Human: {prompt}\n\n### Assistant:"
|
48 |
|
49 |
-
# Generate with memory-efficient settings
|
50 |
inputs = tokenizer(
|
51 |
formatted_prompt,
|
52 |
return_tensors="pt",
|
53 |
padding=True,
|
54 |
truncation=True,
|
55 |
-
max_length=256
|
56 |
).to(model.device)
|
57 |
|
58 |
with torch.inference_mode():
|
@@ -72,11 +73,11 @@ try:
|
|
72 |
|
73 |
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
74 |
|
75 |
-
# Extract
|
76 |
if "### Assistant:" in response:
|
77 |
response = response.split("### Assistant:")[-1].strip()
|
78 |
|
79 |
-
# Clean up
|
80 |
del outputs, inputs
|
81 |
gc.collect()
|
82 |
torch.cuda.empty_cache()
|
@@ -87,7 +88,7 @@ try:
|
|
87 |
logger.error(f"Error during generation: {str(e)}")
|
88 |
return f"Error generating response: {str(e)}"
|
89 |
|
90 |
-
# Create
|
91 |
iface = gr.Interface(
|
92 |
fn=generate_text,
|
93 |
inputs=[
|
@@ -117,7 +118,7 @@ try:
|
|
117 |
lines=5
|
118 |
),
|
119 |
title="HTIGENAI Reflection Analyzer (8-bit)",
|
120 |
-
description="
|
121 |
examples=[
|
122 |
["What is machine learning?", 50, 0.7],
|
123 |
["Explain quantum computing", 50, 0.7],
|
@@ -125,7 +126,7 @@ try:
|
|
125 |
cache_examples=False
|
126 |
)
|
127 |
|
128 |
-
# Launch
|
129 |
iface.launch(
|
130 |
server_name="0.0.0.0",
|
131 |
share=False,
|
|
|
4 |
import logging
|
5 |
import sys
|
6 |
import gc
|
|
|
7 |
|
8 |
# Set up logging
|
9 |
logging.basicConfig(level=logging.INFO)
|
|
|
16 |
|
17 |
try:
|
18 |
logger.info("Loading tokenizer...")
|
19 |
+
# Use the base model's tokenizer instead
|
20 |
+
base_model_id = "unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit"
|
21 |
tokenizer = AutoTokenizer.from_pretrained(
|
22 |
+
base_model_id,
|
23 |
+
use_fast=True,
|
24 |
+
trust_remote_code=True
|
25 |
)
|
26 |
tokenizer.pad_token = tokenizer.eos_token
|
27 |
logger.info("Tokenizer loaded successfully")
|
28 |
|
29 |
+
logger.info("Loading fine-tuned model in 8-bit...")
|
30 |
+
model_id = "htigenai/finetune_test_2"
|
31 |
model = AutoModelForCausalLM.from_pretrained(
|
32 |
model_id,
|
33 |
device_map="auto",
|
34 |
+
load_in_8bit=True,
|
35 |
torch_dtype=torch.float16,
|
36 |
low_cpu_mem_usage=True,
|
37 |
+
max_memory={0: "12GB", "cpu": "4GB"}
|
38 |
)
|
39 |
model.eval()
|
40 |
logger.info("Model loaded successfully in 8-bit")
|
|
|
45 |
|
46 |
def generate_text(prompt, max_tokens=100, temperature=0.7):
|
47 |
try:
|
48 |
+
# Format prompt with chat template
|
49 |
formatted_prompt = f"### Human: {prompt}\n\n### Assistant:"
|
50 |
|
|
|
51 |
inputs = tokenizer(
|
52 |
formatted_prompt,
|
53 |
return_tensors="pt",
|
54 |
padding=True,
|
55 |
truncation=True,
|
56 |
+
max_length=256
|
57 |
).to(model.device)
|
58 |
|
59 |
with torch.inference_mode():
|
|
|
73 |
|
74 |
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
75 |
|
76 |
+
# Extract assistant's response
|
77 |
if "### Assistant:" in response:
|
78 |
response = response.split("### Assistant:")[-1].strip()
|
79 |
|
80 |
+
# Clean up
|
81 |
del outputs, inputs
|
82 |
gc.collect()
|
83 |
torch.cuda.empty_cache()
|
|
|
88 |
logger.error(f"Error during generation: {str(e)}")
|
89 |
return f"Error generating response: {str(e)}"
|
90 |
|
91 |
+
# Create Gradio interface
|
92 |
iface = gr.Interface(
|
93 |
fn=generate_text,
|
94 |
inputs=[
|
|
|
118 |
lines=5
|
119 |
),
|
120 |
title="HTIGENAI Reflection Analyzer (8-bit)",
|
121 |
+
description="Using Llama 3.1 base tokenizer with fine-tuned model. Keep prompts concise for best results.",
|
122 |
examples=[
|
123 |
["What is machine learning?", 50, 0.7],
|
124 |
["Explain quantum computing", 50, 0.7],
|
|
|
126 |
cache_examples=False
|
127 |
)
|
128 |
|
129 |
+
# Launch interface
|
130 |
iface.launch(
|
131 |
server_name="0.0.0.0",
|
132 |
share=False,
|