Model Card for Model ID

The code to create this checkpoint is based on https://huggingface.co/tiny-random/gemma-4-e with a few small changes:

import json
import os

import torch
from huggingface_hub import hf_hub_download

from transformers import (
    AutoConfig,
    AutoProcessor,
    Gemma4ForConditionalGeneration,
    set_seed,
)

source_model_id = "google/gemma-4-E4B"
save_folder = "/tmp/peft/tiny-random-gemma4"

processor = AutoProcessor.from_pretrained(source_model_id)


with open(
    hf_hub_download(source_model_id, filename="config.json", repo_type="model"), "r", encoding="utf-8",
) as f:
    config_json = json.load(f)

config_json["audio_config"].update(
    {
        "num_attention_heads": 2,
        "num_hidden_layers": 2,
        "hidden_size": 64,
        "output_proj_dims": 32,
    }
)
config_json["text_config"].update(
    {
        "global_head_dim": 64,
        "head_dim": 32,
        "hidden_size": 8,
        "hidden_size_per_layer_input": 2,
        "intermediate_size": 64,
        "layer_types": [
            "sliding_attention",
            "full_attention",
            "sliding_attention",
            "full_attention",
        ],
        "num_attention_heads": 8,
        "num_hidden_layers": 4,
        "num_key_value_heads": 4,
        "num_kv_shared_layers": 2,
    }
)
config_json["vision_config"].update(
    {
        "num_hidden_layers": 2,
        "hidden_size": 8,
        "intermediate_size": 64,
        "head_dim": 32,
        "global_head_dim": 32,
        "num_attention_heads": 4,
        "num_key_value_heads": 4,
    }
)

with open(f"{save_folder}/config.json", "w", encoding="utf-8") as f:
    json.dump(config_json, f, indent=2)
config = AutoConfig.from_pretrained(save_folder)

torch.set_default_dtype(torch.bfloat16)
model = Gemma4ForConditionalGeneration(config)
torch.set_default_dtype(torch.float32)
set_seed(42)
model = model.cpu()

all_numels = 0
for name, p in sorted(model.named_parameters()):
    all_numels += p.numel()
with torch.no_grad():
    for name, p in sorted(model.named_parameters()):
        torch.nn.init.normal_(p, 0, 0.2)
        print(name, p.shape, f"{p.numel() / all_numels * 100: .4f}%")


token = os.environ.get("HF_TOKEN")
processor.push_to_hub("peft-internal-testing/tiny-random-gemma4-E2B", token=token)
model.push_to_hub("peft-internal-testing/tiny-random-gemma4-E2B", token=token)
Downloads last month
-
Safetensors
Model size
4.72M params
Tensor type
BF16
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support