File size: 1,900 Bytes
b269ebb
b54a7c5
 
 
b269ebb
 
 
adb76ba
b54a7c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b269ebb
 
 
 
 
 
 
 
b54a7c5
 
b269ebb
 
 
b54a7c5
b269ebb
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
from flask import Flask, request, jsonify
from transformers import AutoTokenizer, TFAutoModel
import tensorflow as tf
import numpy as np

app = Flask(__name__)

# Load PhoBERT (TensorFlow version)
MODEL_NAME = "vinai/phobert-base"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = TFAutoModel.from_pretrained(MODEL_NAME)

MAX_LEN = 256
STRIDE = 128

def split_text_into_chunks(text):
    tokens = tokenizer.encode(text, add_special_tokens=True)
    chunks = []
    for i in range(0, len(tokens), STRIDE):
        chunk = tokens[i:i + MAX_LEN]
        if len(chunk) < MAX_LEN:
            chunk += [tokenizer.pad_token_id] * (MAX_LEN - len(chunk))
        chunks.append(chunk)
        if i + MAX_LEN >= len(tokens):
            break
    return chunks

def embed_text(text):
    chunks = split_text_into_chunks(text)
    embeddings = []

    for chunk in chunks:
        input_ids = tf.constant([chunk])
        attention_mask = tf.cast(input_ids != tokenizer.pad_token_id, tf.int32)
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)

        hidden_states = outputs.last_hidden_state
        mask = tf.cast(tf.expand_dims(attention_mask, -1), tf.float32)
        summed = tf.reduce_sum(hidden_states * mask, axis=1)
        count = tf.reduce_sum(mask, axis=1)
        mean_pooled = summed / count
        embeddings.append(mean_pooled.numpy()[0])

    final_embedding = np.mean(embeddings, axis=0)
    return final_embedding.tolist()

@app.route('/embed', methods=['POST'])
def embed():
    data = request.get_json()
    text = data.get('text', '')
    if not text:
        return jsonify({"error": "No text provided"}), 400

    embedding = embed_text(text)
    return jsonify({"embedding": embedding})

@app.route('/', methods=['GET'])
def index():
    return "PhoBERT vector API is running!"

if __name__ == "__main__":
    app.run(host="0.0.0.0", port=7860)