import os import json import subprocess import warnings import sys import shutil def sync_protein_metadata(jsonl_path, dict_path): """ Automated metadata sanitization to prevent KeyError (e.g., 'seq_chain_C'). Prunes non-proteogenic chain IDs from the dictionary before the run. """ if not os.path.exists(jsonl_path) or not os.path.exists(dict_path): return # 1. Identify chains that actually have proteogenic sequence data valid_chains_map = {} with open(jsonl_path, 'r') as f: for line in f: entry = json.loads(line) name = entry['name'] # Only keep chains that have a 'seq_chain_X' entry in the JSONL valid = {k.split('_')[-1] for k in entry.keys() if k.startswith('seq_chain_')} valid_chains_map[name] = valid # 2. Clean the chain ID dictionary with open(dict_path, 'r') as f: chain_id_dict = json.load(f) for pdb_name, configs in chain_id_dict.items(): if pdb_name in valid_chains_map: valid = valid_chains_map[pdb_name] # configs[0] = redesign list, configs[1] = fixed list original_chains = set(configs[0] + configs[1]) chain_id_dict[pdb_name] = [ [c for c in configs[0] if c in valid], [c for c in configs[1] if c in valid] ] # Diagnostic feedback removed = original_chains - valid if removed: print(f"๐Ÿงน Sanitizer: Pruned non-protein chains from metadata: {removed}") # 3. Overwrite with cleaned metadata for ProteinMPNN with open(dict_path, 'w') as f: json.dump(chain_id_dict, f) def run_broteinshake_generator(pdb_path, fixed_chains, variable_chains, num_seqs=20, temp=0.1): # 1. Setup identifiers and directories pdb_name = os.path.basename(pdb_path).split('.')[0] output_dir = f"./generated/{pdb_name}" os.makedirs(output_dir, exist_ok=True) script_dir = os.path.dirname(os.path.abspath(__file__)) project_root = os.path.dirname(script_dir) proteinmpnn_dir = os.path.join(project_root, "ProteinMPNN") if not os.path.exists(proteinmpnn_dir): print("ProteinMPNN not found, cloning repository...") subprocess.run(["git", "clone", "https://github.com/dauparas/ProteinMPNN.git"], cwd=project_root, check=True) mpnn_script = os.path.join(proteinmpnn_dir, "protein_mpnn_run.py") # 2. Handle Single vs Multi-Chain Logic if not fixed_chains or len(fixed_chains) == 0: chain_to_design = variable_chains[0] if variable_chains else "A" mpnn_cmd = ( f"python -W ignore {mpnn_script} --pdb_path {pdb_path} --pdb_path_chains {chain_to_design} " f"--out_folder {output_dir} --num_seq_per_target {num_seqs} --sampling_temp {temp} --seed 42 --batch_size 1" ) print(f"๐Ÿš€ Designing {pdb_name} (Single-chain: {chain_to_design})...") else: # Multi-chain setup pdb_dir = os.path.dirname(os.path.abspath(pdb_path)) or "." jsonl_path = os.path.join(output_dir, "parsed_pdbs.jsonl") parse_script = os.path.join(proteinmpnn_dir, "helper_scripts", "parse_multiple_chains.py") # Step A: Parse PDB to JSONL subprocess.run(f"python -W ignore {parse_script} --input_path={pdb_dir}/ --output_path={jsonl_path}", shell=True, check=True) # Step B: Create initial Chain Dictionary pdb_name_clones = f"{pdb_name}_clones" # Fix: ensure the name in JSONL matches the dict key with open(jsonl_path, 'r') as f: jsonl_data = json.loads(f.readline()) jsonl_data['name'] = pdb_name_clones with open(jsonl_path, 'w') as f: f.write(json.dumps(jsonl_data) + '\n') chain_id_json = os.path.join(output_dir, "chain_id_dict.json") chain_id_dict = {pdb_name_clones: [[c for c in variable_chains], [c for c in fixed_chains]]} with open(chain_id_json, 'w') as f: json.dump(chain_id_dict, f) # Step C: AUTOMATED CLEANING - Prunes ghost chains like 'C' sync_protein_metadata(jsonl_path, chain_id_json) # Step D: Final Execution Command mpnn_cmd = ( f"python -W ignore {mpnn_script} --jsonl_path {jsonl_path} --chain_id_jsonl {chain_id_json} " f"--out_folder {output_dir} --num_seq_per_target {num_seqs} --sampling_temp {temp} --seed 42" ) print(f"๐Ÿš€ Designing {pdb_name}... (Fixed: {fixed_chains} | Redesign: {variable_chains})") # 3. Execute with suppressed warnings env = os.environ.copy() env['PYTHONWARNINGS'] = 'ignore' subprocess.run(mpnn_cmd, shell=True, check=True, env=env, stderr=subprocess.DEVNULL) print(f"โœ… Success! Design complete for {pdb_name}.") if __name__ == "__main__": if len(sys.argv) < 4: print("Usage: python scripts/generator.py [num_seqs] [temp]") sys.exit(1) run_broteinshake_generator( sys.argv[1], sys.argv[2], sys.argv[3], int(sys.argv[4]) if len(sys.argv) > 4 else 20, float(sys.argv[5]) if len(sys.argv) > 5 else 0.1 )