File size: 2,380 Bytes
29969bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from fastapi import FastAPI, HTTPException
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import logging
from pydantic import BaseModel
import os
import tarfile

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Debug environment variables
logger.info("Environment variables: %s", {k: "****" if "TOKEN" in k or k == "granite" else v for k, v in os.environ.items()})

app = FastAPI()

model_tarball = "/app/granite-8b-finetuned-ascii.tar.gz"
model_path = "/app/granite-8b-finetuned-ascii"

# Extract tarball if model directory doesn't exist
if not os.path.exists(model_path):
    logger.info(f"Extracting model tarball: {model_tarball}")
    try:
        with tarfile.open(model_tarball, "r:gz") as tar:
            tar.extractall(path="/app")
        logger.info("Model tarball extracted successfully")
    except Exception as e:
        logger.error(f"Failed to extract model tarball: {str(e)}")
        raise HTTPException(status_code=500, detail=f"Model tarball extraction failed: {str(e)}")

try:
    logger.info("Loading tokenizer and model")
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
    tokenizer.padding_side = 'right'
    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        torch_dtype=torch.float16,
        device_map="auto",
        trust_remote_code=True
    )
    logger.info("Model and tokenizer loaded successfully")
except Exception as e:
    logger.error(f"Failed to load model or tokenizer: {str(e)}")
    raise HTTPException(status_code=500, detail=f"Model initialization failed: {str(e)}")

class EditRequest(BaseModel):
    text: str

@app.get("/")
def greet_json():
    return {"status": "Model is ready", "model": model_path}

@app.post("/generate")
async def generate(request: EditRequest):
    try:
        prompt = f"Edit this AsciiDoc sentence: {request.text}"
        inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
        outputs = model.generate(**inputs, max_length=200)
        response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        logger.info(f"Generated response for prompt: {prompt}")
        return {"response": response}
    except Exception as e:
        logger.error(f"Generation failed: {str(e)}")
        raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")