forked from UNN/2026-rff_mp
195 lines
6.6 KiB
Python
195 lines
6.6 KiB
Python
from __future__ import annotations
|
||
|
||
import csv
|
||
import statistics
|
||
import sys
|
||
from collections import defaultdict
|
||
from pathlib import Path
|
||
|
||
|
||
ROOT = Path(__file__).resolve().parents[1]
|
||
if str(ROOT) not in sys.path:
|
||
sys.path.insert(0, str(ROOT))
|
||
|
||
from maze_solver import ( # noqa: E402
|
||
AStarStrategy,
|
||
BFSStrategy,
|
||
DFSStrategy,
|
||
DijkstraStrategy,
|
||
MazeSolver,
|
||
TextFileMazeBuilder,
|
||
)
|
||
from scripts.generate_mazes import generate_all # noqa: E402
|
||
|
||
|
||
MAZES = [
|
||
("small", "Маленький 10x10", ROOT / "data" / "mazes" / "small.txt"),
|
||
("medium", "Средний 50x50", ROOT / "data" / "mazes" / "medium.txt"),
|
||
("large", "Большой 100x100", ROOT / "data" / "mazes" / "large.txt"),
|
||
("empty", "Пустой 50x50", ROOT / "data" / "mazes" / "empty.txt"),
|
||
("no_exit", "Без пути 30x30", ROOT / "data" / "mazes" / "no_exit.txt"),
|
||
]
|
||
|
||
STRATEGIES = [BFSStrategy, DFSStrategy, AStarStrategy, DijkstraStrategy]
|
||
REPORTS_DIR = ROOT / "reports"
|
||
CHARTS_DIR = REPORTS_DIR / "charts"
|
||
|
||
|
||
def main(runs: int = 10) -> None:
|
||
generate_all()
|
||
REPORTS_DIR.mkdir(parents=True, exist_ok=True)
|
||
CHARTS_DIR.mkdir(parents=True, exist_ok=True)
|
||
|
||
rows = _run_experiments(runs)
|
||
_write_csv(rows)
|
||
_write_charts(rows)
|
||
print(f"Wrote {REPORTS_DIR / 'results.csv'}")
|
||
print(f"Wrote SVG charts to {CHARTS_DIR}")
|
||
|
||
|
||
def _run_experiments(runs: int) -> list[dict[str, object]]:
|
||
builder = TextFileMazeBuilder()
|
||
rows: list[dict[str, object]] = []
|
||
|
||
for maze_key, maze_name, maze_path in MAZES:
|
||
maze = builder.build_from_file(maze_path)
|
||
for strategy_type in STRATEGIES:
|
||
measurements = []
|
||
for _ in range(runs):
|
||
stats = MazeSolver(maze, strategy_type()).solve()
|
||
measurements.append(stats)
|
||
|
||
avg_time = statistics.fmean(item.time_ms for item in measurements)
|
||
avg_visited = statistics.fmean(item.visited_cells for item in measurements)
|
||
avg_path = statistics.fmean(item.path_length for item in measurements)
|
||
found = measurements[-1].path_length > 0
|
||
rows.append(
|
||
{
|
||
"key": maze_key,
|
||
"лабиринт": maze_name,
|
||
"стратегия": measurements[-1].strategy_name,
|
||
"время_мс": f"{avg_time:.4f}",
|
||
"посещено_клеток": f"{avg_visited:.1f}",
|
||
"длина_пути": f"{avg_path:.1f}",
|
||
"путь_найден": "да" if found else "нет",
|
||
"запусков": runs,
|
||
}
|
||
)
|
||
|
||
return rows
|
||
|
||
|
||
def _write_csv(rows: list[dict[str, object]]) -> None:
|
||
csv_path = REPORTS_DIR / "results.csv"
|
||
headers = [
|
||
"лабиринт",
|
||
"стратегия",
|
||
"время_мс",
|
||
"посещено_клеток",
|
||
"длина_пути",
|
||
"путь_найден",
|
||
"запусков",
|
||
]
|
||
with csv_path.open("w", encoding="utf-8", newline="") as stream:
|
||
writer = csv.DictWriter(stream, fieldnames=headers)
|
||
writer.writeheader()
|
||
for row in rows:
|
||
writer.writerow({header: row[header] for header in headers})
|
||
|
||
|
||
def _write_charts(rows: list[dict[str, object]]) -> None:
|
||
grouped: dict[str, list[dict[str, object]]] = defaultdict(list)
|
||
for row in rows:
|
||
grouped[str(row["key"])].append(row)
|
||
|
||
for maze_key, group in grouped.items():
|
||
title = str(group[0]["лабиринт"])
|
||
_write_bar_chart(
|
||
CHARTS_DIR / f"{maze_key}_time.svg",
|
||
title=f"{title}: среднее время, мс",
|
||
rows=group,
|
||
metric="время_мс",
|
||
color="#2f6fbb",
|
||
)
|
||
_write_bar_chart(
|
||
CHARTS_DIR / f"{maze_key}_visited.svg",
|
||
title=f"{title}: посещенные клетки",
|
||
rows=group,
|
||
metric="посещено_клеток",
|
||
color="#2f8f5b",
|
||
)
|
||
|
||
|
||
def _write_bar_chart(
|
||
path: Path,
|
||
title: str,
|
||
rows: list[dict[str, object]],
|
||
metric: str,
|
||
color: str,
|
||
) -> None:
|
||
width = 780
|
||
height = 360
|
||
left = 72
|
||
right = 28
|
||
top = 54
|
||
bottom = 58
|
||
chart_width = width - left - right
|
||
chart_height = height - top - bottom
|
||
values = [float(row[metric]) for row in rows]
|
||
max_value = max(values) if values else 1.0
|
||
max_value = max_value or 1.0
|
||
bar_area = chart_width / len(rows)
|
||
bar_width = min(96, bar_area * 0.58)
|
||
|
||
parts = [
|
||
f'<svg xmlns="http://www.w3.org/2000/svg" width="{width}" height="{height}" viewBox="0 0 {width} {height}">',
|
||
'<rect width="100%" height="100%" fill="#ffffff"/>',
|
||
f'<text x="{left}" y="30" font-family="Arial" font-size="18" font-weight="700" fill="#1f2933">{_escape(title)}</text>',
|
||
f'<line x1="{left}" y1="{height - bottom}" x2="{width - right}" y2="{height - bottom}" stroke="#9aa5b1"/>',
|
||
f'<line x1="{left}" y1="{top}" x2="{left}" y2="{height - bottom}" stroke="#9aa5b1"/>',
|
||
]
|
||
|
||
for tick in range(5):
|
||
ratio = tick / 4
|
||
y = height - bottom - ratio * chart_height
|
||
value = max_value * ratio
|
||
parts.append(
|
||
f'<line x1="{left - 5}" y1="{y:.1f}" x2="{width - right}" y2="{y:.1f}" stroke="#edf0f2"/>'
|
||
)
|
||
parts.append(
|
||
f'<text x="{left - 10}" y="{y + 4:.1f}" text-anchor="end" font-family="Arial" font-size="11" fill="#52616b">{value:.2f}</text>'
|
||
)
|
||
|
||
for index, row in enumerate(rows):
|
||
value = float(row[metric])
|
||
ratio = value / max_value
|
||
bar_height = ratio * chart_height
|
||
x = left + index * bar_area + (bar_area - bar_width) / 2
|
||
y = height - bottom - bar_height
|
||
label = str(row["стратегия"])
|
||
parts.append(
|
||
f'<rect x="{x:.1f}" y="{y:.1f}" width="{bar_width:.1f}" height="{bar_height:.1f}" fill="{color}" rx="3"/>'
|
||
)
|
||
parts.append(
|
||
f'<text x="{x + bar_width / 2:.1f}" y="{y - 8:.1f}" text-anchor="middle" font-family="Arial" font-size="12" fill="#1f2933">{value:.2f}</text>'
|
||
)
|
||
parts.append(
|
||
f'<text x="{x + bar_width / 2:.1f}" y="{height - bottom + 24}" text-anchor="middle" font-family="Arial" font-size="12" fill="#1f2933">{_escape(label)}</text>'
|
||
)
|
||
|
||
parts.append("</svg>")
|
||
path.write_text("\n".join(parts), encoding="utf-8")
|
||
|
||
|
||
def _escape(value: str) -> str:
|
||
return (
|
||
value.replace("&", "&")
|
||
.replace("<", "<")
|
||
.replace(">", ">")
|
||
.replace('"', """)
|
||
)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|