Instructions to use MilaDeepGraph/ProtST-ESM1b-LocalizationPrediction with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use MilaDeepGraph/ProtST-ESM1b-LocalizationPrediction with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="MilaDeepGraph/ProtST-ESM1b-LocalizationPrediction", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("MilaDeepGraph/ProtST-ESM1b-LocalizationPrediction", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| library_name: transformers | |
| tags: [] | |
| # Model Card for Model ID | |
| ProtST for binary localization. | |
| The following script shows how to finetune ProtST on Gaudi. | |
| ## Running script | |
| ```diff | |
| from transformers import AutoModel, AutoTokenizer, HfArgumentParser, TrainingArguments, Trainer | |
| from transformers.data.data_collator import DataCollatorWithPadding | |
| from transformers.trainer_pt_utils import get_parameter_names | |
| from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS | |
| from datasets import load_dataset | |
| import functools | |
| import numpy as np | |
| from sklearn.metrics import accuracy_score, matthews_corrcoef | |
| import sys | |
| import torch | |
| import logging | |
| import datasets | |
| import transformers | |
| + import habana_frameworks.torch | |
| + from optimum.habana import GaudiConfig, GaudiTrainer, GaudiTrainingArguments | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| def create_optimizer(opt_model, lr_ratio=0.1): | |
| head_names = [] | |
| for n, p in opt_model.named_parameters(): | |
| if "classifier" in n: | |
| head_names.append(n) | |
| else: | |
| p.requires_grad = False | |
| # turn a list of tuple to 2 lists | |
| for n, p in opt_model.named_parameters(): | |
| if n in head_names: | |
| assert p.requires_grad | |
| backbone_names = [] | |
| for n, p in opt_model.named_parameters(): | |
| if n not in head_names and p.requires_grad: | |
| backbone_names.append(n) | |
| # for weight_decay policy, see | |
| # https://github.com/huggingface/transformers/blob/50573c648ae953dcc1b94d663651f07fb02268f4/src/transformers/trainer.py#L947 | |
| decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS) # forbidden layer norm | |
| decay_parameters = [name for name in decay_parameters if "bias" not in name] | |
| # training_args.learning_rate | |
| head_decay_parameters = [name for name in head_names if name in decay_parameters] | |
| head_not_decay_parameters = [name for name in head_names if name not in decay_parameters] | |
| # training_args.learning_rate * model_config.lr_ratio | |
| backbone_decay_parameters = [name for name in backbone_names if name in decay_parameters] | |
| backbone_not_decay_parameters = [name for name in backbone_names if name not in decay_parameters] | |
| optimizer_grouped_parameters = [ | |
| { | |
| "params": [p for n, p in opt_model.named_parameters() if (n in head_decay_parameters and p.requires_grad)], | |
| "weight_decay": training_args.weight_decay, | |
| "lr": training_args.learning_rate | |
| }, | |
| { | |
| "params": [p for n, p in opt_model.named_parameters() if (n in backbone_decay_parameters and p.requires_grad)], | |
| "weight_decay": training_args.weight_decay, | |
| "lr": training_args.learning_rate * lr_ratio | |
| }, | |
| { | |
| "params": [p for n, p in opt_model.named_parameters() if (n in head_not_decay_parameters and p.requires_grad)], | |
| "weight_decay": 0.0, | |
| "lr": training_args.learning_rate | |
| }, | |
| { | |
| "params": [p for n, p in opt_model.named_parameters() if (n in backbone_not_decay_parameters and p.requires_grad)], | |
| "weight_decay": 0.0, | |
| "lr": training_args.learning_rate * lr_ratio | |
| }, | |
| ] | |
| - optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args) | |
| + optimizer_cls, optimizer_kwargs = GaudiTrainer.get_optimizer_cls_and_kwargs(training_args) | |
| optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) | |
| return optimizer | |
| def create_scheduler(training_args, optimizer): | |
| from transformers.optimization import get_scheduler | |
| return get_scheduler( | |
| training_args.lr_scheduler_type, | |
| optimizer=optimizer if optimizer is None else optimizer, | |
| num_warmup_steps=training_args.get_warmup_steps(training_args.max_steps), | |
| num_training_steps=training_args.max_steps, | |
| ) | |
| def compute_metrics(eval_preds): | |
| probs, labels = eval_preds | |
| preds = np.argmax(probs, axis=-1) | |
| result = {"accuracy": accuracy_score(labels, preds), "mcc": matthews_corrcoef(labels, preds)} | |
| return result | |
| def preprocess_logits_for_metrics(logits, labels): | |
| return torch.softmax(logits, dim=-1) | |
| if __name__ == "__main__": | |
| - device = torch.device("cpu") | |
| + device = torch.device("hpu") | |
| raw_dataset = load_dataset("Jiqing/ProtST-BinaryLocalization") | |
| model = AutoModel.from_pretrained("Jiqing/protst-esm1b-for-sequential-classification", trust_remote_code=True, torch_dtype=torch.bfloat16).to(device) | |
| tokenizer = AutoTokenizer.from_pretrained("facebook/esm1b_t33_650M_UR50S") | |
| output_dir = "/home/jiqingfe/protst/protst_2/ProtST-HuggingFace/output_dir/ProtSTModel/default/ESM-1b_PubMedBERT-abs/240123_015856" | |
| training_args = {'output_dir': output_dir, 'overwrite_output_dir': True, 'do_train': True, 'per_device_train_batch_size': 32, 'gradient_accumulation_steps': 1, \ | |
| 'learning_rate': 5e-05, 'weight_decay': 0, 'num_train_epochs': 100, 'max_steps': -1, 'lr_scheduler_type': 'constant', 'do_eval': True, \ | |
| 'evaluation_strategy': 'epoch', 'per_device_eval_batch_size': 32, 'logging_strategy': 'epoch', 'save_strategy': 'epoch', 'save_steps': 820, \ | |
| 'dataloader_num_workers': 0, 'run_name': 'downstream_esm1b_localization_fix', 'optim': 'adamw_torch', 'resume_from_checkpoint': False, \ | |
| - 'label_names': ['labels'], 'load_best_model_at_end': True, 'metric_for_best_model': 'accuracy', 'bf16': True, "save_total_limit": 3} | |
| + 'label_names': ['labels'], 'load_best_model_at_end': True, 'metric_for_best_model': 'accuracy', 'bf16': True, "save_total_limit": 3, "use_habana":True, "use_lazy_mode": True, "use_hpu_graphs_for_inference": True} | |
| - training_args = HfArgumentParser(TrainingArguments).parse_dict(training_args, allow_extra_keys=False)[0] | |
| + training_args = HfArgumentParser(GaudiTrainingArguments).parse_dict(training_args, allow_extra_keys=False)[0] | |
| def tokenize_protein(example, tokenizer=None): | |
| protein_seq = example["prot_seq"] | |
| protein_seq_str = tokenizer(protein_seq, add_special_tokens=True) | |
| example["input_ids"] = protein_seq_str["input_ids"] | |
| example["attention_mask"] = protein_seq_str["attention_mask"] | |
| example["labels"] = example["localization"] | |
| return example | |
| func_tokenize_protein = functools.partial(tokenize_protein, tokenizer=tokenizer) | |
| for split in ["train", "validation", "test"]: | |
| raw_dataset[split] = raw_dataset[split].map(func_tokenize_protein, batched=False, remove_columns=["Unnamed: 0", "prot_seq", "localization"]) | |
| - data_collator = DataCollatorWithPadding(tokenizer=tokenizer) | |
| + data_collator = DataCollatorWithPadding(tokenizer=tokenizer, padding="max_length", max_length=1024) | |
| transformers.utils.logging.set_verbosity_info() | |
| log_level = training_args.get_process_log_level() | |
| logger.setLevel(log_level) | |
| optimizer = create_optimizer(model) | |
| scheduler = create_scheduler(training_args, optimizer) | |
| + gaudi_config = GaudiConfig() | |
| + gaudi_config.use_fused_adam = True | |
| + gaudi_config.use_fused_clip_norm =True | |
| # build trainer | |
| - trainer = Trainer( | |
| + trainer = GaudiTrainer( | |
| model=model, | |
| + gaudi_config=gaudi_config, | |
| args=training_args, | |
| train_dataset=raw_dataset["train"], | |
| eval_dataset=raw_dataset["validation"], | |
| data_collator=data_collator, | |
| optimizers=(optimizer, scheduler), | |
| compute_metrics=compute_metrics, | |
| preprocess_logits_for_metrics=preprocess_logits_for_metrics, | |
| ) | |
| train_result = trainer.train() | |
| trainer.save_model() | |
| # Saves the tokenizer too for easy upload | |
| tokenizer.save_pretrained(training_args.output_dir) | |
| metrics = train_result.metrics | |
| metrics["train_samples"] = len(raw_dataset["train"]) | |
| trainer.log_metrics("train", metrics) | |
| trainer.save_metrics("train", metrics) | |
| trainer.save_state() | |
| metric = trainer.evaluate(raw_dataset["test"], metric_key_prefix="test") | |
| print("test metric: ", metric) | |
| metric = trainer.evaluate(raw_dataset["validation"], metric_key_prefix="valid") | |
| print("valid metric: ", metric) | |
| ``` | |