2026-rff_mp/shahovaa/zadanie 2/scripts/run_experiments.py
2026-05-19 22:39:51 +03:00

195 lines
6.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from __future__ import annotations
import csv
import statistics
import sys
from collections import defaultdict
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
from maze_solver import ( # noqa: E402
AStarStrategy,
BFSStrategy,
DFSStrategy,
DijkstraStrategy,
MazeSolver,
TextFileMazeBuilder,
)
from scripts.generate_mazes import generate_all # noqa: E402
MAZES = [
("small", "Маленький 10x10", ROOT / "data" / "mazes" / "small.txt"),
("medium", "Средний 50x50", ROOT / "data" / "mazes" / "medium.txt"),
("large", "Большой 100x100", ROOT / "data" / "mazes" / "large.txt"),
("empty", "Пустой 50x50", ROOT / "data" / "mazes" / "empty.txt"),
("no_exit", "Без пути 30x30", ROOT / "data" / "mazes" / "no_exit.txt"),
]
STRATEGIES = [BFSStrategy, DFSStrategy, AStarStrategy, DijkstraStrategy]
REPORTS_DIR = ROOT / "reports"
CHARTS_DIR = REPORTS_DIR / "charts"
def main(runs: int = 10) -> None:
generate_all()
REPORTS_DIR.mkdir(parents=True, exist_ok=True)
CHARTS_DIR.mkdir(parents=True, exist_ok=True)
rows = _run_experiments(runs)
_write_csv(rows)
_write_charts(rows)
print(f"Wrote {REPORTS_DIR / 'results.csv'}")
print(f"Wrote SVG charts to {CHARTS_DIR}")
def _run_experiments(runs: int) -> list[dict[str, object]]:
builder = TextFileMazeBuilder()
rows: list[dict[str, object]] = []
for maze_key, maze_name, maze_path in MAZES:
maze = builder.build_from_file(maze_path)
for strategy_type in STRATEGIES:
measurements = []
for _ in range(runs):
stats = MazeSolver(maze, strategy_type()).solve()
measurements.append(stats)
avg_time = statistics.fmean(item.time_ms for item in measurements)
avg_visited = statistics.fmean(item.visited_cells for item in measurements)
avg_path = statistics.fmean(item.path_length for item in measurements)
found = measurements[-1].path_length > 0
rows.append(
{
"key": maze_key,
"лабиринт": maze_name,
"стратегия": measurements[-1].strategy_name,
"время_мс": f"{avg_time:.4f}",
"посещено_клеток": f"{avg_visited:.1f}",
"длина_пути": f"{avg_path:.1f}",
"путь_найден": "да" if found else "нет",
"запусков": runs,
}
)
return rows
def _write_csv(rows: list[dict[str, object]]) -> None:
csv_path = REPORTS_DIR / "results.csv"
headers = [
"лабиринт",
"стратегия",
"время_мс",
"посещено_клеток",
"длина_пути",
"путь_найден",
"запусков",
]
with csv_path.open("w", encoding="utf-8", newline="") as stream:
writer = csv.DictWriter(stream, fieldnames=headers)
writer.writeheader()
for row in rows:
writer.writerow({header: row[header] for header in headers})
def _write_charts(rows: list[dict[str, object]]) -> None:
grouped: dict[str, list[dict[str, object]]] = defaultdict(list)
for row in rows:
grouped[str(row["key"])].append(row)
for maze_key, group in grouped.items():
title = str(group[0]["лабиринт"])
_write_bar_chart(
CHARTS_DIR / f"{maze_key}_time.svg",
title=f"{title}: среднее время, мс",
rows=group,
metric="время_мс",
color="#2f6fbb",
)
_write_bar_chart(
CHARTS_DIR / f"{maze_key}_visited.svg",
title=f"{title}: посещенные клетки",
rows=group,
metric="посещено_клеток",
color="#2f8f5b",
)
def _write_bar_chart(
path: Path,
title: str,
rows: list[dict[str, object]],
metric: str,
color: str,
) -> None:
width = 780
height = 360
left = 72
right = 28
top = 54
bottom = 58
chart_width = width - left - right
chart_height = height - top - bottom
values = [float(row[metric]) for row in rows]
max_value = max(values) if values else 1.0
max_value = max_value or 1.0
bar_area = chart_width / len(rows)
bar_width = min(96, bar_area * 0.58)
parts = [
f'<svg xmlns="http://www.w3.org/2000/svg" width="{width}" height="{height}" viewBox="0 0 {width} {height}">',
'<rect width="100%" height="100%" fill="#ffffff"/>',
f'<text x="{left}" y="30" font-family="Arial" font-size="18" font-weight="700" fill="#1f2933">{_escape(title)}</text>',
f'<line x1="{left}" y1="{height - bottom}" x2="{width - right}" y2="{height - bottom}" stroke="#9aa5b1"/>',
f'<line x1="{left}" y1="{top}" x2="{left}" y2="{height - bottom}" stroke="#9aa5b1"/>',
]
for tick in range(5):
ratio = tick / 4
y = height - bottom - ratio * chart_height
value = max_value * ratio
parts.append(
f'<line x1="{left - 5}" y1="{y:.1f}" x2="{width - right}" y2="{y:.1f}" stroke="#edf0f2"/>'
)
parts.append(
f'<text x="{left - 10}" y="{y + 4:.1f}" text-anchor="end" font-family="Arial" font-size="11" fill="#52616b">{value:.2f}</text>'
)
for index, row in enumerate(rows):
value = float(row[metric])
ratio = value / max_value
bar_height = ratio * chart_height
x = left + index * bar_area + (bar_area - bar_width) / 2
y = height - bottom - bar_height
label = str(row["стратегия"])
parts.append(
f'<rect x="{x:.1f}" y="{y:.1f}" width="{bar_width:.1f}" height="{bar_height:.1f}" fill="{color}" rx="3"/>'
)
parts.append(
f'<text x="{x + bar_width / 2:.1f}" y="{y - 8:.1f}" text-anchor="middle" font-family="Arial" font-size="12" fill="#1f2933">{value:.2f}</text>'
)
parts.append(
f'<text x="{x + bar_width / 2:.1f}" y="{height - bottom + 24}" text-anchor="middle" font-family="Arial" font-size="12" fill="#1f2933">{_escape(label)}</text>'
)
parts.append("</svg>")
path.write_text("\n".join(parts), encoding="utf-8")
def _escape(value: str) -> str:
return (
value.replace("&", "&amp;")
.replace("<", "&lt;")
.replace(">", "&gt;")
.replace('"', "&quot;")
)
if __name__ == "__main__":
main()