| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from transformers import BertModel |
| | from transformers.models.clip.modeling_clip import CLIPTextModel |
| | from transformers.models.mpnet.modeling_mpnet import MPNetModel |
| | from transformers.trainer import logger |
| |
|
| | from .align_transformers import build_align_transformer |
| | from .common_layers import BasePreTrainedModel |
| | from .configuration_radzero import CxrAlignConfig |
| | from .losses import KeyPhraseAlignmentLoss |
| | from .text_encoders import build_text_encoder |
| | from .vision_encoders import Dinov2Model, build_vision_encoder |
| |
|
| |
|
| | class CxrAlignModel(BasePreTrainedModel): |
| |
|
| | config_class = CxrAlignConfig |
| |
|
| | def build_vision_model(self, config: CxrAlignConfig): |
| | vision_config = config.vision_config |
| | vision_config.pretrained_dir = config.pretrained_dir |
| | vision_model = build_vision_encoder(vision_config) |
| | return vision_model |
| |
|
| | def build_text_model(self, config: CxrAlignConfig): |
| | text_config = config.text_config |
| | text_model = build_text_encoder(text_config) |
| | return text_model |
| |
|
| | def build_align_transformer_model(self, config: CxrAlignConfig): |
| | align_transformer_config = config.align_transformer_config |
| | align_transformer = build_align_transformer(align_transformer_config) |
| |
|
| | return align_transformer |
| |
|
| | def __init__(self, config: CxrAlignConfig): |
| | super().__init__(config) |
| |
|
| | logger.info("Build vision model ...") |
| | self.vision_model = self.build_vision_model(config) |
| |
|
| | logger.info("Build text model ...") |
| | self.text_model = self.build_text_model(config) |
| |
|
| | if ( |
| | isinstance(self.text_model, CLIPTextModel) |
| | or isinstance(self.text_model, MPNetModel) |
| | or isinstance(self.text_model, BertModel) |
| | ): |
| | text_dim = self.text_model.config.hidden_size |
| |
|
| | self.hidden_size = config.align_transformer_config.hidden_size |
| |
|
| | if config.text_config.use_text_projection: |
| | self.text_projector = nn.Linear(text_dim, 2 * self.hidden_size) |
| | else: |
| | self.text_projector = None |
| |
|
| | logger.info("Build align transformer model ...") |
| | self.align_transformer = self.build_align_transformer_model(config) |
| |
|
| | logger.info("Build loss functions ...") |
| | loss_cfg = config.kwargs["loss"] |
| | self.loss_ratio = dict() |
| | self.loss_fns = nn.ModuleDict() |
| | for loss_type, ratio in zip(loss_cfg["apply"], loss_cfg["ratio"]): |
| | logger.info(f"Build {loss_type} loss function ...") |
| | if loss_cfg[loss_type] is None: |
| | loss_cfg[loss_type] = dict() |
| | if torch.distributed.is_available() and torch.distributed.is_initialized(): |
| | loss_cfg[loss_type]["rank"] = torch.distributed.get_rank() |
| | loss_cfg[loss_type]["world_size"] = torch.distributed.get_world_size() |
| | self.loss_fns[loss_type] = eval(loss_type)(**loss_cfg[loss_type]) |
| | self.loss_ratio[loss_type] = ratio |
| |
|
| | self.compute_logits_type = config.kwargs.get("compute_logits_type") |
| | self.use_negative_logits = config.kwargs.get("use_negative_logits") |
| |
|
| | self.module_to_update = config.kwargs.get("module_to_update") |
| |
|
| | def forward_vision_model(self, pixel_values): |
| |
|
| | if isinstance(self.vision_model, Dinov2Model): |
| | vision_tokens = self.vision_model(pixel_values)["last_hidden_state"] |
| |
|
| | else: |
| | raise NotImplementedError |
| |
|
| | vision_tokens = self.align_transformer(vision_tokens) |
| |
|
| | cls_token = vision_tokens[:, 0] |
| | patch_tokens = vision_tokens[:, 1:] |
| | image_features = torch.cat([cls_token, patch_tokens.mean(dim=1)], dim=1) |
| | image_features = F.normalize(image_features, p=2, dim=1) |
| |
|
| | outputs = {} |
| | outputs["vision_tokens"] = vision_tokens |
| | outputs["image_cls_token"] = cls_token |
| | outputs["image_patch_tokens"] = patch_tokens |
| | outputs["image_features"] = image_features |
| |
|
| | return outputs |
| |
|
| | def forward_text_model(self, encoded_input): |
| | text_outputs = {} |
| |
|
| | if isinstance(self.text_model, MPNetModel): |
| | model_output = self.text_model( |
| | input_ids=encoded_input["input_ids"], |
| | attention_mask=encoded_input["attention_mask"], |
| | ) |
| |
|
| | token_embeddings = model_output[ |
| | 0 |
| | ] |
| |
|
| | |
| | if self.text_projector is not None: |
| | token_embeddings = self.text_projector(token_embeddings) |
| |
|
| | |
| | if self.config.text_config.use_cls_token: |
| | text_features = token_embeddings[:, 0, :] |
| |
|
| | else: |
| | |
| | input_mask_expanded = ( |
| | encoded_input["attention_mask"] |
| | .unsqueeze(-1) |
| | .expand(token_embeddings.size()) |
| | .float() |
| | ) |
| | text_features = torch.sum( |
| | token_embeddings * input_mask_expanded, 1 |
| | ) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) |
| |
|
| | else: |
| | raise NotImplementedError |
| |
|
| | text_outputs["text_features_wo_l2_norm"] = text_features |
| | text_outputs["text_features"] = F.normalize(text_features, p=2, dim=1) |
| |
|
| | return text_outputs |
| |
|
| | def forward( |
| | self, |
| | pixel_values, |
| | encoded_key_phrases=None, |
| | return_loss=True, |
| | **kwargs, |
| | ): |
| | vision_outputs = self.forward_vision_model(pixel_values) |
| |
|
| | outputs = {} |
| | outputs.update(vision_outputs) |
| |
|
| | |
| | if return_loss: |
| | loss = 0 |
| | losses = {} |
| |
|
| | for loss_type, loss_fn in self.loss_fns.items(): |
| | if isinstance(loss_fn, KeyPhraseAlignmentLoss): |
| | loss_outputs = loss_fn( |
| | encoded_key_phrases, |
| | outputs["vision_tokens"], |
| | self.forward_text_model, |
| | ) |
| | key_phrase_alignment_losses = loss_outputs["losses"] |
| | losses["key_phrase_alignment_loss"] = ( |
| | key_phrase_alignment_losses.pop("loss") |
| | ) |
| | for loss_name, loss_value in key_phrase_alignment_losses.items(): |
| | losses[loss_name] = loss_value |
| | loop_loss = losses["key_phrase_alignment_loss"] |
| | else: |
| | raise NotImplementedError |
| |
|
| | loss += loop_loss * self.loss_ratio[loss_type] |
| |
|
| | losses["loss"] = loss |
| |
|
| | outputs["losses"] = losses |
| |
|
| | return outputs |
| |
|
| | def compute_logits( |
| | self, |
| | pixel_values, |
| | encoded_key_phrases, |
| | **kwargs, |
| | ): |
| | vision_outputs = self.forward_vision_model(pixel_values) |
| |
|
| | outputs = {} |
| |
|
| | if self.compute_logits_type == "key_phrase_alignment": |
| |
|
| | splited_key_phrases = [ |
| | { |
| | "input_ids": encoded_key_phrases[0]["input_ids"][i : i + 1], |
| | "attention_mask": encoded_key_phrases[0]["attention_mask"][ |
| | i : i + 1 |
| | ], |
| | } |
| | for i in range(encoded_key_phrases[0]["input_ids"].size(0)) |
| | ] |
| |
|
| | loss_outputs = self.loss_fns["KeyPhraseAlignmentLoss"]( |
| | splited_key_phrases, |
| | vision_outputs["vision_tokens"], |
| | self.forward_text_model, |
| | ddp_gather=False, |
| | need_attn_weights=True, |
| | compute_loss=False, |
| | ) |
| | outputs.update(loss_outputs) |
| |
|
| | |
| | outputs["similarity_scores"] = torch.mean( |
| | torch.stack(loss_outputs["t2i_attn_weights"]), dim=0 |
| | ) |
| |
|
| | |
| | if self.loss_fns["KeyPhraseAlignmentLoss"].use_vision_cls_token: |
| | outputs["similarity_scores"] = outputs["similarity_scores"][:, :, 1:] |
| |
|
| | |
| | logits = loss_outputs["t2i_logits"] |
| | logits = logits.T |
| |
|
| | logits = ( |
| | logits / self.loss_fns["KeyPhraseAlignmentLoss"].loss_temperature.exp() |
| | ) |
| |
|
| | outputs["logits"] = logits |
| | return outputs |
| |
|