SFT-Emb-8B
This repository provides the inference implementation for SFT-Emb, a supervised fine-tuned embedding model serving as a baseline retriever in the MiA-RAG framework.
Unlike MiA-Emb, which conditions on both the query and a global summary (Mindscape), SFT-Emb operates on the query alone β without any global summary or residual connection. This makes it a standard retrieval baseline that does not leverage document-level semantic scaffolding.
β¨ Key Features
Standard Query-Only Retrieval
Encodes queries without any global summary, serving as a strong SFT baseline for comparison with Mindscape-aware models.Dual-Granularity Retrieval
- Chunk Retrieval for narrative passages (standard RAG)
- Node Retrieval for knowledge graph entities (GraphRAG-style)
Same Architecture, Simpler Input
Built on the same Qwen3-Embedding-8B backbone and LoRA fine-tuning as MiA-Emb, but without the Mindscape summary injection or residual embedding mechanism.
π Usage
Installation
pip install torch transformers>=4.53.0
1) Initialization
SFT-Emb-8B is initialized from
Qwen3-Embedding-8B.
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel
# Configuration
device = "cuda" if torch.cuda.is_available() else "cpu"
# Inference Parameters
node_delimiter = "<|repo_name|>" # Special token for Node tasks
# Load Tokenizer (base)
tokenizer = AutoTokenizer.from_pretrained(
"Qwen/Qwen3-Embedding-8B",
trust_remote_code=True,
padding_side="left"
)
# Load Model
model = AutoModel.from_pretrained(
"MindscapeRAG/SFT-Emb-8B",
trust_remote_code=True,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
device_map={"": 0}
)
2) Chunk Retrieval
Use this mode to retrieve narrative text chunks. The query is encoded without any global summary.
def get_query_prompt(query):
"""Construct input prompt (query-only, no summary)."""
task_desc = "Given a search query, retrieve relevant chunks or helpful entities summaries from the given context that answer the query"
return (
f"Instruct: {task_desc}\n"
f"Query: {query}{node_delimiter}"
)
def last_token_pool(last_hidden_states, attention_mask):
"""Extract the last non-padding token embedding."""
left_padding = attention_mask[:, -1].sum() == attention_mask.shape[0]
if left_padding:
return last_hidden_states[:, -1]
sequence_lengths = attention_mask.sum(dim=1) - 1
batch_size = last_hidden_states.shape[0]
return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
def encode_chunk(texts):
batch = tokenizer(
texts,
max_length=4096,
padding=True,
truncation=True,
return_tensors="pt"
).to(model.device)
outputs = model(**batch)
# Embedding (Last Token)
emb = last_token_pool(outputs.last_hidden_state, batch["attention_mask"])
emb = F.normalize(emb, p=2, dim=-1)
return emb
# --- Example ---
query = "Who is the protagonist?"
chunk = "Harry looked at the scar on his forehead."
# Encode
q_emb = encode_chunk([get_query_prompt(query)])
c_emb = encode_chunk([chunk])
# Score
score = q_emb @ c_emb.T
print(f"Chunk Similarity: {score.item():.4f}")
3) Node Retrieval
SFT-Emb can retrieve knowledge graph entities (Nodes). This mode extracts embeddings from the <|repo_name|> token position.
Candidate format:
Entity Name : Entity Description
Example:
Mary Campbell Smith : Mary Campbell Smith is mentioned as the translator...
def extract_specific_token(outputs, batch, token_id):
"""Extract embedding at the position of a specific token."""
input_ids = batch["input_ids"]
hidden = outputs.last_hidden_state
mask = (input_ids == token_id)
# Take the last occurrence of the token for each sample
positions = mask.long().cumsum(dim=1).eq(mask.long().sum(dim=1, keepdim=True)) & mask
return hidden[positions]
def encode_node_query(texts, node_delimiter="<|repo_name|>"):
batch = tokenizer(texts, padding=True, return_tensors="pt").to(model.device)
outputs = model(**batch)
# Node Main Embedding: extract from <|repo_name|> position
node_id = tokenizer.encode(node_delimiter, add_special_tokens=False)[0]
q_emb_node = extract_specific_token(outputs, batch, node_id)
q_emb_node = F.normalize(q_emb_node, p=2, dim=-1)
return q_emb_node
# --- Example ---
query = "Who is the protagonist?"
# 1) Encode Query (Node Token)
q_emb_node = encode_node_query([get_query_prompt(query)])
# 2) Encode Entity Candidate
entity_text = "Harry Potter : The main protagonist of the series..."
n_emb = encode_chunk([entity_text])
# 3) Score
score = q_emb_node @ n_emb.T
print(f"Node Similarity: {score.item():.4f}")
π Citation
If you find this work useful, please cite:
@misc{li2025mindscapeawareretrievalaugmentedgeneration,
title={Mindscape-Aware Retrieval Augmented Generation for Improved Long Context Understanding},
author={Yuqing Li and Jiangnan Li and Zheng Lin and Ziyan Zhou and Junjie Wu and Weiping Wang and Jie Zhou and Mo Yu},
year={2025},
eprint={2512.17220},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2512.17220},
}