| import numpy as np |
| import torch |
| from transformers import AutoTokenizer, Pipeline |
|
|
|
|
| class TextGenerationPipeline(Pipeline): |
| def __init__(self, model, **kwargs): |
| super().__init__(model=model, **kwargs) |
| |
| model_name = "InstaDeepAI/ChatNT" |
| self.english_tokenizer = AutoTokenizer.from_pretrained( |
| model_name, subfolder="english_tokenizer" |
| ) |
| self.bio_tokenizer = AutoTokenizer.from_pretrained( |
| model_name, subfolder="bio_tokenizer" |
| ) |
|
|
| def _sanitize_parameters(self, **kwargs: dict) -> tuple[dict, dict, dict]: |
| preprocess_kwargs = {} |
| forward_kwargs = {} |
| postprocess_kwargs = {} |
|
|
| if "max_num_tokens_to_decode" in kwargs: |
| forward_kwargs["max_num_tokens_to_decode"] = kwargs[ |
| "max_num_tokens_to_decode" |
| ] |
| if "english_tokens_max_length" in kwargs: |
| preprocess_kwargs["english_tokens_max_length"] = kwargs[ |
| "english_tokens_max_length" |
| ] |
| if "bio_tokens_max_length" in kwargs: |
| preprocess_kwargs["bio_tokens_max_length"] = kwargs["bio_tokens_max_length"] |
|
|
| return preprocess_kwargs, forward_kwargs, postprocess_kwargs |
|
|
| def preprocess( |
| self, |
| inputs: dict, |
| english_tokens_max_length: int = 512, |
| bio_tokens_max_length: int = 512, |
| ) -> dict: |
| english_sequence = inputs["english_sequence"] |
| dna_sequences = inputs["dna_sequences"] |
|
|
| context = "A chat between a curious user and an artificial intelligence assistant that can handle bio sequences. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: " |
| space = " " |
| if english_sequence[-1] == " ": |
| space = "" |
| english_sequence = context + english_sequence + space + "ASSISTANT:" |
|
|
| english_tokens = self.english_tokenizer( |
| english_sequence, |
| return_tensors="pt", |
| padding="max_length", |
| truncation=True, |
| max_length=english_tokens_max_length, |
| ).input_ids |
| if len(dna_sequences) == 0: |
| bio_tokens = None |
| else: |
| bio_tokens = self.bio_tokenizer( |
| dna_sequences, |
| return_tensors="pt", |
| padding="max_length", |
| max_length=bio_tokens_max_length, |
| truncation=True, |
| ).input_ids.unsqueeze(0) |
|
|
| return {"english_tokens": english_tokens, "bio_tokens": bio_tokens} |
|
|
| def _forward(self, model_inputs: dict, max_num_tokens_to_decode: int = 50) -> dict: |
| english_tokens = model_inputs["english_tokens"].clone() |
| bio_tokens = model_inputs["bio_tokens"] |
| if bio_tokens is not None: |
| bio_tokens = bio_tokens.clone() |
| projected_bio_embeddings = None |
|
|
| actual_num_steps = 0 |
| with torch.no_grad(): |
| for _ in range(max_num_tokens_to_decode): |
| |
| if ( |
| self.english_tokenizer.pad_token_id |
| not in english_tokens[0].cpu().numpy() |
| ): |
| break |
|
|
| |
| outs = self.model( |
| multi_omics_tokens_ids=(english_tokens, bio_tokens), |
| projection_english_tokens_ids=english_tokens, |
| projected_bio_embeddings=projected_bio_embeddings, |
| ) |
| projected_bio_embeddings = outs["projected_bio_embeddings"] |
| logits = outs["logits"].detach().cpu().numpy() |
|
|
| |
| first_idx_pad_token = np.where( |
| english_tokens[0].cpu() == self.english_tokenizer.pad_token_id |
| )[0][0] |
| predicted_token = np.argmax(logits[0, first_idx_pad_token - 1]) |
|
|
| |
| if predicted_token == self.english_tokenizer.eos_token_id: |
| break |
| else: |
| english_tokens[0, first_idx_pad_token] = predicted_token |
| actual_num_steps += 1 |
|
|
| |
| idx_begin_generation = np.where( |
| model_inputs["english_tokens"][0].cpu() |
| == self.english_tokenizer.pad_token_id |
| )[0][0] |
|
|
| |
| generated_tokens = english_tokens[ |
| 0, idx_begin_generation : idx_begin_generation + actual_num_steps |
| ] |
|
|
| return { |
| "generated_tokens": generated_tokens, |
| } |
|
|
| def postprocess(self, model_outputs: dict) -> str: |
| generated_tokens = model_outputs["generated_tokens"] |
| generated_sequence: str = self.english_tokenizer.decode(generated_tokens) |
| return generated_sequence |
|
|