| import json |
| import sys |
| import random |
| from collections import defaultdict |
|
|
| def collect_dataset_info(file_path): |
| """收集数据集信息,包括每个数据集的行号列表和首次出现顺序""" |
| dataset_lines = defaultdict(list) |
| order = [] |
| seen = set() |
| |
| with open(file_path, 'r') as f: |
| for line_num, line in enumerate(f, 1): |
| try: |
| data = json.loads(line.strip()) |
| custom_id = data['custom_id'] |
| dataset = custom_id.split('-')[0] |
| |
| if dataset not in seen: |
| order.append(dataset) |
| seen.add(dataset) |
| |
| dataset_lines[dataset].append(line_num) |
| except json.JSONDecodeError: |
| print(f"Error: Invalid JSON at line {line_num}", file=sys.stderr) |
| except KeyError: |
| print(f"Error: Missing 'custom_id' at line {line_num}", file=sys.stderr) |
| except IndexError: |
| print(f"Error: Invalid custom_id format at line {line_num}", file=sys.stderr) |
| |
| return dataset_lines, order |
|
|
| def main(): |
| if len(sys.argv) != 4: |
| print("Usage: python sample_datasets.py <input.jsonl> <output.jsonl> <N>") |
| sys.exit(1) |
| |
| input_file = sys.argv[1] |
| output_file = sys.argv[2] |
| try: |
| N = int(sys.argv[3]) |
| except ValueError: |
| print("Error: N must be an integer.") |
| sys.exit(1) |
| |
| |
| dataset_info, dataset_order = collect_dataset_info(input_file) |
| k = len(dataset_info) |
| |
| if k == 0: |
| print("Error: No datasets found in the input file.") |
| sys.exit(1) |
| |
| |
| for dataset, lines in dataset_info.items(): |
| if len(lines) < 5: |
| print(f"Error: Dataset '{dataset}' has fewer than 5 samples.") |
| sys.exit(1) |
| |
| total_samples = sum(len(lines) for lines in dataset_info.values()) |
| min_samples = 5 * k |
| |
| if N < min_samples or N > total_samples: |
| print(f"Error: N must be between {min_samples} and {total_samples}.") |
| sys.exit(1) |
| |
| |
| available = {dataset: len(lines) - 5 for dataset, lines in dataset_info.items()} |
| total_available = sum(available.values()) |
| R = N - 5 * k |
| |
| if R > total_available: |
| print(f"Error: Cannot allocate {R} samples from available {total_available}.") |
| sys.exit(1) |
| |
| |
| allocations = [] |
| sum_avail = total_available if total_available != 0 else 1 |
| |
| for dataset in dataset_order: |
| avail = available[dataset] |
| alloc_float = R * avail / sum_avail |
| allocations.append(alloc_float) |
| |
| integer_part = [int(alloc) for alloc in allocations] |
| remainders = [alloc - int_part for alloc, int_part in zip(allocations, integer_part)] |
| remainder_total = R - sum(integer_part) |
| |
| |
| remainder_indices = sorted(enumerate(remainders), key=lambda x: (-x[1], x[0])) |
| for i in range(remainder_total): |
| idx = remainder_indices[i][0] |
| integer_part[idx] += 1 |
| |
| |
| sample_counts = {} |
| for i, dataset in enumerate(dataset_order): |
| alloc = integer_part[i] |
| if alloc > available[dataset]: |
| print(f"Error: Allocation for dataset '{dataset}' exceeds available samples.") |
| sys.exit(1) |
| sample_counts[dataset] = 5 + alloc |
| |
| |
| print("\nSampling Distribution:") |
| total_sampled = 0 |
| for dataset in dataset_order: |
| count = sample_counts[dataset] |
| total_sampled += count |
| print(f" - {dataset}: {count} samples") |
| print(f"Total samples: {total_sampled} (target: {N})") |
| |
| |
| if total_sampled != N: |
| print(f"Error: Total sampled count mismatch ({total_sampled} vs {N})") |
| sys.exit(1) |
| |
| |
| selected_lines = [] |
| for dataset in dataset_order: |
| lines = dataset_info[dataset] |
| count = sample_counts[dataset] |
| selected = random.sample(lines, count) |
| selected_lines.extend(selected) |
| |
| selected_lines.sort() |
| |
| |
| current_idx = 0 |
| total_selected = len(selected_lines) |
| |
| with open(input_file, 'r') as infile, open(output_file, 'w') as outfile: |
| for line_num, line in enumerate(infile, 1): |
| if current_idx >= total_selected: |
| break |
| if line_num == selected_lines[current_idx]: |
| outfile.write(line) |
| current_idx += 1 |
| |
| print(f"\nSuccessfully sampled {N} records to {output_file}.") |
|
|
| if __name__ == "__main__": |
| main() |