SFT-Emb-8B

Paper Model

This repository provides the inference implementation for SFT-Emb, a supervised fine-tuned embedding model serving as a baseline retriever in the MiA-RAG framework.

Unlike MiA-Emb, which conditions on both the query and a global summary (Mindscape), SFT-Emb operates on the query alone β€” without any global summary or residual connection. This makes it a standard retrieval baseline that does not leverage document-level semantic scaffolding.


✨ Key Features

  • Standard Query-Only Retrieval
    Encodes queries without any global summary, serving as a strong SFT baseline for comparison with Mindscape-aware models.

  • Dual-Granularity Retrieval

    • Chunk Retrieval for narrative passages (standard RAG)
    • Node Retrieval for knowledge graph entities (GraphRAG-style)
  • Same Architecture, Simpler Input
    Built on the same Qwen3-Embedding-8B backbone and LoRA fine-tuning as MiA-Emb, but without the Mindscape summary injection or residual embedding mechanism.


πŸš€ Usage

Installation

pip install torch transformers>=4.53.0

1) Initialization

SFT-Emb-8B is initialized from Qwen3-Embedding-8B.

import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel

# Configuration
device = "cuda" if torch.cuda.is_available() else "cpu"

# Inference Parameters
node_delimiter = "<|repo_name|>"  # Special token for Node tasks

# Load Tokenizer (base)
tokenizer = AutoTokenizer.from_pretrained(
    "Qwen/Qwen3-Embedding-8B",
    trust_remote_code=True,
    padding_side="left"
)

# Load Model
model = AutoModel.from_pretrained(
    "MindscapeRAG/SFT-Emb-8B",
    trust_remote_code=True,
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
    device_map={"": 0}
)

2) Chunk Retrieval

Use this mode to retrieve narrative text chunks. The query is encoded without any global summary.

def get_query_prompt(query):
    """Construct input prompt (query-only, no summary)."""
    task_desc = "Given a search query, retrieve relevant chunks or helpful entities summaries from the given context that answer the query"
    return (
        f"Instruct: {task_desc}\n"
        f"Query: {query}{node_delimiter}"
    )

def last_token_pool(last_hidden_states, attention_mask):
    """Extract the last non-padding token embedding."""
    left_padding = attention_mask[:, -1].sum() == attention_mask.shape[0]
    if left_padding:
        return last_hidden_states[:, -1]
    sequence_lengths = attention_mask.sum(dim=1) - 1
    batch_size = last_hidden_states.shape[0]
    return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]

def encode_chunk(texts):
    batch = tokenizer(
        texts,
        max_length=4096,
        padding=True,
        truncation=True,
        return_tensors="pt"
    ).to(model.device)

    outputs = model(**batch)

    # Embedding (Last Token)
    emb = last_token_pool(outputs.last_hidden_state, batch["attention_mask"])
    emb = F.normalize(emb, p=2, dim=-1)
    return emb


# --- Example ---
query = "Who is the protagonist?"
chunk = "Harry looked at the scar on his forehead."

# Encode
q_emb = encode_chunk([get_query_prompt(query)])
c_emb = encode_chunk([chunk])

# Score
score = q_emb @ c_emb.T
print(f"Chunk Similarity: {score.item():.4f}")

3) Node Retrieval

SFT-Emb can retrieve knowledge graph entities (Nodes). This mode extracts embeddings from the <|repo_name|> token position.

Candidate format: Entity Name : Entity Description

Example: Mary Campbell Smith : Mary Campbell Smith is mentioned as the translator...

def extract_specific_token(outputs, batch, token_id):
    """Extract embedding at the position of a specific token."""
    input_ids = batch["input_ids"]
    hidden = outputs.last_hidden_state
    mask = (input_ids == token_id)
    # Take the last occurrence of the token for each sample
    positions = mask.long().cumsum(dim=1).eq(mask.long().sum(dim=1, keepdim=True)) & mask
    return hidden[positions]

def encode_node_query(texts, node_delimiter="<|repo_name|>"):
    batch = tokenizer(texts, padding=True, return_tensors="pt").to(model.device)
    outputs = model(**batch)

    # Node Main Embedding: extract from <|repo_name|> position
    node_id = tokenizer.encode(node_delimiter, add_special_tokens=False)[0]
    q_emb_node = extract_specific_token(outputs, batch, node_id)
    q_emb_node = F.normalize(q_emb_node, p=2, dim=-1)
    return q_emb_node


# --- Example ---
query = "Who is the protagonist?"

# 1) Encode Query (Node Token)
q_emb_node = encode_node_query([get_query_prompt(query)])

# 2) Encode Entity Candidate
entity_text = "Harry Potter : The main protagonist of the series..."
n_emb = encode_chunk([entity_text])

# 3) Score
score = q_emb_node @ n_emb.T
print(f"Node Similarity: {score.item():.4f}")

πŸ“œ Citation

If you find this work useful, please cite:

@misc{li2025mindscapeawareretrievalaugmentedgeneration,
      title={Mindscape-Aware Retrieval Augmented Generation for Improved Long Context Understanding}, 
      author={Yuqing Li and Jiangnan Li and Zheng Lin and Ziyan Zhou and Junjie Wu and Weiping Wang and Jie Zhou and Mo Yu},
      year={2025},
      eprint={2512.17220},
      archivePrefix={arXiv},
      primaryClass={cs.CL},
      url={https://arxiv.org/abs/2512.17220}, 
}

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Model tree for MindscapeRAG/SFT-Emb-8B

Finetuned
(23)
this model

Paper for MindscapeRAG/SFT-Emb-8B