Integrate with Sentence Transformers v5.4

#8
by tomaarsen HF Staff - opened

Hello!

Pull Request overview

  • Integrate zerank-2 with Sentence Transformers (v5.4+) so the model loads via the standard CrossEncoder API with no extra kwargs.

Details

This PR supersedes #2, which requires trust_remote_code. Instead, this integration uses the stock CrossEncoder + causal-LM pipeline Transformer(text-generation) -> LogitScore with new (ST v5.4+) built-in classes only, so no trust_remote_code=True is required and no custom modeling code is added on top of what already ships in the repo.

LogitScore returns the raw "Yes" logit at the last position. Rankings can be used directly. To recover the 0-1 score range that the original predict() produces, callers can apply the temperature-scaled sigmoid sigmoid(score / 5) themselves; the README shows the one-liner. Keeping this transformation client-side avoids needing a custom score head and keeps the integration purely config-driven.

chat_template.jinja gets a small new branch at the top: when the input messages carry query / document roles (the convention Sentence Transformers passes to a CrossEncoder), it renders them as <|im_start|>system\n{query}<|im_end|>\n<|im_start|>user\n{document}<|im_end|>\n<|im_start|>assistant\n, which is byte-for-byte identical to the chat-templated string the original format_pointwise_datapoints produces. Any other role configuration falls through to the original Qwen3 logic untouched, so direct tokenizer.apply_chat_template(...) usage is unaffected.

CrossEncoder("zeroentropy/zerank-2").predict(...) loads in bf16 automatically (picked up from config.json) and matches the reference baseline numbers for the README example: [5.40625, -4.5] raw, [0.7461, 0.2891] after (scores / 5).sigmoid().

Added files:

  • modules.json: the Transformer -> LogitScore pipeline using ST's built-in classes (sentence_transformers.base.modules.transformer.Transformer and sentence_transformers.cross_encoder.modules.logit_score.LogitScore).
  • sentence_bert_config.json: declares transformer_task: text-generation, the text / message modality config (with format: flat), module_output_name: causal_logits, and processing_kwargs.chat_template.add_generation_prompt: true.
  • config_sentence_transformers.json: model_type: CrossEncoder with activation_fn: torch.nn.modules.linear.Identity and prompts: {}.
  • 1_LogitScore/config.json: stores true_token_id: 9454 (the "Yes" token), with false_token_id left null so the score is the raw "Yes" logit rather than a "Yes"-vs-"No" log-odds.

Modified files:

  • chat_template.jinja: prepended a query / document branch as described above; the original Qwen3 ChatML logic is preserved unchanged below it.
  • README.md: Expanded the "How to Use" section with a pip install, expected predict output, the sigmoid(score / 5) post-processing one-liner, and a short model.rank example.

The main changes

from sentence_transformers import CrossEncoder

model = CrossEncoder("zeroentropy/zerank-2", revision="refs/pr/8")

query_documents = [
    ("What is 2+2?", "4"),
    ("What is 2+2?", "The answer is definitely 1 million"),
]

scores = model.predict(query_documents, convert_to_tensor=True)
print(scores)
# tensor([ 5.4062, -4.5000], device='cuda:0', dtype=torch.bfloat16)

probabilities = (scores / 5).sigmoid()
print(probabilities)
# tensor([0.7461, 0.2891], device='cuda:0', dtype=torch.bfloat16)

Note that the revision="refs/pr/8" means that you can already test the PR branch without having to clone or anything.

A few small deltas vs. the original predict() to be aware of (rankings are unaffected in all of them):

  • Padding side: Sentence Transformers forces padding_side="left" for causal LMsm, while the original modeling_zeranker.py uses padding_side="right". This slightly shifts RoPE position IDs on padded items in a batch, drifting per-pair scores by up to ~0.01 on the shortest item in a mixed-length batch. The longest item in any batch is bit-identical. A custom Transformer subclass that flips back to right-padding can match exactly, at the cost of needing trust_remote_code=True.
  • Sigmoid lives client-side: the original predict() returns sigmoid(yes_logit / 5) directly; here the caller applies (scores / 5).sigmoid(). The scaled sigmoid isn't possible in non-trust_remote_code Sentence Transformers, but can also be fixed to output 0...1 if we're going the trust_remote_code=True route.

I kept the existing modeling_zeranker.py, although it's not being used right now. Feel free to remove it also.

Please let me know if you have any questions or feedback!

  • Tom Aarsen
tomaarsen changed pull request status to open
Ready to merge
This branch is ready to get merged automatically.

Sign up or log in to comment