import os # Set cache directory for HF Spaces os.environ["HF_HOME"] = "/tmp" os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf-cache" import torch import re from fastapi import FastAPI from fastapi.responses import PlainTextResponse from pydantic import BaseModel from transformers import AutoTokenizer, AutoModelForCausalLM # Optional: speed up inference on CPU torch.set_num_threads(1) app = FastAPI() # Load model + tokenizer model_id = "misalsathsara/phi1.5-js-codegen" tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForCausalLM.from_pretrained(model_id) # Optional: Compile model if using PyTorch >= 2 (comment out if error) # model = torch.compile(model) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) model.eval() # JS assistant system prompt system_prompt = """ You are a smart javascript assistant that only generates only the best simple javascript functions without any comments like this: function transform(row) { row['Latitude'] = row['Location'].split(',')[0]; row['Longitude'] = row['Location'].split(',')[1]; return row; } when user gives a prompt like "convert the location field into separate latitude and longitude fields". Generate simple javascript functions that should take a single row of data as input and the generated function name is always transform. The user may use the words column, item or field to mean each column. Guard against null and undefined for items in the row. ${fieldList} Field names are case sensitive. For parsing something into a date, assume a function called parseAnyDate is available. If the code requires some numeric calculation - ensure the value is converted to a number first. Don't assume its always the correct data type. When doing any string comparison, make it case insensitive. When replacing characters in a string, make sure to use the correct replacement literal. For example, to replace hyphens with spaces, use: .replace(/-/g, ' ') The function should not include a single comment before or after the function. Don't add any text except for the function code. Don't add any markdown block markers either. Every function must end with return row; """ class RequestData(BaseModel): instruction: str @app.post("/generate") def generate_code(data: RequestData): instruction = data.instruction full_prompt = system_prompt + f"\n### Instruction:\n{instruction}\n\n### Response:\n" input_ids = tokenizer(full_prompt, return_tensors="pt").input_ids.to(device) with torch.no_grad(): output_ids = model.generate( input_ids, max_new_tokens=100, # Faster temperature=0.3, top_k=50, top_p=0.95, do_sample=True, pad_token_id=tokenizer.eos_token_id ) generated_text = tokenizer.decode(output_ids[0][input_ids.shape[-1]:], skip_special_tokens=True) # Extract only the JavaScript function that ends with return row; match = re.search(r"function\s+transform\s*\([^)]*\)\s*{[^}]*return row;\s*}", generated_text, re.DOTALL) if match: clean_output = match.group(0).strip() else: fallback = generated_text.split("return row;")[0] + "return row;\n}" clean_output = fallback.strip() return PlainTextResponse(clean_output)