Spaces:
Sleeping
Sleeping
from flask import Flask, request, jsonify | |
from flask_cors import CORS | |
import tensorflow as tf | |
import numpy as np | |
import sentencepiece as spm | |
import pandas as pd | |
import random, re, json | |
from tensorflow.keras.preprocessing.sequence import pad_sequences | |
from tensorflow.keras.preprocessing.text import tokenizer_from_json | |
from sklearn.feature_extraction.text import TfidfVectorizer | |
from sklearn.metrics.pairwise import cosine_similarity | |
# ========== Global Variables ========== | |
model = None | |
encoder_tokenizer = None | |
sp = None | |
vectorizer = None | |
custom_vectors = None | |
custom = None | |
# ========== Util Functions ========== | |
def masked_loss(y_true, y_pred): | |
mask = tf.cast(tf.not_equal(y_true, 0), tf.float32) | |
loss = tf.keras.losses.sparse_categorical_crossentropy(y_true, y_pred) | |
return tf.reduce_sum(loss * mask) / tf.reduce_sum(mask) | |
def clean_text(text): | |
text = re.sub(r'\{\{.*?\}\}', '', text) | |
text = re.sub(r'\".*?\"', '', text) | |
text = re.sub(r'https?://\S+', '', text) | |
text = re.sub(r'[^a-zA-Z0-9\s.,!?]', '', text) | |
text = re.sub(r'\.{2,}', '.', text) | |
return text.strip().lower() | |
def postprocess(text): | |
text = re.sub(r'\s+', ' ', text) | |
text = re.sub(r'[^\w\s.,!?]', '', text) | |
text = re.sub(r'\s([.,!?])', r'\1', text) | |
return text.strip().capitalize() | |
def load_assets(): | |
global model, encoder_tokenizer, sp, custom, vectorizer, custom_vectors | |
if model is not None: | |
return | |
print("Loading model and assets...") | |
model_path = "model/chatbot_benerin_v6_4.keras" | |
model = tf.keras.models.load_model(model_path, custom_objects={"masked_loss": masked_loss}) | |
with open("model/encoder_tokenizer.json", "r", encoding="utf-8") as f: | |
encoder_tokenizer = tokenizer_from_json(f.read()) | |
sp = spm.SentencePieceProcessor() | |
sp.load("model/spm_tokenizer.model") | |
custom = pd.read_csv("model/benerin_electronics_dataset_v6.csv") | |
custom['instruction'] = custom['instruction'].apply(clean_text) | |
custom['response'] = custom['response'].apply(lambda x: re.sub(r'\s+', ' ', str(x)).strip()) | |
vectorizer = TfidfVectorizer() | |
custom_vectors = vectorizer.fit_transform(custom['instruction'].tolist()) | |
def nucleus_sampling_with_fallback(input_text, model, top_p=0.9, temperature=0.7, max_len=80, rerank_k=3): | |
cleaned_input = clean_text(input_text) | |
seq = encoder_tokenizer.texts_to_sequences([cleaned_input]) | |
seq = pad_sequences(seq, maxlen=64, padding='post') | |
candidates = [] | |
for _ in range(rerank_k): | |
generated = [sp.bos_id()] | |
for _ in range(max_len): | |
dec_input = pad_sequences([generated], maxlen=80, padding='post') | |
preds = model.predict([seq, dec_input], verbose=0)[0, len(generated) - 1] | |
preds = np.log(preds + 1e-8) / temperature | |
preds = np.exp(preds) / np.sum(np.exp(preds)) | |
sorted_idx = np.argsort(preds)[::-1] | |
cumulative_prob, filtered = 0.0, [] | |
for idx in sorted_idx: | |
cumulative_prob += preds[idx] | |
filtered.append(idx) | |
if cumulative_prob >= top_p: | |
break | |
if not filtered: | |
break | |
sampled_token = random.choice(filtered) | |
if sampled_token == sp.eos_id(): | |
break | |
generated.append(sampled_token) | |
decoded = sp.decode_ids([int(t) for t in generated if t not in [sp.bos_id(), sp.eos_id(), sp.pad_id()]]) | |
candidates.append(postprocess(decoded)) | |
input_vec = vectorizer.transform([cleaned_input]) | |
cand_vecs = vectorizer.transform(candidates) | |
sims = cosine_similarity(input_vec, cand_vecs)[0] | |
best_score = max(sims) | |
best_resp = candidates[int(np.argmax(sims))] | |
if best_score < 0.3: | |
fallback_idx = int(np.argmax(cosine_similarity(input_vec, custom_vectors)[0])) | |
return custom.iloc[fallback_idx]['response'].capitalize() | |
return best_resp | |
# ========== Flask App ========== | |
app = Flask(__name__) | |
CORS(app) | |
def chat(): | |
load_assets() | |
data = request.json | |
instr_text = data.get("instruction", "") | |
if not instr_text: | |
return jsonify({"error": "No instruction provided"}), 400 | |
try: | |
response = nucleus_sampling_with_fallback(instr_text, model) | |
return jsonify({"response": response}) | |
except Exception as e: | |
return jsonify({"error": str(e)}), 500 | |
if __name__ == "__main__": | |
app.run(debug=False, port=7860, host="0.0.0.0") |