| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| #include <torch/extension.h>
|
| #include <vector>
|
|
|
| |
| |
| |
|
|
| #ifdef WITH_CUDA
|
| torch::Tensor bitlinear_cuda_forward(
|
| torch::Tensor x,
|
| torch::Tensor W_ternary,
|
| torch::Tensor gamma,
|
| torch::optional<torch::Tensor> bias
|
| );
|
|
|
| torch::Tensor multi_ternary_cuda_forward(
|
| torch::Tensor x,
|
| torch::Tensor W_ternary,
|
| torch::Tensor gammas,
|
| torch::optional<torch::Tensor> bias
|
| );
|
| #endif
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| torch::Tensor bitlinear_cpu_forward(
|
| torch::Tensor x,
|
| torch::Tensor W_ternary,
|
| torch::Tensor gamma,
|
| torch::optional<torch::Tensor> bias
|
| ) {
|
|
|
| auto x_shape = x.sizes().vec();
|
| int64_t batch_size = 1;
|
| for (size_t i = 0; i < x_shape.size() - 1; i++) {
|
| batch_size *= x_shape[i];
|
| }
|
| int64_t in_features = x_shape.back();
|
| int64_t out_features = W_ternary.size(0);
|
|
|
|
|
| auto x_2d = x.view({batch_size, in_features});
|
|
|
|
|
|
|
| auto output = torch::matmul(x_2d, W_ternary.t());
|
|
|
|
|
|
|
| output = output * gamma.unsqueeze(0);
|
|
|
|
|
| if (bias.has_value() && bias.value().defined()) {
|
| output = output + bias.value().unsqueeze(0);
|
| }
|
|
|
|
|
| std::vector<int64_t> out_shape(x_shape.begin(), x_shape.end() - 1);
|
| out_shape.push_back(out_features);
|
| output = output.view(out_shape);
|
|
|
| return output;
|
| }
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| torch::Tensor multi_ternary_cpu_forward(
|
| torch::Tensor x,
|
| torch::Tensor W_ternary,
|
| torch::Tensor gammas,
|
| torch::optional<torch::Tensor> bias
|
| ) {
|
|
|
|
|
| int64_t k = W_ternary.size(0);
|
| int64_t out_features = W_ternary.size(1);
|
| int64_t in_features = W_ternary.size(2);
|
|
|
|
|
| auto x_shape = x.sizes().vec();
|
| int64_t batch_size = 1;
|
| for (size_t i = 0; i < x_shape.size() - 1; i++) {
|
| batch_size *= x_shape[i];
|
| }
|
|
|
|
|
| auto x_2d = x.view({batch_size, in_features});
|
|
|
|
|
| auto output = torch::zeros({batch_size, out_features}, x.options());
|
|
|
|
|
| for (int64_t i = 0; i < k; i++) {
|
|
|
| auto W_i = W_ternary[i];
|
| auto gamma_i = gammas[i];
|
|
|
|
|
| auto component = torch::matmul(x_2d, W_i.t());
|
| component = component * gamma_i.unsqueeze(0);
|
|
|
|
|
| output = output + component;
|
| }
|
|
|
|
|
| if (bias.has_value() && bias.value().defined()) {
|
| output = output + bias.value().unsqueeze(0);
|
| }
|
|
|
|
|
| std::vector<int64_t> out_shape(x_shape.begin(), x_shape.end() - 1);
|
| out_shape.push_back(out_features);
|
| output = output.view(out_shape);
|
|
|
| return output;
|
| }
|
|
|
| |
| |
| |
| |
| |
|
|
| torch::Tensor bitlinear_forward(
|
| torch::Tensor x,
|
| torch::Tensor W_ternary,
|
| torch::Tensor gamma,
|
| torch::optional<torch::Tensor> bias
|
| ) {
|
|
|
| TORCH_CHECK(x.dim() >= 2, "Input must have at least 2 dimensions");
|
| TORCH_CHECK(W_ternary.dim() == 2, "W_ternary must be 2D");
|
| TORCH_CHECK(gamma.dim() == 1 || gamma.dim() == 2, "gamma must be 1D or 2D");
|
|
|
|
|
| if (x.is_cuda()) {
|
| #ifdef WITH_CUDA
|
| return bitlinear_cuda_forward(x, W_ternary, gamma, bias);
|
| #else
|
| AT_ERROR("BitLinear CUDA kernels not compiled. Rebuild with CUDA support.");
|
| #endif
|
| } else {
|
| return bitlinear_cpu_forward(x, W_ternary, gamma, bias);
|
| }
|
| }
|
|
|
| |
| |
|
|
| torch::Tensor multi_ternary_forward(
|
| torch::Tensor x,
|
| torch::Tensor W_ternary,
|
| torch::Tensor gammas,
|
| torch::optional<torch::Tensor> bias
|
| ) {
|
|
|
| TORCH_CHECK(x.dim() >= 2, "Input must have at least 2 dimensions");
|
| TORCH_CHECK(W_ternary.dim() == 3, "W_ternary must be 3D [k, out_features, in_features]");
|
| TORCH_CHECK(gammas.dim() == 2, "gammas must be 2D [k, out_features]");
|
|
|
|
|
| if (x.is_cuda()) {
|
| #ifdef WITH_CUDA
|
| return multi_ternary_cuda_forward(x, W_ternary, gammas, bias);
|
| #else
|
| AT_ERROR("Multi-ternary CUDA kernels not compiled. Rebuild with CUDA support.");
|
| #endif
|
| } else {
|
| return multi_ternary_cpu_forward(x, W_ternary, gammas, bias);
|
| }
|
| }
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| torch::Tensor pack_ternary_base3_cpp(torch::Tensor W_ternary) {
|
|
|
| auto flat = W_ternary.flatten().to(torch::kCPU).to(torch::kInt8);
|
| int64_t numel = flat.numel();
|
|
|
|
|
| auto mapped = (flat + 1).to(torch::kUInt8);
|
|
|
|
|
| int64_t packed_size = (numel + 4) / 5;
|
| auto packed = torch::zeros({packed_size}, torch::dtype(torch::kUInt8).device(torch::kCPU));
|
|
|
|
|
| auto mapped_ptr = mapped.data_ptr<uint8_t>();
|
| auto packed_ptr = packed.data_ptr<uint8_t>();
|
|
|
|
|
| const uint8_t powers[5] = {1, 3, 9, 27, 81};
|
|
|
|
|
| for (int64_t i = 0; i < packed_size; i++) {
|
| int64_t base_idx = i * 5;
|
| uint8_t packed_val = 0;
|
|
|
| for (int j = 0; j < 5; j++) {
|
| int64_t idx = base_idx + j;
|
| if (idx < numel) {
|
| packed_val += mapped_ptr[idx] * powers[j];
|
| } else {
|
|
|
| packed_val += 1 * powers[j];
|
| }
|
| }
|
| packed_ptr[i] = packed_val;
|
| }
|
|
|
| return packed;
|
| }
|
|
|
| |
| |
| |
| |
| |
| |
|
|
| torch::Tensor unpack_ternary_base3_cpp(
|
| torch::Tensor packed,
|
| std::vector<int64_t> original_shape
|
| ) {
|
|
|
| int64_t numel = 1;
|
| for (auto dim : original_shape) {
|
| numel *= dim;
|
| }
|
|
|
|
|
| auto packed_flat = packed.flatten().to(torch::kCPU).to(torch::kUInt8);
|
| int64_t packed_size = packed_flat.numel();
|
|
|
|
|
| auto unpacked = torch::zeros({numel}, torch::dtype(torch::kInt8).device(torch::kCPU));
|
|
|
|
|
| auto packed_ptr = packed_flat.data_ptr<uint8_t>();
|
| auto unpacked_ptr = unpacked.data_ptr<int8_t>();
|
|
|
|
|
| int64_t out_idx = 0;
|
| for (int64_t i = 0; i < packed_size && out_idx < numel; i++) {
|
| uint8_t packed_val = packed_ptr[i];
|
|
|
|
|
| for (int j = 0; j < 5 && out_idx < numel; j++) {
|
| uint8_t val = packed_val % 3;
|
| packed_val /= 3;
|
|
|
|
|
| unpacked_ptr[out_idx] = static_cast<int8_t>(val) - 1;
|
| out_idx++;
|
| }
|
| }
|
|
|
|
|
| return unpacked.view(original_shape).to(torch::kFloat32);
|
| }
|
|
|
| |
| |
| |
| |
| |
| |
|
|
| PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
| m.def("forward", &bitlinear_forward, "BitLinear forward (CPU/CUDA)",
|
| py::arg("x"),
|
| py::arg("W_ternary"),
|
| py::arg("gamma"),
|
| py::arg("bias") = py::none());
|
|
|
| m.def("multi_ternary_forward", &multi_ternary_forward,
|
| "Multi-ternary linear forward (CPU/CUDA)",
|
| py::arg("x"),
|
| py::arg("W_ternary"),
|
| py::arg("gammas"),
|
| py::arg("bias") = py::none());
|
|
|
| m.def("pack_ternary_base3", &pack_ternary_base3_cpp,
|
| "Pack ternary weights to base-3 (CPU)",
|
| py::arg("W_ternary"));
|
|
|
| m.def("unpack_ternary_base3", &unpack_ternary_base3_cpp,
|
| "Unpack base-3 ternary weights (CPU)",
|
| py::arg("packed"),
|
| py::arg("original_shape"));
|
| }
|
|
|