File size: 2,480 Bytes
72b0049
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d0ec710
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import os
from fastapi import FastAPI
from typing import Union, List, Dict, Tuple, Optional
from pydantic import BaseModel, Field
from angle_emb import AnglE

class EmbeddingInput(BaseModel):
    input: Union[List[str], Tuple[str], List[Dict], str] = Field(..., description="The input to be encoded")
    model: Optional[str] = None
    encoding_format: Optional[str] = 'float'
    dimensions: Optional[int] = None
    user: Optional[str] = None

app = FastAPI()

# Get the model name and path from the environment variables
model_name = os.getenv('MODEL_NAME', default='WhereIsAI/UAE-Large-V1')
model_path = os.getenv('MODEL_PATH', default='models/WhereIsAI/UAE-Large-V1')

# Load the model
try:
    angle_model = AnglE.from_pretrained(model_path, pooling_strategy='cls').to('cpu')
except Exception as e:
    print(f"Failed to load model from path {model_path}. Error: {str(e)}")

@app.get("/")
def read_root():
    return {
        "model_name": model_name,
        "model_path": model_path,
        "message": "Model is up and running",
        "route_info": {
            "/": "Returns the model info",
            "/health": "Returns the health status of the application",
            "/v1/embeddings": 'POST route to get embeddings. Usage: curl -H "Content-Type: application/json" -d \'{ "input": "Your text string goes here" }\' http://localhost:8080/v1/embeddings'
        }
    }

@app.get("/health")
def health_check():
    return {"health": "ok"}

@app.post("/v1/embeddings")
def get_embeddings(embedding_input: EmbeddingInput):
    # # Check if the input is an empty string
    # if not embedding_input.input.strip():
    #     return {
    #         "object": "list",
    #         "data": [],
    #         "model": model_name,
    #         "usage": {"prompt_tokens": 0, "total_tokens": 0},
    #     }

    # Encode the input text using the model
    embeddings = angle_model.encode(embedding_input.input, embedding_size=embedding_input.dimensions)

    # Create a response format compatible with OpenAI's API
    response = {
        "object": "list",
        "data": [
            {"object": "embedding", "index": i, "embedding": emb.tolist()}
            for i, emb in enumerate(embeddings)
        ],
        "model": model_name,
        "usage": {"prompt_tokens": len(embedding_input.input), "total_tokens": len(embedding_input.input)},
    }

    return response

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8080)