postgraduate-prep/experiment/code/main.py

166 lines
5.6 KiB
Python

#!/usr/bin/env python3
"""
进程调度算法模拟器 - 主程序
运行方式:
python main.py # 默认小数据集测试
python main.py -n 100 # 生成 100 个随机进程
python main.py -n 1000 --seed 42 # 固定种子 42
python main.py -a P1,P2 # 指定运行的算法
python main.py --demo # 使用默认演示数据
参数说明:
-n, --num: 进程数量 (默认 5)
-s, --seed: 随机种子 (默认 42, 设为 None 则随机)
-a, --algo: 运行的算法 (默认全部)
--demo: 使用预设演示数据
"""
import argparse
from typing import List, Dict, Optional
from base import (
Process, ProcessScheduler,
generate_random_processes,
print_processes,
print_comparison
)
from fcfs import FCFSScheduler
from sjf import SJFScheduler
from rr import RoundRobinScheduler
from priority import PriorityScheduler
from mlfq import MLFQScheduler
# 预设演示数据
DEMO_PROCESSES = [
Process(pid='P1', arrival_time=0, burst_time=7, priority=3),
Process(pid='P2', arrival_time=2, burst_time=4, priority=1),
Process(pid='P3', arrival_time=4, burst_time=1, priority=4),
Process(pid='P4', arrival_time=5, burst_time=4, priority=2),
Process(pid='P5', arrival_time=6, burst_time=2, priority=3),
]
def get_algorithm_config() -> Dict:
"""算法配置:名称与调度器类的映射"""
return {
'FCFS': FCFSScheduler,
'SJF': lambda p: SJFScheduler(p, preemptive=False),
'SRTF': lambda p: SJFScheduler(p, preemptive=True),
'Priority': lambda p: PriorityScheduler(p, preemptive=False),
'RR': lambda p: RoundRobinScheduler(p, time_slice=4),
'MLFQ': lambda p: MLFQScheduler(p),
}
def run_all_algorithms(processes: List[Process], algorithms: Optional[List[str]] = None):
"""运行所有/指定调度算法并比较"""
# 深拷贝进程列表
original_processes = [Process(**p.__dict__) for p in processes]
# 默认运行所有算法
if algorithms is None:
algorithms = list(get_algorithm_config().keys())
results = {}
scheduler_map = get_algorithm_config()
for algo_name in algorithms:
if algo_name not in scheduler_map:
print(f"警告: 未知算法 '{algo_name}', 跳过")
continue
# 每次重新创建进程列表
process_list = [Process(**p.__dict__) for p in original_processes]
# 获取调度器
scheduler_class = scheduler_map[algo_name]
scheduler = scheduler_class(process_list)
# 运行调度
metrics = scheduler.schedule()
results[algo_name] = metrics
# 打印比较结果
print_comparison(results)
# 打印最优指标提示
print("\n" + "="*70)
print("指标分析")
print("="*70)
metrics_names = [
('avg_waiting', '平均等待时间'),
('avg_turnaround', '平均周转时间'),
('avg_weighted_turnaround', '平均带权周转时间'),
('avg_response', '平均响应时间'),
]
for metric_key, metric_name in metrics_names:
values = {name: m[metric_key] for name, m in results.items()}
best_algo = min(values, key=values.get)
print(f"{metric_name}: {best_algo} ({values[best_algo]:.2f}) 最优")
return results
def main():
parser = argparse.ArgumentParser(
description='进程调度算法模拟器',
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog=__doc__
)
parser.add_argument('-n', '--num', type=int, default=10,
help='随机生成的进程数量 (默认: 5)')
parser.add_argument('-s', '--seed', type=int, default=42,
help='随机种子 (默认: 42, 设为 None 则每次不同)')
parser.add_argument('-a', '--algo', type=str, default=None,
help='运行的算法,逗号分隔 (如: FCFS,SJF,RR)')
parser.add_argument('--demo', action='store_true',
help='使用预设演示数据')
parser.add_argument('--arrival', type=str, default='0,50',
help='到达时间范围 (格式: min,max)')
parser.add_argument('--burst', type=str, default='1,20',
help='服务时间范围 (格式: min,max)')
parser.add_argument('--priority', type=str, default='1,10',
help='优先级范围 (格式: min,max)')
args = parser.parse_args()
# 解析算法列表
algorithms = None
if args.algo:
algorithms = [a.strip() for a in args.algo.split(',')]
# 生成测试数据
if args.demo:
processes = DEMO_PROCESSES
print_processes(processes, "演示数据 (5 个进程)")
else:
# 解析范围
arrival_range = tuple(map(int, args.arrival.split(',')))
burst_range = tuple(map(int, args.burst.split(',')))
priority_range = tuple(map(int, args.priority.split(',')))
seed = args.seed if args.seed != -1 else None
processes = generate_random_processes(
n=args.num,
seed=seed,
arrival_range=arrival_range,
burst_range=burst_range,
priority_range=priority_range
)
title = f"随机测试数据 (n={args.num}, seed={args.seed if seed else 'None'})"
print_processes(processes, title)
# 运行算法
print(f"\n运行算法: {algorithms if algorithms else '全部'}")
run_all_algorithms(processes, algorithms)
if __name__ == "__main__":
main()