raghss123123 commited on
Commit
48b2e5c
·
verified ·
1 Parent(s): f9680ae

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +150 -0
app.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ from transformers import AutoTokenizer, AutoModel
5
+ from flask import Flask, request, jsonify
6
+ import logging
7
+
8
+ # Configure logging
9
+ logging.basicConfig(level=logging.INFO)
10
+ logger = logging.getLogger(__name__)
11
+
12
+ app = Flask(__name__)
13
+
14
+ # Qwen3-Embedding-4B model for retrieval
15
+ MODEL_NAME = "Qwen/Qwen3-Embedding-4B"
16
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
17
+ EMBEDDING_DIM = 2560 # Max dimension for Qwen3-Embedding-4B
18
+
19
+ class EmbeddingModel:
20
+ def __init__(self):
21
+ logger.info(f"Loading {MODEL_NAME} on {DEVICE}")
22
+ self.tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
23
+ self.model = AutoModel.from_pretrained(MODEL_NAME)
24
+ self.model.to(DEVICE)
25
+ self.model.eval()
26
+ logger.info("✅ Model loaded successfully")
27
+
28
+ def encode(self, texts, batch_size=16):
29
+ """Encode texts to embeddings using Qwen3-Embedding-4B"""
30
+ if isinstance(texts, str):
31
+ texts = [texts]
32
+
33
+ embeddings = []
34
+
35
+ for i in range(0, len(texts), batch_size):
36
+ batch_texts = texts[i:i + batch_size]
37
+
38
+ # Qwen3 instruction format for retrieval
39
+ batch_texts = [f"Instruct: Retrieve semantically similar text.\nQuery: {text}" for text in batch_texts]
40
+
41
+ inputs = self.tokenizer(
42
+ batch_texts,
43
+ padding="left", # Qwen3 recommendation
44
+ truncation=True,
45
+ max_length=32768, # Qwen3 supports up to 32k context
46
+ return_tensors="pt"
47
+ ).to(DEVICE)
48
+
49
+ with torch.no_grad():
50
+ outputs = self.model(**inputs)
51
+ # Use EOS token embedding for Qwen3
52
+ eos_token_id = self.tokenizer.eos_token_id
53
+ sequence_lengths = (inputs['input_ids'] == eos_token_id).long().argmax(-1) - 1
54
+
55
+ batch_embeddings = []
56
+ for j, seq_len in enumerate(sequence_lengths):
57
+ embedding = outputs.last_hidden_state[j, seq_len, :].cpu().numpy()
58
+ batch_embeddings.append(embedding)
59
+
60
+ batch_embeddings = np.array(batch_embeddings)
61
+
62
+ # Normalize embeddings
63
+ batch_embeddings = batch_embeddings / np.linalg.norm(batch_embeddings, axis=1, keepdims=True)
64
+
65
+ embeddings.extend(batch_embeddings)
66
+
67
+ return embeddings
68
+
69
+ # Global model instance
70
+ embedding_model = None
71
+
72
+ def get_model():
73
+ global embedding_model
74
+ if embedding_model is None:
75
+ embedding_model = EmbeddingModel()
76
+ return embedding_model
77
+
78
+ @app.route("/", methods=["GET"])
79
+ def health_check():
80
+ return jsonify({
81
+ "status": "healthy",
82
+ "model": MODEL_NAME,
83
+ "device": DEVICE,
84
+ "embedding_dim": EMBEDDING_DIM,
85
+ "max_context": 32768
86
+ })
87
+
88
+ @app.route("/embed", methods=["POST"])
89
+ def embed_texts():
90
+ """Embed texts and return embeddings"""
91
+ try:
92
+ data = request.get_json()
93
+
94
+ if not data or "texts" not in data:
95
+ return jsonify({"error": "Missing 'texts' field"}), 400
96
+
97
+ texts = data["texts"]
98
+ if not isinstance(texts, list):
99
+ texts = [texts]
100
+
101
+ logger.info(f"Embedding {len(texts)} texts")
102
+
103
+ model = get_model()
104
+ embeddings = model.encode(texts)
105
+
106
+ return jsonify({
107
+ "embeddings": [embedding.tolist() for embedding in embeddings],
108
+ "model": MODEL_NAME,
109
+ "dimension": len(embeddings[0]) if embeddings else 0,
110
+ "count": len(embeddings)
111
+ })
112
+
113
+ except Exception as e:
114
+ logger.error(f"Embedding error: {str(e)}")
115
+ return jsonify({"error": str(e)}), 500
116
+
117
+ @app.route("/embed_single", methods=["POST"])
118
+ def embed_single():
119
+ """Embed single text (convenience endpoint)"""
120
+ try:
121
+ data = request.get_json()
122
+
123
+ if not data or "text" not in data:
124
+ return jsonify({"error": "Missing 'text' field"}), 400
125
+
126
+ text = data["text"]
127
+ logger.info(f"Embedding single text: {text[:100]}...")
128
+
129
+ model = get_model()
130
+ embeddings = model.encode([text])
131
+
132
+ return jsonify({
133
+ "embedding": embeddings[0].tolist(),
134
+ "model": MODEL_NAME,
135
+ "dimension": len(embeddings[0]),
136
+ "text_length": len(text)
137
+ })
138
+
139
+ except Exception as e:
140
+ logger.error(f"Single embedding error: {str(e)}")
141
+ return jsonify({"error": str(e)}), 500
142
+
143
+ if __name__ == "__main__":
144
+ # Initialize model on startup
145
+ logger.info("🚀 Starting embedding service...")
146
+ get_model()
147
+ logger.info("✅ Service ready!")
148
+
149
+ port = int(os.environ.get("PORT", 7860))
150
+ app.run(host="0.0.0.0", port=port)