| | import torch |
| | import re |
| | from html import unescape |
| | from transformers import GPT2LMHeadModel, GPT2Tokenizer |
| | from peft import PeftModel |
| | from transformers import StoppingCriteria, StoppingCriteriaList |
| | from difflib import SequenceMatcher |
| | from flask import Flask, request, jsonify |
| |
|
| | |
| | |
| | |
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| | print(f"🚀 Running on device: {device}") |
| |
|
| | |
| | |
| | |
| | model_path = "./" |
| | try: |
| | tokenizer = GPT2Tokenizer.from_pretrained(model_path) |
| | tokenizer.pad_token = tokenizer.eos_token |
| | print("✅ Tokenizer loaded successfully") |
| | except Exception as e: |
| | print(f"❌ Error loading tokenizer: {e}") |
| | exit() |
| |
|
| | |
| | |
| | |
| | quant_config = None |
| | if torch.cuda.is_available(): |
| | try: |
| | from transformers import BitsAndBytesConfig |
| | quant_config = BitsAndBytesConfig( |
| | load_in_4bit=True, |
| | bnb_4bit_use_double_quant=True, |
| | bnb_4bit_quant_type="nf4", |
| | bnb_4bit_compute_dtype=torch.bfloat16 |
| | ) |
| | print("✅ Using 4-bit quantization (GPU mode)") |
| | except Exception as e: |
| | print("⚠️ BitsAndBytes not available, continuing without quantization:", e) |
| | else: |
| | print("💡 CPU mode — quantization disabled") |
| |
|
| | try: |
| | base_model = GPT2LMHeadModel.from_pretrained( |
| | model_path, |
| | quantization_config=quant_config, |
| | device_map={"": 0} if torch.cuda.is_available() else None, |
| | low_cpu_mem_usage=True, |
| | torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 |
| | ).to(device) |
| | print("✅ Base model loaded successfully") |
| | except Exception as e: |
| | print(f"❌ Error loading base model: {e}") |
| | exit() |
| |
|
| | |
| | |
| | |
| | try: |
| | model = PeftModel.from_pretrained( |
| | base_model, |
| | model_path, |
| | is_trainable=False, |
| | device_map={"": 0} if torch.cuda.is_available() else None |
| | ) |
| | model.to(device) |
| | print("✅ PEFT model loaded successfully") |
| | except Exception as e: |
| | print(f"⚠️ Warning: Failed to load PEFT adapter, using base model. ({e})") |
| | model = base_model |
| |
|
| | |
| | |
| | |
| | system_prompt = """You are GPT-A, a friendly AI assistant made by LuxAI. |
| | You must answer very short and cooherent.""" |
| |
|
| | |
| | |
| | |
| | class CustomStoppingCriteria(StoppingCriteria): |
| | def __init__(self, stop_token_id): |
| | self.stop_token_id = stop_token_id |
| |
|
| | def __call__(self, input_ids, scores, **kwargs): |
| | return input_ids[0][-1] == self.stop_token_id or len(input_ids[0]) > 512 |
| |
|
| | stopping_criteria = StoppingCriteriaList([CustomStoppingCriteria(tokenizer.eos_token_id)]) |
| |
|
| | |
| | |
| | |
| | def clean_response(text): |
| | """Odstraní HTML, Markdown a redundantní mezery.""" |
| | original_text = text |
| | text = re.sub(r"<[^>]+>", " ", text) |
| | text = unescape(text) |
| | text = re.sub(r"[*#`_~]+", "", text) |
| | text = re.sub(r"\s+", " ", text).strip() |
| | if text != original_text: |
| | print("🧹 Cleaned response.") |
| | return text |
| |
|
| |
|
| | def remove_repetitions(text, similarity_threshold=0.8): |
| | """Odstraní opakující se věty.""" |
| | sentences = re.split(r'(?<=[.!?])\s+', text) |
| | if len(sentences) <= 1: |
| | return text |
| | unique_sentences = [] |
| | for sent in sentences: |
| | sent_clean = sent.strip() |
| | if not sent_clean: |
| | continue |
| | if not unique_sentences or SequenceMatcher(None, sent_clean, unique_sentences[-1]).ratio() < similarity_threshold: |
| | unique_sentences.append(sent_clean) |
| | return " ".join(unique_sentences) |
| |
|
| |
|
| | def truncate_to_last_sentence(text): |
| | """Zkrátí text na poslední dokončenou větu.""" |
| | sentences = re.split(r'(?<=[.!?])\s+', text) |
| | for i in range(len(sentences) - 1, -1, -1): |
| | if re.search(r'[.!?]$', sentences[i].strip()): |
| | return " ".join(sentences[:i+1]).strip() |
| | return text.strip() |
| |
|
| | |
| | |
| | |
| | def generate_response( |
| | user_input, |
| | max_length=2048, |
| | temperature=0.7, |
| | top_k=50, |
| | top_p=0.7, |
| | repetition_penalty=10.0, |
| | num_beams=4, |
| | early_stopping=True, |
| | do_sample=True |
| | ): |
| | try: |
| | prompt = f"{system_prompt}\n\nUser: {user_input}\nAssistant:" |
| | inputs = tokenizer(prompt, return_tensors="pt").to(device) |
| | print(f"📥 Input on device: {inputs['input_ids'].device}") |
| |
|
| | with torch.no_grad(): |
| | outputs = model.generate( |
| | **inputs, |
| | max_length=max_length, |
| | temperature=temperature if do_sample else 1.0, |
| | top_k=top_k if do_sample else None, |
| | top_p=top_p if do_sample else None, |
| | repetition_penalty=repetition_penalty, |
| | num_beams=num_beams, |
| | early_stopping=early_stopping if num_beams > 1 else False, |
| | num_return_sequences=1, |
| | pad_token_id=tokenizer.eos_token_id, |
| | eos_token_id=tokenizer.eos_token_id, |
| | do_sample=do_sample, |
| | stopping_criteria=stopping_criteria |
| | ) |
| |
|
| | generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
| | response = generated_text.split("Assistant:")[-1].strip() |
| |
|
| | response = clean_response(response) |
| | response = remove_repetitions(response) |
| | response = truncate_to_last_sentence(response) |
| |
|
| | return response |
| |
|
| | except Exception as e: |
| | print(f"❌ Error during generation: {e}") |
| | return None |
| |
|
| | |
| | |
| | |
| | app = Flask(__name__) |
| |
|
| | @app.route('/generate', methods=['POST']) |
| | def generate_text(): |
| | data = request.get_json() |
| | if not data or 'user_input' not in data: |
| | return jsonify({'error': 'Missing user_input parameter'}), 400 |
| |
|
| | user_input = data['user_input'] |
| | generated_response = generate_response(user_input) |
| |
|
| | if generated_response is None: |
| | return jsonify({'error': 'Failed to generate response'}), 500 |
| |
|
| | return jsonify({'response': generated_response}) |
| |
|
| | |
| | |
| | |
| | if __name__ == '__main__': |
| | app.run(host='0.0.0.0', port=7860) |
| |
|