| | |
| | from transformers import LEDConfig, LEDModel, LEDPreTrainedModel |
| | import torch.nn as nn |
| |
|
| | |
| | class CustomLEDForQAModel(LEDPreTrainedModel): |
| | config_class = LEDConfig |
| | |
| | def __init__(self, config: LEDConfig, checkpoint): |
| |
|
| | super().__init__(config) |
| | config.num_labels = 2 |
| | self.num_labels = config.num_labels |
| |
|
| | if (checkpoint): |
| | self.led = LEDModel.from_pretrained(checkpoint, config=config).get_encoder() |
| | else: |
| | self.led = LEDModel(config).get_encoder() |
| | |
| | self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) |
| |
|
| | def forward(self, input_ids=None, attention_mask=None, global_attention_mask=None, start_positions=None, end_positions=None): |
| |
|
| | outputs = self.led(input_ids=input_ids, attention_mask=attention_mask, global_attention_mask=global_attention_mask) |
| | |
| | logits = self.qa_outputs(outputs.last_hidden_state) |
| | start_logits, end_logits = logits.split(1, dim=-1) |
| | start_logits = start_logits.squeeze(-1).contiguous() |
| | end_logits = end_logits.squeeze(-1).contiguous() |
| |
|
| | total_loss = None |
| |
|
| | if start_positions is not None and end_positions is not None: |
| | |
| | loss_fct = nn.CrossEntropyLoss() |
| |
|
| | if len(start_positions.size()) > 1: |
| | start_positions = start_positions.squeeze(-1) |
| | if len(end_positions.size()) > 1: |
| | end_positions = end_positions.squeeze(-1) |
| |
|
| | start_loss = loss_fct(start_logits, start_positions) |
| | end_loss = loss_fct(end_logits, end_positions) |
| | |
| | |
| | total_loss = (start_loss + end_loss) / 2 |
| | |
| | |
| |
|
| | |
| | return { |
| | 'loss': total_loss, |
| | 'start_logits': start_logits, |
| | 'end_logits': end_logits, |
| | } |