--- license: apache-2.0 --- ## 简介 该款自然语言生成 SQL 的模型(NL2SQL/Text2SQL)是以 [replit-code-v1-3b](https://huggingface.co/replit/replit-code-v1-3b) 代码续写预训练模型为基础进行 LoRA 微调的,这里仅提供 LoRA 权重(大概 11M),推理时需要结合原始预训练模型一起使用,具体参考下文示例。 ## 用法 NL2SQL 任务中输入参数含有用户查询文本+数据库表信息,目前按照以下格式拼接模型的输入文本: ``` # Table Allergy_Type , columns = [ Allergy , AllergyType ] # Table Has_Allergy , columns = [ StuID , Allergy ] # Table Student , columns = [ StuID , LName , Fname , Age , Sex , Major , Advisor , city_code ] # primary keys: [ Allergy_Type.Allergy , Student.StuID ] # foreign keys: [ Has_Allergy.Allergy = Allergy_Type.Allergy , Has_Allergy.StuID = Student.StuID ] # Create a query for question: 显示所有男生的学生ID。 query = ``` 具体使用方法参考以下示例: ```python import sqlparse import torch from peft import PeftModel from transformers import AutoTokenizer, AutoModelForCausalLM, TextGenerationPipeline device = 'cuda' base_model_path = 'replit/replit-code-v1-3b' lora_model_path = 'DMetaSoul/nl2sql-chinese-standard-3b-lora' sampling = False tokenizer = AutoTokenizer.from_pretrained(base_model_path, trust_remote_code=True, padding_side='left') model = AutoModelForCausalLM.from_pretrained(base_model_path, trust_remote_code=True, torch_dtype=torch.float16) if lora_model_path: model = PeftModel.from_pretrained(model, lora_model_path, torch_dtype=torch.float16) model.eval() model.to(device) input_texts = [ "# Table Allergy_Type , columns = [ Allergy , AllergyType ]\n# Table Has_Allergy , columns = [ StuID , Allergy ]\n# Table Student , columns = [ StuID , LName , Fname , Age , Sex , Major , Advisor , city_code ]\n# primary keys: [ Allergy_Type.Allergy , Student.StuID ]\n# foreign keys: [ Has_Allergy.Allergy = Allergy_Type.Allergy , Has_Allergy.StuID = Student.StuID ]\n# Create a query for question: 显示所有女学生的名、 姓氏、年龄。他们的性别是“女”.\nquery =", "# Table Allergy_Type , columns = [ Allergy , AllergyType ]\n# Table Has_Allergy , columns = [ StuID , Allergy ]\n# Table Student , columns = [ StuID , LName , Fname , Age , Sex , Major , Advisor , city_code ]\n# primary keys: [ Allergy_Type.Allergy , Student.StuID ]\n# foreign keys: [ Has_Allergy.Allergy = Allergy_Type.Allergy , Has_Allergy.StuID = Student.StuID ]\n# Create a query for question: 显示所有男生的学生ID。\nquery =", ] inputs = tokenizer(input_texts, max_length=512, return_tensors="pt", padding=True, truncation=True) inputs = {k:v.to(device) for k,v in inputs.items()} with torch.no_grad(): if sampling: outputs = model.generate(**inputs, do_sample=True, top_k=50, top_p=0.95, temperature=1.0, num_return_sequences=1, return_full_text=False, max_length=512, return_dict_in_generate=True, output_scores=True) else: outputs = model.generate(**inputs, num_beams=4, num_return_sequences=1, return_full_text=False max_length=512, return_dict_in_generate=True, output_scores=True) output_ids = outputs.sequences results = tokenizer.batch_decode(output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True) for question, sql in zip(input_texts, results): print(question) print('SQL: {}'.format(sqlparse.format(sql, reindent=True))) ```