2026-rff_mp/shahovaa/zadanie 2/scripts/run_experiments.py

195 lines
6.6 KiB
Python
Raw Normal View History

2026-05-19 19:39:51 +00:00
from __future__ import annotations
import csv
import statistics
import sys
from collections import defaultdict
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
from maze_solver import ( # noqa: E402
AStarStrategy,
BFSStrategy,
DFSStrategy,
DijkstraStrategy,
MazeSolver,
TextFileMazeBuilder,
)
from scripts.generate_mazes import generate_all # noqa: E402
MAZES = [
("small", "Маленький 10x10", ROOT / "data" / "mazes" / "small.txt"),
("medium", "Средний 50x50", ROOT / "data" / "mazes" / "medium.txt"),
("large", "Большой 100x100", ROOT / "data" / "mazes" / "large.txt"),
("empty", "Пустой 50x50", ROOT / "data" / "mazes" / "empty.txt"),
("no_exit", "Без пути 30x30", ROOT / "data" / "mazes" / "no_exit.txt"),
]
STRATEGIES = [BFSStrategy, DFSStrategy, AStarStrategy, DijkstraStrategy]
REPORTS_DIR = ROOT / "reports"
CHARTS_DIR = REPORTS_DIR / "charts"
def main(runs: int = 10) -> None:
generate_all()
REPORTS_DIR.mkdir(parents=True, exist_ok=True)
CHARTS_DIR.mkdir(parents=True, exist_ok=True)
rows = _run_experiments(runs)
_write_csv(rows)
_write_charts(rows)
print(f"Wrote {REPORTS_DIR / 'results.csv'}")
print(f"Wrote SVG charts to {CHARTS_DIR}")
def _run_experiments(runs: int) -> list[dict[str, object]]:
builder = TextFileMazeBuilder()
rows: list[dict[str, object]] = []
for maze_key, maze_name, maze_path in MAZES:
maze = builder.build_from_file(maze_path)
for strategy_type in STRATEGIES:
measurements = []
for _ in range(runs):
stats = MazeSolver(maze, strategy_type()).solve()
measurements.append(stats)
avg_time = statistics.fmean(item.time_ms for item in measurements)
avg_visited = statistics.fmean(item.visited_cells for item in measurements)
avg_path = statistics.fmean(item.path_length for item in measurements)
found = measurements[-1].path_length > 0
rows.append(
{
"key": maze_key,
"лабиринт": maze_name,
"стратегия": measurements[-1].strategy_name,
"время_мс": f"{avg_time:.4f}",
"посещено_клеток": f"{avg_visited:.1f}",
"длина_пути": f"{avg_path:.1f}",
"путь_найден": "да" if found else "нет",
"запусков": runs,
}
)
return rows
def _write_csv(rows: list[dict[str, object]]) -> None:
csv_path = REPORTS_DIR / "results.csv"
headers = [
"лабиринт",
"стратегия",
"время_мс",
"посещено_клеток",
"длина_пути",
"путь_найден",
"запусков",
]
with csv_path.open("w", encoding="utf-8", newline="") as stream:
writer = csv.DictWriter(stream, fieldnames=headers)
writer.writeheader()
for row in rows:
writer.writerow({header: row[header] for header in headers})
def _write_charts(rows: list[dict[str, object]]) -> None:
grouped: dict[str, list[dict[str, object]]] = defaultdict(list)
for row in rows:
grouped[str(row["key"])].append(row)
for maze_key, group in grouped.items():
title = str(group[0]["лабиринт"])
_write_bar_chart(
CHARTS_DIR / f"{maze_key}_time.svg",
title=f"{title}: среднее время, мс",
rows=group,
metric="время_мс",
color="#2f6fbb",
)
_write_bar_chart(
CHARTS_DIR / f"{maze_key}_visited.svg",
title=f"{title}: посещенные клетки",
rows=group,
metric="посещено_клеток",
color="#2f8f5b",
)
def _write_bar_chart(
path: Path,
title: str,
rows: list[dict[str, object]],
metric: str,
color: str,
) -> None:
width = 780
height = 360
left = 72
right = 28
top = 54
bottom = 58
chart_width = width - left - right
chart_height = height - top - bottom
values = [float(row[metric]) for row in rows]
max_value = max(values) if values else 1.0
max_value = max_value or 1.0
bar_area = chart_width / len(rows)
bar_width = min(96, bar_area * 0.58)
parts = [
f'<svg xmlns="http://www.w3.org/2000/svg" width="{width}" height="{height}" viewBox="0 0 {width} {height}">',
'<rect width="100%" height="100%" fill="#ffffff"/>',
f'<text x="{left}" y="30" font-family="Arial" font-size="18" font-weight="700" fill="#1f2933">{_escape(title)}</text>',
f'<line x1="{left}" y1="{height - bottom}" x2="{width - right}" y2="{height - bottom}" stroke="#9aa5b1"/>',
f'<line x1="{left}" y1="{top}" x2="{left}" y2="{height - bottom}" stroke="#9aa5b1"/>',
]
for tick in range(5):
ratio = tick / 4
y = height - bottom - ratio * chart_height
value = max_value * ratio
parts.append(
f'<line x1="{left - 5}" y1="{y:.1f}" x2="{width - right}" y2="{y:.1f}" stroke="#edf0f2"/>'
)
parts.append(
f'<text x="{left - 10}" y="{y + 4:.1f}" text-anchor="end" font-family="Arial" font-size="11" fill="#52616b">{value:.2f}</text>'
)
for index, row in enumerate(rows):
value = float(row[metric])
ratio = value / max_value
bar_height = ratio * chart_height
x = left + index * bar_area + (bar_area - bar_width) / 2
y = height - bottom - bar_height
label = str(row["стратегия"])
parts.append(
f'<rect x="{x:.1f}" y="{y:.1f}" width="{bar_width:.1f}" height="{bar_height:.1f}" fill="{color}" rx="3"/>'
)
parts.append(
f'<text x="{x + bar_width / 2:.1f}" y="{y - 8:.1f}" text-anchor="middle" font-family="Arial" font-size="12" fill="#1f2933">{value:.2f}</text>'
)
parts.append(
f'<text x="{x + bar_width / 2:.1f}" y="{height - bottom + 24}" text-anchor="middle" font-family="Arial" font-size="12" fill="#1f2933">{_escape(label)}</text>'
)
parts.append("</svg>")
path.write_text("\n".join(parts), encoding="utf-8")
def _escape(value: str) -> str:
return (
value.replace("&", "&amp;")
.replace("<", "&lt;")
.replace(">", "&gt;")
.replace('"', "&quot;")
)
if __name__ == "__main__":
main()