phi-2-mongodb / README.md
Chirayu's picture
Update README.md
ee2e597 verified
|
raw
history blame
4.99 kB
metadata
library_name: peft
base_model: microsoft/phi-2

Model Card for Model ID

phi-2-mongodb is a fine-tuned version of microsoft/phi-2 to generate MongoDB pipeline queries. It was fine-tuned on a custom curated natural language to MongoDB queries dataset, I'll be releasing that next week.

Model Details

Further details about fine-tuned model can be found at : https://github.com/Chirayu-Tripathi/nl2query. It can also be used via nl2query library.

Model Description

Prompt Template

prompt_template = f"""<s> 
Task Description:
Your task is to create a MongoDB query that accurately fulfills the provided Instruct while strictly adhering to the given MongoDB schema. Ensure that the query solely relies on keys and columns present in the schema. Minimize the usage of lookup operations wherever feasible to enhance query efficiency.

MongoDB Schema: 
{db_schema}

### Instruct:
{text}

### Output:
"""

How to Get Started with the Model

Use the code sample provided in the original post to interact with the model.

from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    BitsAndBytesConfig,
)
import torch
from peft import PeftModel

db_schema = '''{
  "collections": [
    {
      "name": "shipwrecks",
      "indexes": [
        {
          "key": {
            "_id": 1
          }
        },
        {
          "key": {
            "feature_type": 1
          }
        },
        {
          "key": {
            "chart": 1
          }
        },
        {
          "key": {
            "latdec": 1,
            "londec": 1
          }
        }
      ],
      "uniqueIndexes": [],
      "document": {
        "properties": {
          "_id": {
            "bsonType": "string"
          },
          "recrd": {
            "bsonType": "string"
          },
          "vesslterms": {
            "bsonType": "string"
          },
          "feature_type": {
            "bsonType": "string"
          },
          "chart": {
            "bsonType": "string"
          },
          "latdec": {
            "bsonType": "double"
          },
          "londec": {
            "bsonType": "double"
          },
          "gp_quality": {
            "bsonType": "string"
          },
          "depth": {
            "bsonType": "string"
          },
          "sounding_type": {
            "bsonType": "string"
          },
          "history": {
            "bsonType": "string"
          },
          "quasou": {
            "bsonType": "string"
          },
          "watlev": {
            "bsonType": "string"
          },
          "coordinates": {
            "bsonType": "array",
            "items": {
              "bsonType": "double"
            }
          }
        }
      }
    }
  ],
  "version": 1
}'''

text = ''''Find the count of shipwrecks for each unique combination of "latdec" and "longdec"'''
prompt = f"""<s> 
        Task Description:
        Your task is to create a MongoDB query that accurately fulfills the provided Instruct while strictly adhering to the given MongoDB schema. Ensure that the query solely relies on keys and columns present in the schema. Minimize the usage of lookup operations wherever feasible to enhance query efficiency.

        MongoDB Schema: 
        {db_schema}

        ### Instruct:
        {text}

        ### Output:
        """

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
base_model_id = "microsoft/phi-2"
tokenizer = AutoTokenizer.from_pretrained(base_model_id, use_fast=True)
compute_dtype = getattr(torch, "float16")
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=compute_dtype,
    bnb_4bit_use_double_quant=True,
)
model = AutoModelForCausalLM.from_pretrained(
    base_model_id,
    trust_remote_code=True,
    quantization_config=bnb_config,
    revision="refs/pr/23",
    device_map={"": 0},
    torch_dtype="auto",
    flash_attn=True,
    flash_rotary=True,
    fused_dense=True,
)
adapter = 'Chirayu/phi-2-mongodb'

model = PeftModel.from_pretrained(model, adapter).to(device)
model_inputs = tokenizer(prompt, return_tensors="pt").to(device)
output = model.generate(
    **model_inputs,
    max_length=1024,
    no_repeat_ngram_size=10,
    repetition_penalty=1.02,
    pad_token_id=tokenizer.eos_token_id,
    eos_token_id=tokenizer.eos_token_id,
)[0]

prompt_length = model_inputs['input_ids'].shape[1]
query = tokenizer.decode(output[prompt_length:], skip_special_tokens=False)
try:
    stop_idx = query.index("</s>")
except Exception as e:
    print(e)
    stop_idx = len(query)
print(query[: stop_idx].strip())
  • PEFT 0.10.0