File size: 2,982 Bytes
b9b96cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3bd5de9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
from flask import Flask, render_template, request, jsonify
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import torch.nn.functional as F
from scipy.stats import percentileofscore

app = Flask(__name__)

DEFAULT_MODEL = "gpt2"

model_cache = {}
tokenizer_cache = {}


def get_model_and_tokenizer(model_name):
    if model_name not in model_cache:
        trust_code = model_name == "microsoft/phi-1_5"
        model_cache[model_name] = AutoModelForCausalLM.from_pretrained(
            model_name, trust_remote_code=trust_code
        )
        tokenizer_cache[model_name] = AutoTokenizer.from_pretrained(
            model_name, trust_remote_code=trust_code
        )
    return model_cache[model_name], tokenizer_cache[model_name]


@app.route("/")
def index():
    return render_template(
        "index.html",
        models=[
            DEFAULT_MODEL,
            # "gpt2-medium",
            # "gpt2-large",
            # "gpt2-xl",
            # "EleutherAI/pythia-1.4b",
            # "facebook/opt-1.3b",
            # "bigscience/bloom-1b7",
            # "microsoft/phi-1_5",
            "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
        ],
    )


@app.route("/analyze", methods=["POST"])
def analyze():
    data = request.get_json()
    text = data["text"]
    model_name = data["model"]

    model, tokenizer = get_model_and_tokenizer(model_name)
    model.eval()

    with torch.no_grad():
        inputs = tokenizer(text, return_tensors="pt")
        outputs = model(**inputs)
        logits = outputs.logits

        input_ids = inputs["input_ids"][0]
        tokens = tokenizer.convert_ids_to_tokens(input_ids)

        log_probs = []
        all_log_probs_list = []
        top_k_predictions = []

        for i in range(len(input_ids) - 1):
            probs_at_position = F.log_softmax(logits[0, i, :], dim=-1)
            all_log_probs_list.extend(probs_at_position.tolist())

            top_k_values, top_k_indices = torch.topk(probs_at_position, 5)
            top_k_tokens = tokenizer.convert_ids_to_tokens(top_k_indices)
            top_k_predictions.append(
                [
                    {"token": t, "log_prob": v.item()}
                    for t, v in zip(top_k_tokens, top_k_values)
                ]
            )

            log_prob = probs_at_position[input_ids[i + 1]].item()
            log_probs.append(log_prob)

        percentiles = [percentileofscore(all_log_probs_list, lp) for lp in log_probs]
        joint_log_likelihood = sum(log_probs)
        average_log_likelihood = (
            joint_log_likelihood / len(log_probs) if log_probs else 0
        )

    return jsonify({
        "tokens": tokens,
        "percentiles": percentiles,
        "log_probs": log_probs,
        "top_k_predictions": top_k_predictions,
        "joint_log_likelihood": joint_log_likelihood,
        "average_log_likelihood": average_log_likelihood,
    })

if __name__ == "__main__":
    app.run(host="0.0.0.0", port=7860)