| import torch |
| from torch import nn |
| from tqdm import tqdm |
| from torch.nn import functional as F |
| from transformers import ( |
| set_seed, pipeline, AutoTokenizer, AutoModelForCausalLM |
| ) |
|
|
| EMBEDDING = """ |
| You are a helpful AI assistant. Your task is to analyze input text and create a high-quality semantic vector embedding, which represents key concepts, relationships, and semantic meaning. |
| """ |
| GENERATION = """ |
| You are a helpful AI assistant. Your task is to enrich user input for more effective embedding representation by adding semantic depth. |
| |
| For each input, briefly enhance the content by: |
| 1. Identifying core concepts and their relationships. |
| 2. Including key terminology with essential definitions. |
| 3. Adding contextually relevant synonyms and related terms. |
| 4. Connecting to related topics and common applications without excessive elaboration. |
| |
| To represent the final embedding, you MUST end every response with <|embed_token|>. |
| """ |
|
|
|
|
| class SearchR3(nn.Module): |
| def __init__(self, |
| path: str, |
| max_length: int, |
| batch_size: int): |
| nn.Module.__init__(self) |
| |
| self.model = AutoModelForCausalLM.from_pretrained( |
| path, torch_dtype='auto', device_map='auto' |
| ) |
| self.tokenizer = AutoTokenizer.from_pretrained( |
| path, truncation_side='left', padding_side='left' |
| ) |
| self.embed_token = self.tokenizer.encode('<|embed_token|>')[0] |
| self.max_length = max_length |
| self.batch_size = batch_size |
|
|
| @property |
| def device(self): |
| return next(self.model.parameters()).device |
|
|
| @torch.no_grad() |
| def generate(self, batch: list[str]): |
| if not isinstance(batch, (list, tuple)): |
| raise ValueError('batch type is incorrect') |
| if any(not isinstance(v, str) for v in batch): |
| raise ValueError('batch item type is incorrect') |
|
|
| |
| if len(batch) > self.batch_size: |
| outputs = [] |
| for i in tqdm( |
| range(0, len(batch), self.batch_size) |
| ): |
| outputs.extend( |
| self.generate( |
| batch[i:i + self.batch_size] |
| ) |
| ) |
| return outputs |
|
|
| |
| messages = [ |
| [ |
| {'role': 'system', 'content': GENERATION.strip()}, |
| {'role': 'user', 'content': item} |
| ] |
| for item in batch |
| ] |
| context = self.tokenizer.apply_chat_template( |
| messages, tokenize=False, add_generation_prompt=True |
| ) |
| inputs = self.tokenizer( |
| context, padding='longest', truncation=True, |
| return_tensors='pt', max_length=self.max_length // 2 |
| ) |
| prompt_length = inputs['input_ids'].size(-1) |
|
|
| |
| self.model.eval() |
| outputs = self.model.generate( |
| **inputs.to(device=self.device), |
| max_new_tokens=self.max_length - prompt_length |
| ) |
| outputs = self.tokenizer.batch_decode( |
| outputs[:, prompt_length:], skip_special_tokens=False |
| ) |
|
|
| |
| for special_token in self.tokenizer.all_special_tokens: |
| if special_token == '<|embed_token|>': |
| continue |
| outputs = [ |
| item.replace(special_token, '') for item in outputs |
| ] |
| messages = [ |
| item + [ |
| {'role': 'assistant', 'content': outputs[i].strip()} |
| ] |
| for i, item in enumerate(messages) |
| ] |
| return messages |
|
|
| def format(self, batch: list[str]): |
| if any(not isinstance(v, str) for v in batch): |
| raise RuntimeError('batch type is incorrect') |
| return [ |
| [ |
| {'role': 'system', 'content': EMBEDDING.strip()}, |
| {'role': 'user', 'content': item}, |
| {'role': 'assistant', 'content': 'The embedding is: <|embed_token|>'} |
| ] |
| for item in batch |
| ] |
|
|
| @torch.no_grad() |
| def encode(self, batch: list[any]): |
| if not isinstance(batch, (list, tuple)): |
| raise ValueError('batch type is incorrect') |
|
|
| |
| if len(batch) > self.batch_size: |
| outputs = [ |
| self.encode( |
| batch[i:i + self.batch_size] |
| ) |
| for i in tqdm( |
| range(0, len(batch), self.batch_size) |
| ) |
| ] |
| return torch.cat(outputs, dim=0) |
|
|
| |
| if all(isinstance(v, str) for v in batch): |
| batch = self.format(batch=batch) |
|
|
| |
| if any( |
| m[-1]['role'] != 'assistant' for m in batch |
| ): |
| raise RuntimeError('unexpected role') |
| if any( |
| m[-2]['role'] != 'user' for m in batch |
| ): |
| raise RuntimeError('unexpected role') |
|
|
| |
| batch = [ |
| m if '<|embed_token|>' in m[-1]['content'] |
| else self.format([m[-2]['content']])[0] |
| for m in batch |
| ] |
| if any( |
| '<|embed_token|>' not in m[-1]['content'] for m in batch |
| ): |
| raise RuntimeError('unexpected embed token') |
|
|
| |
| context = self.tokenizer.apply_chat_template( |
| batch, tokenize=False, add_generation_prompt=False |
| ) |
| inputs = self.tokenizer( |
| context, padding='longest', truncation=True, |
| return_tensors='pt', max_length=self.max_length |
| ) |
|
|
| |
| self.model.eval() |
| outputs = self.model( |
| **inputs.to(device=self.device), |
| return_dict=True, output_hidden_states=True |
| ) |
| hidden_state = outputs['hidden_states'][-1] |
|
|
| |
| length = inputs['input_ids'].size(-1) |
| valid_mask = torch.arange(length, device=self.device) |
| valid_mask = torch.where( |
| valid_mask.unsqueeze(0) > length - 5, True, False |
| ) |
| embed_mask = torch.where( |
| inputs['input_ids'] == self.embed_token, True, False |
| ) |
| embed_mask = embed_mask.logical_and(valid_mask) |
| return F.normalize( |
| hidden_state[embed_mask].cpu().float(), dim=-1 |
| ) |
|
|
|
|
| def main(): |
| |
| set_seed(42) |
| from pprint import pprint |
|
|
| |
| generator = pipeline( |
| task='text-generation', |
| model='ytgui/Search-R3.0-Small', |
| torch_dtype='auto', device_map='auto' |
| ) |
| messages = [ |
| {"role": 'user', 'content': 'Who are you?'}, |
| ] |
| response = generator(messages, max_new_tokens=256) |
| pprint(response) |
|
|
| |
| model = SearchR3( |
| 'ytgui/Search-R3.0-Small', max_length=1024, batch_size=8 |
| ) |
| reasoning = model.generate( |
| batch=['what python library is useful for data analysis?'] |
| ) |
| pprint(reasoning) |
|
|
| |
| documents = [ |
| 'pandas is a fast, powerful, flexible and easy to use open source data analysis and manipulation tool, built on top of the Python programming language.', |
| 'The giant panda (Ailuropoda melanoleuca), also known as the panda bear or simply panda, is a bear species endemic to China. It is characterised by its white coat with black patches around the eyes, ears, legs and shoulders.', |
| ] |
| E_d = model.encode(batch=documents) |
| E_q = model.encode(batch=reasoning) |
| print('distance:', torch.cdist(E_q, E_d, p=2.0)) |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|