HusainHG commited on
Commit
c1b2914
Β·
verified Β·
1 Parent(s): c043e5f

Upload 10 files

Browse files
Files changed (2) hide show
  1. app.py +198 -200
  2. dataset.py +0 -0
app.py CHANGED
@@ -1,200 +1,198 @@
1
- from flask import Flask, request, jsonify, send_from_directory
2
- from flask_cors import CORS
3
- from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
4
- import torch
5
- import os
6
- import sys
7
-
8
- app = Flask(__name__, static_folder='static')
9
- CORS(app)
10
-
11
- MODEL_NAME = "KASHH-4/phi_finetuned"
12
-
13
- print("\n" + "="*80)
14
- print("πŸš€ LEGALDOCS AI - MODEL INITIALIZATION")
15
- print("="*80)
16
- print(f"πŸ“¦ Model: {MODEL_NAME}")
17
- print(f"🐍 Python: {torch.__version__}")
18
- print(f"πŸ”₯ PyTorch: {torch.__version__}")
19
- print(f"πŸ€— Transformers: Loading...")
20
- print("="*80 + "\n")
21
-
22
- print("Loading tokenizer from YOUR merged model (slow tokenizer)...")
23
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False)
24
-
25
- if tokenizer.pad_token is None:
26
- tokenizer.pad_token = tokenizer.eos_token
27
-
28
- print("βœ… Tokenizer loaded successfully!")
29
- print(f" - Vocab size: {tokenizer.vocab_size}")
30
- print(f" - Model max length: {tokenizer.model_max_length}")
31
- print(f" - Pad token: {tokenizer.pad_token}")
32
-
33
- print("Loading YOUR model weights...")
34
- # Optimized for 18GB RAM with 4-bit quantization
35
- quantization_config = BitsAndBytesConfig(
36
- load_in_4bit=True,
37
- bnb_4bit_compute_dtype=torch.float16,
38
- bnb_4bit_quant_type="nf4",
39
- bnb_4bit_use_double_quant=True,
40
- )
41
-
42
- model = AutoModelForCausalLM.from_pretrained(
43
- MODEL_NAME,
44
- quantization_config=quantization_config,
45
- device_map="auto",
46
- low_cpu_mem_usage=True,
47
- trust_remote_code=True,
48
- torch_dtype=torch.float16,
49
- )
50
-
51
- print("βœ… Model loaded successfully!")
52
- print(f" - Device: {model.device}")
53
- print(f" - Model type: {type(model).__name__}")
54
- print(f" - Quantization: 4-bit NF4")
55
- print(f" - Compute dtype: float16")
56
-
57
- # Memory info
58
- if torch.cuda.is_available():
59
- print(f" - GPU: {torch.cuda.get_device_name(0)}")
60
- print(f" - GPU Memory allocated: {torch.cuda.memory_allocated(0) / 1024**3:.2f} GB")
61
- print(f" - GPU Memory reserved: {torch.cuda.memory_reserved(0) / 1024**3:.2f} GB")
62
- else:
63
- print(f" - Running on CPU")
64
-
65
- print("\n" + "="*80)
66
- print("βœ… MODEL READY - Server starting...")
67
- print("="*80 + "\n")
68
-
69
-
70
- @app.route('/')
71
- def index():
72
- return send_from_directory('static', 'index.html')
73
-
74
-
75
- @app.route('/api/generate', methods=['POST'])
76
- def generate():
77
- import time
78
- try:
79
- print("\n" + "="*80, flush=True)
80
- print("πŸš€ NEW GENERATION REQUEST RECEIVED", flush=True)
81
- print("="*80, flush=True)
82
- sys.stdout.flush()
83
-
84
- data = request.json
85
-
86
- if not data or 'prompt' not in data:
87
- print("❌ ERROR: Missing prompt in request body", flush=True)
88
- sys.stdout.flush()
89
- return jsonify({'error': 'Missing prompt in request body'}), 400
90
-
91
- prompt = data['prompt']
92
- max_new_tokens = data.get('max_new_tokens', 400)
93
- temperature = data.get('temperature', 0.7)
94
- top_p = data.get('top_p', 0.9)
95
-
96
- print(f"\nπŸ“ REQUEST PARAMETERS:", flush=True)
97
- print(f" - Prompt length: {len(prompt)} characters", flush=True)
98
- print(f" - Prompt preview: {prompt[:200]}...", flush=True)
99
- print(f" - Max new tokens: {max_new_tokens}", flush=True)
100
- print(f" - Temperature: {temperature}", flush=True)
101
- print(f" - Top P: {top_p}", flush=True)
102
- sys.stdout.flush()
103
-
104
- print(f"\nπŸ”„ TOKENIZING INPUT...", flush=True)
105
- sys.stdout.flush()
106
- tokenize_start = time.time()
107
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
108
- tokenize_time = time.time() - tokenize_start
109
- input_token_count = inputs['input_ids'].shape[1]
110
- print(f" βœ… Tokenization complete in {tokenize_time:.2f}s", flush=True)
111
- print(f" - Input tokens: {input_token_count}", flush=True)
112
- print(f" - Device: {model.device}", flush=True)
113
- sys.stdout.flush()
114
-
115
- print(f"\n🧠 GENERATING TEXT WITH MODEL...", flush=True)
116
- print(f" Model: {MODEL_NAME}", flush=True)
117
- print(f" Status: Running inference...", flush=True)
118
- sys.stdout.flush()
119
- generation_start = time.time()
120
-
121
- # Use controlled sampling for better JSON generation
122
- with torch.no_grad():
123
- torch.set_num_threads(2) # Use both CPU cores
124
- outputs = model.generate(
125
- **inputs,
126
- max_new_tokens=400,
127
- temperature=0.7,
128
- top_p=0.9,
129
- do_sample=True,
130
- pad_token_id=tokenizer.eos_token_id,
131
- eos_token_id=tokenizer.eos_token_id,
132
- repetition_penalty=1.1
133
- )
134
-
135
- generation_time = time.time() - generation_start
136
- output_token_count = outputs.shape[1]
137
- tokens_generated = output_token_count - input_token_count
138
- tokens_per_second = tokens_generated / generation_time if generation_time > 0 else 0
139
-
140
- print(f" βœ… Generation complete in {generation_time:.2f}s", flush=True)
141
- print(f" - Output tokens: {output_token_count}", flush=True)
142
- print(f" - New tokens generated: {tokens_generated}", flush=True)
143
- print(f" - Speed: {tokens_per_second:.2f} tokens/second", flush=True)
144
- sys.stdout.flush()
145
-
146
- print(f"\nπŸ”„ DECODING OUTPUT...", flush=True)
147
- sys.stdout.flush()
148
- decode_start = time.time()
149
- # Decode the full output
150
- full_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
151
- decode_time = time.time() - decode_start
152
- print(f" βœ… Decoding complete in {decode_time:.2f}s", flush=True)
153
- sys.stdout.flush()
154
-
155
- # Remove the prompt from the output to return only the generated text
156
- generated_text = full_output[len(prompt):].strip()
157
-
158
- print(f"\nπŸ“Š FINAL RESULTS:", flush=True)
159
- print(f" - Generated text length: {len(generated_text)} characters", flush=True)
160
- print(f" - Generated text preview: {generated_text[:200]}...", flush=True)
161
- print(f" - Total processing time: {(time.time() - tokenize_start):.2f}s", flush=True)
162
- sys.stdout.flush()
163
-
164
- print(f"\nβœ… REQUEST COMPLETED SUCCESSFULLY", flush=True)
165
- print("="*80 + "\n", flush=True)
166
- sys.stdout.flush()
167
-
168
- return jsonify({
169
- 'generated_text': generated_text,
170
- 'prompt': prompt
171
- })
172
-
173
- except Exception as e:
174
- print(f"\n❌ ERROR DURING GENERATION:", flush=True)
175
- print(f" Error type: {type(e).__name__}", flush=True)
176
- print(f" Error message: {str(e)}", flush=True)
177
- sys.stdout.flush()
178
- import traceback
179
- print(f" Traceback:\n{traceback.format_exc()}", flush=True)
180
- print("="*80 + "\n", flush=True)
181
- sys.stdout.flush()
182
- return jsonify({'error': str(e)}), 500
183
-
184
-
185
- @app.route('/api/health', methods=['GET'])
186
- def health():
187
- return jsonify({
188
- 'status': 'ok',
189
- 'model': MODEL_NAME,
190
- 'device': str(model.device)
191
- })
192
-
193
-
194
- if __name__ == '__main__':
195
- port = int(os.environ.get('PORT', 7860))
196
- print(f"\n🌐 Starting Flask server on port {port}...")
197
- print(f"πŸ”— Access the app at: http://localhost:{port}")
198
- print(f"πŸ“Š Health check: http://localhost:{port}/api/health")
199
- print(f"πŸš€ API endpoint: http://localhost:{port}/api/generate\n")
200
- app.run(host='0.0.0.0', port=port, debug=False)
 
1
+ from flask import Flask, request, jsonify, send_from_directory
2
+ from flask_cors import CORS
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
4
+ import torch
5
+ import os
6
+ import sys
7
+
8
+ app = Flask(__name__, static_folder='static')
9
+ CORS(app)
10
+
11
+ MODEL_NAME = "KASHH-4/phi_finetuned"
12
+
13
+ print("\n" + "="*80)
14
+ print("πŸš€ LEGALDOCS AI - MODEL INITIALIZATION")
15
+ print("="*80)
16
+ print(f"πŸ“¦ Model: {MODEL_NAME}")
17
+ print(f"🐍 Python: {torch.__version__}")
18
+ print(f"πŸ”₯ PyTorch: {torch.__version__}")
19
+ print(f"πŸ€— Transformers: Loading...")
20
+ print("="*80 + "\n")
21
+
22
+ print("Loading tokenizer from YOUR merged model (slow tokenizer)...")
23
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False)
24
+
25
+ if tokenizer.pad_token is None:
26
+ tokenizer.pad_token = tokenizer.eos_token
27
+
28
+ print("βœ… Tokenizer loaded successfully!")
29
+ print(f" - Vocab size: {tokenizer.vocab_size}")
30
+ print(f" - Model max length: {tokenizer.model_max_length}")
31
+ print(f" - Pad token: {tokenizer.pad_token}")
32
+
33
+ print("Loading YOUR model weights...")
34
+ # Optimized for 18GB RAM with 4-bit quantization
35
+ quantization_config = BitsAndBytesConfig(
36
+ load_in_4bit=True,
37
+ bnb_4bit_compute_dtype=torch.float16,
38
+ bnb_4bit_quant_type="nf4",
39
+ bnb_4bit_use_double_quant=True,
40
+ )
41
+
42
+ model = AutoModelForCausalLM.from_pretrained(
43
+ MODEL_NAME,
44
+ quantization_config=quantization_config,
45
+ device_map="auto",
46
+ low_cpu_mem_usage=True,
47
+ trust_remote_code=True,
48
+ torch_dtype=torch.float16,
49
+ )
50
+
51
+ print("βœ… Model loaded successfully!")
52
+ print(f" - Device: {model.device}")
53
+ print(f" - Model type: {type(model).__name__}")
54
+ print(f" - Quantization: 4-bit NF4")
55
+ print(f" - Compute dtype: float16")
56
+
57
+ # Memory info
58
+ if torch.cuda.is_available():
59
+ print(f" - GPU: {torch.cuda.get_device_name(0)}")
60
+ print(f" - GPU Memory allocated: {torch.cuda.memory_allocated(0) / 1024**3:.2f} GB")
61
+ print(f" - GPU Memory reserved: {torch.cuda.memory_reserved(0) / 1024**3:.2f} GB")
62
+ else:
63
+ print(f" - Running on CPU")
64
+
65
+ print("\n" + "="*80)
66
+ print("βœ… MODEL READY - Server starting...")
67
+ print("="*80 + "\n")
68
+
69
+
70
+ @app.route('/')
71
+ def index():
72
+ return send_from_directory('static', 'index.html')
73
+
74
+
75
+ @app.route('/api/generate', methods=['POST'])
76
+ def generate():
77
+ import time
78
+ try:
79
+ print("\n" + "="*80, flush=True)
80
+ print("πŸš€ NEW GENERATION REQUEST RECEIVED", flush=True)
81
+ print("="*80, flush=True)
82
+ sys.stdout.flush()
83
+
84
+ data = request.json
85
+
86
+ if not data or 'prompt' not in data:
87
+ print("❌ ERROR: Missing prompt in request body", flush=True)
88
+ sys.stdout.flush()
89
+ return jsonify({'error': 'Missing prompt in request body'}), 400
90
+
91
+ prompt = data['prompt']
92
+ max_new_tokens = data.get('max_new_tokens', 400)
93
+ temperature = data.get('temperature', 0.7)
94
+ top_p = data.get('top_p', 0.9)
95
+
96
+ print(f"\nπŸ“ REQUEST PARAMETERS:", flush=True)
97
+ print(f" - Prompt length: {len(prompt)} characters", flush=True)
98
+ print(f" - Prompt preview: {prompt[:200]}...", flush=True)
99
+ print(f" - Max new tokens: {max_new_tokens}", flush=True)
100
+ print(f" - Temperature: {temperature}", flush=True)
101
+ print(f" - Top P: {top_p}", flush=True)
102
+ sys.stdout.flush()
103
+
104
+ print(f"\nπŸ”„ TOKENIZING INPUT...", flush=True)
105
+ sys.stdout.flush()
106
+ tokenize_start = time.time()
107
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
108
+ tokenize_time = time.time() - tokenize_start
109
+ input_token_count = inputs['input_ids'].shape[1]
110
+ print(f" βœ… Tokenization complete in {tokenize_time:.2f}s", flush=True)
111
+ print(f" - Input tokens: {input_token_count}", flush=True)
112
+ print(f" - Device: {model.device}", flush=True)
113
+ sys.stdout.flush()
114
+
115
+ print(f"\n🧠 GENERATING TEXT WITH MODEL...", flush=True)
116
+ print(f" Model: {MODEL_NAME}", flush=True)
117
+ print(f" Status: Running inference...", flush=True)
118
+ sys.stdout.flush()
119
+ generation_start = time.time()
120
+
121
+ # Use controlled sampling optimized for Phi-3
122
+ with torch.no_grad():
123
+ torch.set_num_threads(2) # Use both CPU cores
124
+ outputs = model.generate(
125
+ **inputs,
126
+ max_new_tokens=400,
127
+ do_sample=False, # Phi-3 works better with greedy decoding
128
+ pad_token_id=tokenizer.pad_token_id,
129
+ eos_token_id=tokenizer.eos_token_id,
130
+ use_cache=True
131
+ )
132
+
133
+ generation_time = time.time() - generation_start
134
+ output_token_count = outputs.shape[1]
135
+ tokens_generated = output_token_count - input_token_count
136
+ tokens_per_second = tokens_generated / generation_time if generation_time > 0 else 0
137
+
138
+ print(f" βœ… Generation complete in {generation_time:.2f}s", flush=True)
139
+ print(f" - Output tokens: {output_token_count}", flush=True)
140
+ print(f" - New tokens generated: {tokens_generated}", flush=True)
141
+ print(f" - Speed: {tokens_per_second:.2f} tokens/second", flush=True)
142
+ sys.stdout.flush()
143
+
144
+ print(f"\nπŸ”„ DECODING OUTPUT...", flush=True)
145
+ sys.stdout.flush()
146
+ decode_start = time.time()
147
+ # Decode the full output
148
+ full_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
149
+ decode_time = time.time() - decode_start
150
+ print(f" βœ… Decoding complete in {decode_time:.2f}s", flush=True)
151
+ sys.stdout.flush()
152
+
153
+ # Remove the prompt from the output to return only the generated text
154
+ generated_text = full_output[len(prompt):].strip()
155
+
156
+ print(f"\nπŸ“Š FINAL RESULTS:", flush=True)
157
+ print(f" - Generated text length: {len(generated_text)} characters", flush=True)
158
+ print(f" - Generated text preview: {generated_text[:200]}...", flush=True)
159
+ print(f" - Total processing time: {(time.time() - tokenize_start):.2f}s", flush=True)
160
+ sys.stdout.flush()
161
+
162
+ print(f"\nβœ… REQUEST COMPLETED SUCCESSFULLY", flush=True)
163
+ print("="*80 + "\n", flush=True)
164
+ sys.stdout.flush()
165
+
166
+ return jsonify({
167
+ 'generated_text': generated_text,
168
+ 'prompt': prompt
169
+ })
170
+
171
+ except Exception as e:
172
+ print(f"\n❌ ERROR DURING GENERATION:", flush=True)
173
+ print(f" Error type: {type(e).__name__}", flush=True)
174
+ print(f" Error message: {str(e)}", flush=True)
175
+ sys.stdout.flush()
176
+ import traceback
177
+ print(f" Traceback:\n{traceback.format_exc()}", flush=True)
178
+ print("="*80 + "\n", flush=True)
179
+ sys.stdout.flush()
180
+ return jsonify({'error': str(e)}), 500
181
+
182
+
183
+ @app.route('/api/health', methods=['GET'])
184
+ def health():
185
+ return jsonify({
186
+ 'status': 'ok',
187
+ 'model': MODEL_NAME,
188
+ 'device': str(model.device)
189
+ })
190
+
191
+
192
+ if __name__ == '__main__':
193
+ port = int(os.environ.get('PORT', 7860))
194
+ print(f"\n🌐 Starting Flask server on port {port}...")
195
+ print(f"πŸ”— Access the app at: http://localhost:{port}")
196
+ print(f"πŸ“Š Health check: http://localhost:{port}/api/health")
197
+ print(f"πŸš€ API endpoint: http://localhost:{port}/api/generate\n")
198
+ app.run(host='0.0.0.0', port=port, debug=False)
 
 
dataset.py ADDED
The diff for this file is too large to render. See raw diff