not-lain's picture
add installation guide
231cf09 verified
|
raw
history blame
2.34 kB
metadata
library_name: gemma_torch
tags:
  - pytorch
extra_gated_heading: Access CodeGemma on Hugging Face
extra_gated_prompt: >-
  To access CodeGemma on Hugging Face, you’re required to review and agree to
  Google’s usage license. To do this, please ensure you’re logged-in to Hugging
  Face and click below. Requests are processed immediately.
extra_gated_button_content: Acknowledge license
license: gemma
license_link: https://ai.google.dev/gemma/terms
pipeline_tag: text-generation

CodeGemma Model Card

This repository corresponds to the CodeGemma 7B IT checkpoint for use with Gemma PyTorch. If you're looking for the transformers implementation, or more detailed model card, visit https://huggingface.co/google/codegemma-7b-it.

Model Page: CodeGemma

Resources and Technical Documentation:

Terms of Use: Terms

Authors: Google

Sample Usage

pip install git+https://github.com/google/gemma_pytorch.git
from gemma.config import GemmaConfig, get_config_for_7b, get_config_for_2b
from gemma.model import GemmaForCausalLM
from gemma.tokenizer import Tokenizer
import contextlib
import os
import torch

VARIANT = "7b-it" 
MACHINE_TYPE = "cpu" 
weights_dir = 'codegemma-7b-it-pytorch' 

@contextlib.contextmanager
def _set_default_tensor_type(dtype: torch.dtype):
  """Sets the default torch dtype to the given dtype."""
  torch.set_default_dtype(dtype)
  yield
  torch.set_default_dtype(torch.float)

model_config = get_config_for_2b() if "2b" in VARIANT else get_config_for_7b()
model_config.tokenizer = os.path.join(weights_dir, "tokenizer.model")

device = torch.device(MACHINE_TYPE)
with _set_default_tensor_type(model_config.get_dtype()):
  model = GemmaForCausalLM(model_config)
  ckpt_path = os.path.join(weights_dir, f'codegemma-{VARIANT}.pt')
  model.load_weights(ckpt_path)
  model = model.to(device).eval()

PROMPT = """<start_of_turn>user
Write a Python function to calculate the nth fibonacci number.<end_of_turn>
<start_of_turn>model
"""

model.generate(
    PROMPT,
    device=device,
    output_len=100,
)