--- frameworks: - Pytorch license: other tasks: - text-generation --- # Model Card for NL2SQL-StarCoder-15B ## Model Inro NL2SQL-StarCoder-15B is a NLP-SQL model fintuned by QLoRAbased on StarCoder 15B Code-LLM。 ## Requirements - python>=3.8 - pytorch>=2.0.0 - transformers==4.32.0 - CUDA 11.4 ## Data Format The data is in the form of a string spliced by the model in the training data format, which is also how the input PROMPT is spliced during inference: ```python """ <|user|> /* Given the following database schema: */ CREATE TABLE "table_name" ( "col1" int, ... ... ) /* Write a sql to answer the following question: {Question} */ <|assistant|> ```sql {Output SQL} ```<|end|> """ ``` ## Quick Start ```python import torch from transformers import AutoModelForCausalLM, AutoTokenizer model_dir = "gabrielpondc/NL2SQL-StarCoder-15B" tokenizer = AutoTokenizer.from_pretrained(model_dir, device_map="auto", trust_remote_code=True, torch_dtype=torch.float16) tokenizer.padding_side = "left" tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids("") tokenizer.eos_token_id = tokenizer.convert_tokens_to_ids("<|endoftext|>") tokenizer.pad_token = "" tokenizer.eos_token = "<|endoftext|>" model = AutoModelForCausalLM.from_pretrained(model_dir, device_map="auto", trust_remote_code=True, torch_dtype=torch.float16) model.eval() text = '<|user|>\n/* Given the following database schema: */\nCREATE TABLE "singer" (\n"Singer_ID" int,\n"Name" text,\n"Country" text,\n"Song_Name" text,\n"Song_release_year" text,\n"Age" int,\n"Is_male" bool,\nPRIMARY KEY ("Singer_ID")\n)\n\n/* Write a sql to answer the following question: Show countries where a singer above age 40 and a singer below 30 are from. */<|end|>\n' inputs = tokenizer(text, return_tensors='pt', padding=True, add_special_tokens=False).to("cuda") outputs = model.generate( inputs=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_new_tokens=512, top_p=0.95, temperature=0.1, do_sample=False, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id ) gen_text = tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[1]:], skip_special_tokens=True) print(gen_text) ```