| import safetensors.torch |
| import torch |
| import sys |
|
|
| |
|
|
| cast_to = None |
| if "fp8_e4m3fn" in sys.argv[1]: |
| cast_to = torch.float8_e4m3fn |
| elif "fp16" in sys.argv[1]: |
| cast_to = torch.float16 |
| elif "bf16" in sys.argv[1]: |
| cast_to = torch.bfloat16 |
|
|
| replace_keys = {"all_final_layer.2-1.": "final_layer.", |
| "all_x_embedder.2-1.": "x_embedder.", |
| ".attention.to_out.0.bias": ".attention.out.bias", |
| ".attention.norm_k.weight": ".attention.k_norm.weight", |
| ".attention.norm_q.weight": ".attention.q_norm.weight", |
| ".attention.to_out.0.weight": ".attention.out.weight" |
| } |
|
|
| out_sd = {} |
| for f in sys.argv[2:]: |
| sd = safetensors.torch.load_file(f) |
| cc = None |
| for k in sd: |
| w = sd[k] |
|
|
| if cast_to is not None: |
| w = w.to(cast_to) |
| k_out = k |
| if k_out.endswith(".attention.to_out.0.bias"): |
| continue |
| if k_out.endswith(".attention.to_k.weight"): |
| cc = [w] |
| continue |
| if k_out.endswith(".attention.to_q.weight"): |
| cc = [w] + cc |
| continue |
| if k_out.endswith(".attention.to_v.weight"): |
| cc = cc + [w] |
| w = torch.cat(cc, dim=0) |
| k_out = k_out.replace(".attention.to_v.weight", ".attention.qkv.weight") |
|
|
| for r, rr in replace_keys.items(): |
| k_out = k_out.replace(r, rr) |
| out_sd[k_out] = w |
|
|
|
|
|
|
| safetensors.torch.save_file(out_sd, sys.argv[1]) |
|
|