Skip to content

ucb-bar/Understanding-PI0

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

9 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Understanding PI0

This repository contains some basic scripts to understand the PI0 policy.

Setting up the repository

We use the Huggingface implementation of PI0 policy. In order to look into the source code, it is recommended to pull the lerobot and transformers source code.

git clone https://github.com/huggingface/lerobot.git

Note that to use transformers with lerobot, we need to checkout to this specific branch:

git clone https://github.com/huggingface/transformers.git
cd ./transformers/
git checkout fix/lerobot_openpi

PI0 Model

This is the structure of the PI0Policy:

PI0Policy(
  (model): PI0Pytorch(
    (paligemma_with_expert): PaliGemmaWithExpertModel(
      (paligemma): PaliGemmaForConditionalGeneration(
        (model): PaliGemmaModel(
          (vision_tower): SiglipVisionModel(
            (vision_model): SiglipVisionTransformer(
              (embeddings): SiglipVisionEmbeddings(
                (patch_embedding): Conv2d(3, 1152, kernel_size=(14, 14), stride=(14, 14), padding=valid)
                (position_embedding): Embedding(256, 1152)
              )
              (encoder): SiglipEncoder(
                (layers): ModuleList(
                  (0-26): 27 x SiglipEncoderLayer(
                    (layer_norm1): LayerNorm((1152,), eps=1e-06, elementwise_affine=True)
                    (self_attn): SiglipAttention(
                      (k_proj): Linear(in_features=1152, out_features=1152, bias=True)
                      (v_proj): Linear(in_features=1152, out_features=1152, bias=True)
                      (q_proj): Linear(in_features=1152, out_features=1152, bias=True)
                      (out_proj): Linear(in_features=1152, out_features=1152, bias=True)
                    )
                    (layer_norm2): LayerNorm((1152,), eps=1e-06, elementwise_affine=True)
                    (mlp): SiglipMLP(
                      (activation_fn): PytorchGELUTanh()
                      (fc1): Linear(in_features=1152, out_features=4304, bias=True)
                      (fc2): Linear(in_features=4304, out_features=1152, bias=True)
                    )
                  )
                )
              )
              (post_layernorm): LayerNorm((1152,), eps=1e-06, elementwise_affine=True)
            )
          )
          (multi_modal_projector): PaliGemmaMultiModalProjector(
            (linear): Linear(in_features=1152, out_features=2048, bias=True)
          )
          (language_model): GemmaModel(
            (embed_tokens): Embedding(257152, 2048, padding_idx=0)
            (layers): ModuleList(
              (0-17): 18 x GemmaDecoderLayer(
                (self_attn): GemmaAttention(
                  (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
                  (k_proj): Linear(in_features=2048, out_features=256, bias=False)
                  (v_proj): Linear(in_features=2048, out_features=256, bias=False)
                  (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
                )
                (mlp): GemmaMLP(
                  (gate_proj): Linear(in_features=2048, out_features=16384, bias=False)
                  (up_proj): Linear(in_features=2048, out_features=16384, bias=False)
                  (down_proj): Linear(in_features=16384, out_features=2048, bias=False)
                  (act_fn): PytorchGELUTanh()
                )
                (input_layernorm): GemmaRMSNorm((2048,), eps=1e-06)
                (post_attention_layernorm): GemmaRMSNorm((2048,), eps=1e-06)
              )
            )
            (norm): GemmaRMSNorm((2048,), eps=1e-06)
            (rotary_emb): GemmaRotaryEmbedding()
          )
        )
        (lm_head): Linear(in_features=2048, out_features=257152, bias=False)
      )
      (gemma_expert): GemmaForCausalLM(
        (model): GemmaModel(
          (embed_tokens): None
          (layers): ModuleList(
            (0-17): 18 x GemmaDecoderLayer(
              (self_attn): GemmaAttention(
                (q_proj): Linear(in_features=1024, out_features=2048, bias=False)
                (k_proj): Linear(in_features=1024, out_features=256, bias=False)
                (v_proj): Linear(in_features=1024, out_features=256, bias=False)
                (o_proj): Linear(in_features=2048, out_features=1024, bias=False)
              )
              (mlp): GemmaMLP(
                (gate_proj): Linear(in_features=1024, out_features=4096, bias=False)
                (up_proj): Linear(in_features=1024, out_features=4096, bias=False)
                (down_proj): Linear(in_features=4096, out_features=1024, bias=False)
                (act_fn): PytorchGELUTanh()
              )
              (input_layernorm): GemmaRMSNorm((1024,), eps=1e-06)
              (post_attention_layernorm): GemmaRMSNorm((1024,), eps=1e-06)
            )
          )
          (norm): GemmaRMSNorm((1024,), eps=1e-06)
          (rotary_emb): GemmaRotaryEmbedding()
        )
        (lm_head): Linear(in_features=1024, out_features=257152, bias=False)
      )
    )
    (action_in_proj): Linear(in_features=16, out_features=1024, bias=True)
    (action_out_proj): Linear(in_features=1024, out_features=16, bias=True)
    (state_proj): Linear(in_features=32, out_features=1024, bias=True)
    (action_time_mlp_in): Linear(in_features=2048, out_features=1024, bias=True)
    (action_time_mlp_out): Linear(in_features=1024, out_features=1024, bias=True)
  )
)

The inference can be separated into three parts. The vision_tower handles the encoding of image from the input RGB space into token space. language_model contains the main VLM model, and the gemma_expert is the smaller action expert that generates the target position using flow-matching process.

Total model parameter is 3,501,339,392 (3.50 B), in which the SigLIP vision model accounts for 412,442,352 (412.44 M), Gemma language model takes up 2,508,531,712 (2508.53 M), and the flow-matching action expert model accounts for 574,788,608 (574.79 M) parameters.

The total required amount of FLOPs is 4,354,614,038,072 (4.35 T) for one inference step.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published