File size: 1,126 Bytes
dd389e3
 
 
 
e06c2d9
dd389e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

class UserRequest(BaseModel):
    prompt: str

app = FastAPI()

# Load the model and tokenizer
model_name = "Artples/L-MChat-7b"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

# Make sure the model is on CPU
device = torch.device("cpu")
model.to(device)

@app.post("/generate/")
async def generate(request: UserRequest):
    try:
        # Tokenize the prompt
        inputs = tokenizer.encode(request.prompt, return_tensors="pt")
        inputs = inputs.to(device)
        
        # Generate a response from the model
        output = model.generate(inputs, max_length=100, num_return_sequences=1)
        response_text = tokenizer.decode(output[0], skip_special_tokens=True)
        
        return {"response": response_text}
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)