| import logging
|
| import torch
|
| import json
|
| import os
|
| from flask import Flask, render_template, request, Response, stream_with_context
|
| from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer, BitsAndBytesConfig
|
| from threading import Thread
|
| import random
|
| import numpy as np
|
|
|
| logging.basicConfig(level=logging.INFO)
|
| app = Flask(__name__)
|
|
|
| MODEL_NAME = "Mattimax/DATA-AI_Chat_3_0.5B"
|
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
| SEED = 42
|
| random.seed(SEED)
|
| np.random.seed(SEED)
|
| torch.manual_seed(SEED)
|
| if device == "cuda":
|
| torch.cuda.manual_seed_all(SEED)
|
| torch.backends.cudnn.deterministic = True
|
|
|
|
|
| bnb_config = None
|
| if device == "cuda":
|
| bnb_config = BitsAndBytesConfig(
|
| load_in_4bit=True,
|
| bnb_4bit_compute_dtype=torch.float16
|
| )
|
|
|
| logging.info("Caricamento tokenizer e modello: %s (device=%s)", MODEL_NAME, device)
|
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
|
|
|
|
| chat_template = None
|
| try:
|
| from transformers.models.auto.tokenization_auto import get_tokenizer_config
|
| config_dict = get_tokenizer_config(MODEL_NAME)
|
| chat_template = config_dict.get("chat_template")
|
| logging.info("Chat template caricato: %s", chat_template[:100] if chat_template else "Non disponibile")
|
| except Exception as e:
|
| logging.warning("Impossibile caricare chat_template: %s", e)
|
|
|
| chat_template = "{% for message in messages %}{% if message['role'] == 'user' %}User: {{ message['content'] }}\n{% elif message['role'] == 'assistant' %}Assistant: {{ message['content'] }}\n{% else %}{{ message['role'] }}: {{ message['content'] }}\n{% endif %}{% endfor %}"
|
|
|
|
|
| if tokenizer.pad_token_id is None:
|
| tokenizer.pad_token_id = tokenizer.eos_token_id
|
|
|
| if device == "cuda":
|
|
|
| model = AutoModelForCausalLM.from_pretrained(
|
| MODEL_NAME,
|
| quantization_config=bnb_config,
|
| device_map="auto"
|
| )
|
| else:
|
|
|
| model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
|
| model.to("cpu")
|
|
|
| model.eval()
|
|
|
|
|
| SYSTEM_PROMPT = """Tu sei DAC, un assistente intelligente e amichevole. Rispondi in modo coerente, chiaro e utile.
|
| Se non conosci la risposta, ammettilo con sincerità. Mantieni il tono professionale ma accessibile."""
|
|
|
|
|
| @app.route('/')
|
| def index():
|
| return render_template('index.html')
|
|
|
|
|
| @app.route('/chat', methods=['POST'])
|
| def chat():
|
| data = request.json or {}
|
| user_input = data.get("message", "")
|
| if not user_input:
|
| return Response(json.dumps({"error": "empty message"}), status=400)
|
|
|
|
|
| messages = [
|
| {"role": "system", "content": SYSTEM_PROMPT},
|
| {"role": "user", "content": user_input}
|
| ]
|
|
|
|
|
| if chat_template and hasattr(tokenizer, 'apply_chat_template'):
|
| try:
|
| prompt_text = tokenizer.apply_chat_template(
|
| messages,
|
| tokenize=False,
|
| add_generation_prompt=True
|
| )
|
| except Exception as e:
|
| logging.warning("Errore applicando chat_template: %s, fallback a prompt semplice", e)
|
| prompt_text = f"System: {SYSTEM_PROMPT}\nUser: {user_input}\nAssistant:"
|
| else:
|
|
|
| prompt_text = f"System: {SYSTEM_PROMPT}\nUser: {user_input}\nAssistant:"
|
|
|
| logging.info("Prompt generato: %s", prompt_text[:200])
|
|
|
| streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
|
|
| inputs = tokenizer(prompt_text, return_tensors="pt")
|
|
|
| if device == "cuda":
|
| inputs = {k: v.to("cuda") for k, v in inputs.items()}
|
|
|
|
|
| generation_kwargs = dict(
|
| input_ids=inputs.get("input_ids"),
|
| attention_mask=inputs.get("attention_mask"),
|
| streamer=streamer,
|
| max_new_tokens=2048,
|
| temperature=0.5,
|
| do_sample=True,
|
| top_p=0.80,
|
| top_k=40,
|
| repetition_penalty=1.2,
|
| pad_token_id=tokenizer.pad_token_id,
|
| eos_token_id=tokenizer.eos_token_id,
|
| no_repeat_ngram_size=4,
|
| early_stopping=False,
|
| )
|
|
|
| def run_generate():
|
| try:
|
| with torch.no_grad():
|
| model.generate(**generation_kwargs)
|
| except Exception as e:
|
| logging.exception("Errore durante la generazione:")
|
|
|
| thread = Thread(target=run_generate)
|
| thread.daemon = True
|
| thread.start()
|
|
|
| def generate():
|
| try:
|
|
|
| for new_text in streamer:
|
| yield f"data: {json.dumps({'token': new_text})}\n\n"
|
| except GeneratorExit:
|
| logging.info("Client disconnected dalla stream")
|
| except Exception:
|
| logging.exception("Errore nello stream")
|
|
|
| headers = {
|
| 'Cache-Control': 'no-cache',
|
| 'X-Accel-Buffering': 'no'
|
| }
|
|
|
| return Response(stream_with_context(generate()), mimetype='text/event-stream', headers=headers)
|
|
|
|
|
| if __name__ == "__main__":
|
|
|
| logging.info("Avvio app su 0.0.0.0:7860")
|
| app.run(host='0.0.0.0', port=7860, threaded=True) |