| import torch |
| from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer |
| from transformers.tools import PipelineTool |
|
|
|
|
| class TextPairClassificationTool(PipelineTool): |
| default_checkpoint = "sgugger/bert-finetuned-mrpc" |
| pre_processor_class = AutoTokenizer |
| model_class = AutoModelForSequenceClassification |
|
|
| description = ( |
| "classifies if two texts in English are similar or not using the labels {labels}. It takes two inputs named " |
| "`text` and `second_text` which should be in English and returns a dictionary with two keys named 'label' " |
| "(the predicted label ) and 'score' (the probability associated to it)." |
| ) |
|
|
| def post_init(self): |
| if isinstance(self.model, str): |
| config = AutoConfig.from_pretrained(self.model) |
| else: |
| config = self.model.config |
|
|
| labels = list(config.label2id.keys()) |
|
|
| if len(labels) > 1: |
| labels = [f"'{label}'" for label in labels] |
| labels_string = ", ".join(labels[:-1]) |
| labels_string += f", and {labels[-1]}" |
| else: |
| raise ValueError("Not enough labels.") |
|
|
| self.description = self.description.replace("{labels}", labels_string) |
|
|
| def encode(self, text, second_text): |
| return self.pre_processor(text, second_text, return_tensors="pt") |
|
|
| def decode(self, outputs): |
| logits = outputs.logits |
| scores = torch.nn.functional.softmax(logits, dim=-1) |
| label_id = torch.argmax(logits[0]).item() |
| label = self.model.config.id2label[label_id] |
| return {"label": label, "score": scores[0][label_id].item()} |
|
|