BroteinShake / scripts /generator.py
42Cummer's picture
Upload generator.py
05055dd verified
import os
import json
import subprocess
import warnings
import sys
import shutil
def sync_protein_metadata(jsonl_path, dict_path):
"""
Automated metadata sanitization to prevent KeyError (e.g., 'seq_chain_C').
Prunes non-proteogenic chain IDs from the dictionary before the run.
"""
if not os.path.exists(jsonl_path) or not os.path.exists(dict_path):
return
# 1. Identify chains that actually have proteogenic sequence data
valid_chains_map = {}
with open(jsonl_path, 'r') as f:
for line in f:
entry = json.loads(line)
name = entry['name']
# Only keep chains that have a 'seq_chain_X' entry in the JSONL
valid = {k.split('_')[-1] for k in entry.keys() if k.startswith('seq_chain_')}
valid_chains_map[name] = valid
# 2. Clean the chain ID dictionary
with open(dict_path, 'r') as f:
chain_id_dict = json.load(f)
for pdb_name, configs in chain_id_dict.items():
if pdb_name in valid_chains_map:
valid = valid_chains_map[pdb_name]
# configs[0] = redesign list, configs[1] = fixed list
original_chains = set(configs[0] + configs[1])
chain_id_dict[pdb_name] = [
[c for c in configs[0] if c in valid],
[c for c in configs[1] if c in valid]
]
# Diagnostic feedback
removed = original_chains - valid
if removed:
print(f"🧹 Sanitizer: Pruned non-protein chains from metadata: {removed}")
# 3. Overwrite with cleaned metadata for ProteinMPNN
with open(dict_path, 'w') as f:
json.dump(chain_id_dict, f)
def run_broteinshake_generator(pdb_path, fixed_chains, variable_chains, num_seqs=20, temp=0.1):
# 1. Setup identifiers and directories
pdb_name = os.path.basename(pdb_path).split('.')[0]
output_dir = f"./generated/{pdb_name}"
os.makedirs(output_dir, exist_ok=True)
script_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.dirname(script_dir)
proteinmpnn_dir = os.path.join(project_root, "ProteinMPNN")
if not os.path.exists(proteinmpnn_dir):
print("ProteinMPNN not found, cloning repository...")
subprocess.run(["git", "clone", "https://github.com/dauparas/ProteinMPNN.git"], cwd=project_root, check=True)
mpnn_script = os.path.join(proteinmpnn_dir, "protein_mpnn_run.py")
# 2. Handle Single vs Multi-Chain Logic
if not fixed_chains or len(fixed_chains) == 0:
chain_to_design = variable_chains[0] if variable_chains else "A"
mpnn_cmd = (
f"python -W ignore {mpnn_script} --pdb_path {pdb_path} --pdb_path_chains {chain_to_design} "
f"--out_folder {output_dir} --num_seq_per_target {num_seqs} --sampling_temp {temp} --seed 42 --batch_size 1"
)
print(f"🚀 Designing {pdb_name} (Single-chain: {chain_to_design})...")
else:
# Multi-chain setup
pdb_dir = os.path.dirname(os.path.abspath(pdb_path)) or "."
jsonl_path = os.path.join(output_dir, "parsed_pdbs.jsonl")
parse_script = os.path.join(proteinmpnn_dir, "helper_scripts", "parse_multiple_chains.py")
# Step A: Parse PDB to JSONL
subprocess.run(f"python -W ignore {parse_script} --input_path={pdb_dir}/ --output_path={jsonl_path}", shell=True, check=True)
# Step B: Create initial Chain Dictionary
pdb_name_clones = f"{pdb_name}_clones"
# Fix: ensure the name in JSONL matches the dict key
with open(jsonl_path, 'r') as f:
jsonl_data = json.loads(f.readline())
jsonl_data['name'] = pdb_name_clones
with open(jsonl_path, 'w') as f:
f.write(json.dumps(jsonl_data) + '\n')
chain_id_json = os.path.join(output_dir, "chain_id_dict.json")
chain_id_dict = {pdb_name_clones: [[c for c in variable_chains], [c for c in fixed_chains]]}
with open(chain_id_json, 'w') as f:
json.dump(chain_id_dict, f)
# Step C: AUTOMATED CLEANING - Prunes ghost chains like 'C'
sync_protein_metadata(jsonl_path, chain_id_json)
# Step D: Final Execution Command
mpnn_cmd = (
f"python -W ignore {mpnn_script} --jsonl_path {jsonl_path} --chain_id_jsonl {chain_id_json} "
f"--out_folder {output_dir} --num_seq_per_target {num_seqs} --sampling_temp {temp} --seed 42"
)
print(f"🚀 Designing {pdb_name}... (Fixed: {fixed_chains} | Redesign: {variable_chains})")
# 3. Execute with suppressed warnings
env = os.environ.copy()
env['PYTHONWARNINGS'] = 'ignore'
subprocess.run(mpnn_cmd, shell=True, check=True, env=env, stderr=subprocess.DEVNULL)
print(f"✅ Success! Design complete for {pdb_name}.")
if __name__ == "__main__":
if len(sys.argv) < 4:
print("Usage: python scripts/generator.py <pdb_path> <fixed_chains> <variable_chains> [num_seqs] [temp]")
sys.exit(1)
run_broteinshake_generator(
sys.argv[1], sys.argv[2], sys.argv[3],
int(sys.argv[4]) if len(sys.argv) > 4 else 20,
float(sys.argv[5]) if len(sys.argv) > 5 else 0.1
)