phi15-js-api / app.py
misalsathsara's picture
Update app.py
c14d360 verified
import os
# Set cache directory for HF Spaces
os.environ["HF_HOME"] = "/tmp"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf-cache"
import torch
import re
from fastapi import FastAPI
from fastapi.responses import PlainTextResponse
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForCausalLM
# Optional: speed up inference on CPU
torch.set_num_threads(1)
app = FastAPI()
# Load model + tokenizer
model_id = "misalsathsara/phi1.5-js-codegen"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)
# Optional: Compile model if using PyTorch >= 2 (comment out if error)
# model = torch.compile(model)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
# JS assistant system prompt
system_prompt = """
You are a smart javascript assistant that only generates only the best simple javascript functions without any comments like this:
function transform(row) {
row['Latitude'] = row['Location'].split(',')[0];
row['Longitude'] = row['Location'].split(',')[1];
return row;
}
when user gives a prompt like "convert the location field into separate latitude and longitude fields".
Generate simple javascript functions that should take a single row of data as input and the generated function name is always transform.
The user may use the words column, item or field to mean each column.
Guard against null and undefined for items in the row.
${fieldList}
Field names are case sensitive.
For parsing something into a date, assume a function called parseAnyDate is available.
If the code requires some numeric calculation - ensure the value is converted to a number first. Don't assume its always the correct data type.
When doing any string comparison, make it case insensitive.
When replacing characters in a string, make sure to use the correct replacement literal. For example, to replace hyphens with spaces, use: .replace(/-/g, ' ')
The function should not include a single comment before or after the function.
Don't add any text except for the function code.
Don't add any markdown block markers either.
Every function must end with return row;
"""
class RequestData(BaseModel):
instruction: str
@app.post("/generate")
def generate_code(data: RequestData):
instruction = data.instruction
full_prompt = system_prompt + f"\n### Instruction:\n{instruction}\n\n### Response:\n"
input_ids = tokenizer(full_prompt, return_tensors="pt").input_ids.to(device)
with torch.no_grad():
output_ids = model.generate(
input_ids,
max_new_tokens=100, # Faster
temperature=0.3,
top_k=50,
top_p=0.95,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
generated_text = tokenizer.decode(output_ids[0][input_ids.shape[-1]:], skip_special_tokens=True)
# Extract only the JavaScript function that ends with return row;
match = re.search(r"function\s+transform\s*\([^)]*\)\s*{[^}]*return row;\s*}", generated_text, re.DOTALL)
if match:
clean_output = match.group(0).strip()
else:
fallback = generated_text.split("return row;")[0] + "return row;\n}"
clean_output = fallback.strip()
return PlainTextResponse(clean_output)