"""
岩屑图像切片工具 - 四点定位切分
按范围设置d值
"""

import os
import tkinter as tk
from tkinter import filedialog, messagebox
import pandas as pd
from PIL import Image, ImageDraw, ImageTk

class SliceToolApp:
    def __init__(self, root):
        self.root = root
        self.root.title("Rock Image Slice Tool")
        self.root.geometry("1500x950")
        
        # 变量
        self.current_image_path = None
        self.current_image = None
        self.output_dir = tk.StringVar(value="./output")
        self.slice_size = tk.IntVar(value=256)
        
        # 4个角点
        self.corner_points = []
        self.temp_marker_ids = []
        
        # d值范围设置: [{'start': 1, 'end': 101, 'd': 1.0}, ...] (左闭右开，包含1-100)
        self.dragging_point = None
        self.preview_line_ids = []
        self.d_ranges = [{'start': 1, 'end': 101, 'd': 1.0}]
        
        # 100个格子的深度数据
        self.depth_data = []
        for r in range(10):
            row_data = []
            for c in range(10):
                row_data.append(0)
            self.depth_data.append(row_data)
        
        # Canvas
        self.canvas = None
        self.scale = 1.0
        self.offset_x = 0
        self.offset_y = 0
        self.display_w = 0
        self.display_h = 0
        
        self.setup_ui()
        self.calculate_all_depths()
    
    def setup_ui(self):
        # 标题
        title_frame = tk.Frame(self.root, bg="#2c3e50", height=50)
        title_frame.pack(fill=tk.X)
        title_frame.pack_propagate(False)
        tk.Label(title_frame, text="Rock Image Slice Tool - Set d by Range", 
                font=("Arial", 14, "bold"), bg="#2c3e50", fg="white").pack(pady=12)
        
        # 主容器
        main_container = tk.PanedWindow(self.root, orient=tk.HORIZONTAL)
        main_container.pack(fill=tk.BOTH, expand=True)
        
        # 左侧
        left_frame = tk.Frame(main_container, width=1000)
        main_container.add(left_frame, width=1000)
        
        self.canvas = tk.Canvas(left_frame, bg="#1a1a2e", width=1000, height=850, cursor="crosshair")
        self.canvas.pack(fill=tk.BOTH, expand=True, padx=5, pady=5)
        self.canvas.bind("<Button-1>", self.on_canvas_click)
        
        # 按钮行
        btn_frame = tk.Frame(left_frame)
        btn_frame.pack(fill=tk.X, padx=5, pady=5)
        
        tk.Button(btn_frame, text="Load Image", command=self.select_image,
                 bg="#3498db", fg="white", font=("Arial", 10)).pack(side=tk.LEFT, padx=3)
        tk.Button(btn_frame, text="Set Output", command=self.select_output_dir,
                 bg="#34495e", fg="white", font=("Arial", 10)).pack(side=tk.LEFT, padx=3)
        tk.Button(btn_frame, text="Clear Points", command=self.clear_points,
                 bg="#e74c3c", fg="white", font=("Arial", 10)).pack(side=tk.LEFT, padx=3)
        tk.Button(btn_frame, text="Execute Slice", command=self.execute_slice,
                 bg="#27ae60", fg="white", font=("Arial", 10)).pack(side=tk.LEFT, padx=3)
        
        self.image_info_label = tk.Label(left_frame, text="No image loaded",
                                        bg="#34495e", fg="#3498db", font=("Arial", 9), anchor=tk.W)
        self.image_info_label.pack(fill=tk.X, padx=5, pady=2)
        
        # 右侧
        right_frame = tk.Frame(main_container, bg="#2c3e50", width=400)
        main_container.add(right_frame, width=400)
        
        # 角点进度
        progress_group = tk.LabelFrame(right_frame, text="Corner Points",
                                   font=("Arial", 11), bg="#34495e", fg="#3498db", padx=10, pady=10)
        progress_group.pack(fill=tk.X, padx=10, pady=5)
        
        self.progress_label = tk.Label(progress_group, text="Click 4 corners: 0/4",
                                         bg="#2c3e50", fg="#f1c40f", font=("Arial", 12, "bold"), anchor=tk.W)
        self.progress_label.pack(fill=tk.X)
        
        tk.Label(progress_group, text="任意顺序点击4个点，自动验证平行四边形",
                bg="#34495e", fg="#95a5a6", font=("Arial", 9)).pack(anchor=tk.W)
        
        # d值范围设置
        range_group = tk.LabelFrame(right_frame, text="Set d by Range (Cell Range)",
                                    font=("Arial", 11), bg="#34495e", fg="#27ae60", padx=10, pady=10)
        range_group.pack(fill=tk.X, padx=10, pady=5)
        
        tk.Label(range_group, text="Cell#: Start ~ End  |  d value",
                bg="#34495e", fg="#ecf0f1", font=("Arial", 9)).pack(anchor=tk.W)
        
        # 范围输入框
        self.range_frames = []
        self.range_start_entries = []
        self.range_end_entries = []
        self.range_d_entries = []
        
        for i, r in enumerate(self.d_ranges):
            frame = tk.Frame(range_group, bg="#34495e")
            frame.pack(fill=tk.X, pady=2)
            
            start_entry = tk.Entry(frame, width=6, font=("Arial", 9), justify=tk.CENTER)
            start_entry.pack(side=tk.LEFT, padx=2)
            start_entry.insert(0, str(r['start']))
            
            tk.Label(frame, text="~", bg="#34495e", fg="white").pack(side=tk.LEFT)
            
            end_entry = tk.Entry(frame, width=6, font=("Arial", 9), justify=tk.CENTER)
            end_entry.pack(side=tk.LEFT, padx=2)
            end_entry.insert(0, str(r['end'] if r['end'] <= 100 else 101))
            
            tk.Label(frame, text="d=", bg="#34495e", fg="white").pack(side=tk.LEFT, padx=(10,2))
            
            d_entry = tk.Entry(frame, width=8, font=("Arial", 9), justify=tk.CENTER)
            d_entry.pack(side=tk.LEFT, padx=2)
            d_entry.insert(0, str(r['d']))
            
            self.range_frames.append(frame)
            self.range_start_entries.append(start_entry)
            self.range_end_entries.append(end_entry)
            self.range_d_entries.append(d_entry)
        
        # 按钮行
        range_btn_frame = tk.Frame(range_group, bg="#34495e")
        range_btn_frame.pack(fill=tk.X, pady=5)
        
        tk.Button(range_btn_frame, text="+ Add Range", command=self.add_range,
                 bg="#3498db", fg="white", font=("Arial", 9)).pack(side=tk.LEFT, padx=2)
        tk.Button(range_btn_frame, text="- Remove", command=self.remove_range,
                 bg="#e74c3c", fg="white", font=("Arial", 9)).pack(side=tk.LEFT, padx=2)
        tk.Button(range_btn_frame, text="Calculate", command=self.calculate_all_depths,
                 bg="#27ae60", fg="white", font=("Arial", 9)).pack(side=tk.LEFT, padx=2)
        
        # 起始深度
        start_depth_frame = tk.Frame(range_group, bg="#34495e")
        start_depth_frame.pack(fill=tk.X, pady=5)
        
        tk.Label(start_depth_frame, text="Start Depth:", bg="#34495e", fg="#ecf0f1",
                font=("Arial", 9)).pack(side=tk.LEFT)
        self.base_depth_entry = tk.Entry(start_depth_frame, width=10, font=("Arial", 9))
        self.base_depth_entry.pack(side=tk.LEFT, padx=5)
        self.base_depth_entry.insert(0, "3660")
        tk.Button(start_depth_frame, text="Apply", command=self.calculate_all_depths,
                 bg="#3498db", fg="white", font=("Arial", 9)).pack(side=tk.LEFT, padx=2)
        
        # 100格子表格
        table_group = tk.LabelFrame(right_frame, text="100 Grid (拖选设置d值)",
                                    font=("Arial", 11), bg="#34495e", fg="#27ae60", padx=10, pady=5)
        table_group.pack(fill=tk.BOTH, expand=True, padx=10, pady=5)

        table_container = tk.Frame(table_group, bg="#34495e")
        table_container.pack(fill=tk.BOTH, expand=True, pady=5)

        tk.Label(table_container, text="R\\C", bg="#2c3e50", fg="white",
                font=("Arial", 7, "bold"), width=4, relief=tk.RIDGE).grid(row=0, column=0, padx=1, pady=1)
        for c in range(10):
            tk.Label(table_container, text=str(c+1), bg="#3498db", fg="white",
                    font=("Arial", 7, "bold"), width=5, relief=tk.RIDGE).grid(row=0, column=c+1, padx=1, pady=1)

        self.cell_labels = []  # 单元格标签
        for r in range(10):
            tk.Label(table_container, text=str(r+1), bg="#3498db", fg="white",
                    font=("Arial", 7, "bold"), width=4, relief=tk.RIDGE).grid(row=r+1, column=0, padx=1, pady=1)

            row_labels = []
            for c in range(10):
                lbl = tk.Label(table_container, width=5, relief=tk.RIDGE,
                             bg="#1a1a2e", fg="#ecf0f1", font=("Arial", 7), cursor="hand1")
                lbl.grid(row=r+1, column=c+1, padx=1, pady=1)
                lbl.bind('<Button-1>', lambda e, r=r, c=c: self.on_cell_select_start(r, c))
                lbl.bind('<B1-Motion>', lambda e, r=r, c=c: self.on_cell_select_drag(r, c))
                lbl.bind('<ButtonRelease-1>', lambda e: self.on_cell_select_end())
                row_labels.append(lbl)
            self.cell_labels.append(row_labels)

        # d值设置区域
        d_frame = tk.Frame(table_group, bg="#34495e")
        d_frame.pack(fill=tk.X, pady=5)

        tk.Label(d_frame, text="d值:", bg="#34495e", fg="#ecf0f1",
                font=("Arial", 9)).pack(side=tk.LEFT)
        self.selected_d_entry = tk.Entry(d_frame, width=8, font=("Arial", 9), justify=tk.CENTER)
        self.selected_d_entry.pack(side=tk.LEFT, padx=3)
        self.selected_d_entry.insert(0, "1.0")
        tk.Button(d_frame, text="应用", command=self.apply_selected_d,
                 bg="#27ae60", fg="white", font=("Arial", 9)).pack(side=tk.LEFT, padx=3)

        tk.Button(d_frame, text="全部", command=self.apply_all_d,
                 bg="#3498db", fg="white", font=("Arial", 9)).pack(side=tk.LEFT, padx=3)
        tk.Button(d_frame, text="清除", command=self.clear_selection,
                 bg="#e74c3c", fg="white", font=("Arial", 9)).pack(side=tk.LEFT, padx=3)

        self.selection_info = tk.Label(table_group, text="点击两个单元格选择范围",
                                       bg="#34495e", fg="#95a5a6", font=("Arial", 8), anchor=tk.W)
        self.selection_info.pack(fill=tk.X, pady=2)

        # 选择状态 - 点击两个点确定范围
        self.select_start = None  # 第一个点击 (r, c)
        self.select_end = None    # 第二个点击 (r, c)
        self.selected_cells = []  # 当前选中的单元格列表
        self.selection_color = "#3498db"  # 选中颜色
        
        # 切片大小和预览线设置
        size_frame = tk.Frame(right_frame, bg="#34495e")
        size_frame.pack(fill=tk.X, padx=10, pady=5)

        tk.Label(size_frame, text="切片大小:", bg="#34495e", fg="#ecf0f1",
                font=("Arial", 9)).grid(row=0, column=0, sticky=tk.W)
        self.slice_size_entry = tk.Entry(size_frame, width=8, font=("Arial", 9))
        self.slice_size_entry.grid(row=0, column=1, padx=5)
        self.slice_size_entry.insert(0, "256")
        self.slice_size_entry.bind('<FocusOut>', lambda e: self.slice_size.set(int(self.slice_size_entry.get())))
        self.slice_size_entry.bind('<Return>', lambda e: self.slice_size.set(int(self.slice_size_entry.get())))

        # 预览线控制
        self.show_preview_line = tk.BooleanVar(value=True)
        tk.Checkbutton(size_frame, text="显示预览线", variable=self.show_preview_line,
                      bg="#34495e", fg="#ecf0f1", font=("Arial", 9),
                      command=self.update_preview_lines).grid(row=1, column=0, columnspan=2, sticky=tk.W, pady=2)
        
        # XLSX
        xlsx_group = tk.LabelFrame(right_frame, text="XLSX to CSV",
                                   font=("Arial", 11), bg="#34495e", fg="#27ae60", padx=10, pady=10)
        xlsx_group.pack(fill=tk.X, padx=10, pady=5)
        
        tk.Button(xlsx_group, text="Select XLSX", command=self.select_xlsx_file,
                 bg="#27ae60", fg="white", font=("Arial", 9)).pack(fill=tk.X, pady=2)
        tk.Button(xlsx_group, text="Batch Convert All", command=self.batch_convert_xlsx,
                 bg="#8e44ad", fg="white", font=("Arial", 9)).pack(fill=tk.X, pady=2)
        
        self.xlsx_info_label = tk.Label(xlsx_group, text="", bg="#34495e", fg="#95a5a6",
                                       font=("Arial", 8), anchor=tk.W)
        self.xlsx_info_label.pack(fill=tk.X, pady=5)
    
    def add_range(self):
        """添加新的d值范围"""
        # 获取最后一个范围的end值
        if self.range_end_entries:
            last_end = int(self.range_end_entries[-1].get())
            new_start = last_end  # 左闭右开，下一个从last_end开始
            new_end = min(last_end + 10, 101)  # 默认加10，最多到101
        else:
            new_start = 1
            new_end = 101  # 左闭右开 [1, 101) 包含 1-100
        
        self.d_ranges.append({'start': new_start, 'end': new_end, 'd': 1.0})

        # 更新显示
        self.update_range_display()

    def remove_range(self):
        """删除最后一个d值范围"""
        if len(self.d_ranges) > 1:
            self.d_ranges.pop()
            frame = self.range_frames.pop()
            frame.destroy()
            self.range_start_entries.pop()
            self.range_end_entries.pop()
            self.range_d_entries.pop()
    
    def calculate_all_depths(self):
        """根据d值范围计算所有深度"""
        # 更新d_ranges数据
        for i in range(len(self.d_ranges)):
            try:
                self.d_ranges[i]['start'] = int(self.range_start_entries[i].get())
                self.d_ranges[i]['end'] = int(self.range_end_entries[i].get())
                self.d_ranges[i]['d'] = float(self.range_d_entries[i].get())
            except:
                pass
        
        # 获取起始深度
        try:
            base_depth = float(self.base_depth_entry.get())
        except:
            base_depth = 3660
        
        # 计算100个格子的深度
        # 格子编号: 第r行第c列 = r*10 + c + 1
        depths = [0] * 100
        current_depth = base_depth
        
        for i in range(100):
            cell_num = i + 1  # 1-100
            
            # 找到对应的d值 (左闭右开区间 [start, end))
            d = 1.0
            for r in self.d_ranges:
                if r['start'] <= cell_num < r['end']:
                    d = r['d']
                    break
            
            depths[i] = current_depth
            current_depth += d
        
        # 更新depth_data
        for r in range(10):
            for c in range(10):
                idx = r * 10 + c
                self.depth_data[r][c] = depths[idx]
        
        self.update_depth_display()

    def update_depth_display(self):
        """更新深度显示"""
        for r in range(10):
            for c in range(10):
                val = self.depth_data[r][c]
                self.cell_labels[r][c].config(text=f"{val:.1f}" if val != int(val) else str(int(val)))
    
    def on_cell_edit(self, row, col):
        try:
            value = float(self.cell_labels[row][col].cget('text'))
            self.depth_data[row][col] = value
        except:
            pass

    def on_cell_select_start(self, row, col):
        """点击选择单元格 - 第一个点（按序号范围）"""
        cell_num = row * 10 + col + 1
        if self.select_start is None:
            # 第一次点击
            self.select_start = (row, col, cell_num)
            self.select_end = (row, col, cell_num)
            self.cell_labels[row][col].config(bg=self.selection_color)
            self.selection_info.config(text=f"已选: #{cell_num}, 点击第二个单元格")
        elif self.select_start[2] == cell_num:
            # 点击第一个单元格，取消选择
            self.clear_selection()
        else:
            # 第二次点击 - 完成选择
            self.select_end = (row, col, cell_num)
            self.update_selection()
            self.selection_info.config(text=f"序号范围: {len(self.selected_cells)} 个单元格，可设置d值")

    def merge_ranges(self, ranges):
        """合并相邻的相同d值范围"""
        if not ranges:
            return []
        # 按起始点排序
        sorted_ranges = sorted(ranges, key=lambda x: x['start'])
        merged = [sorted_ranges[0]]
        for r in sorted_ranges[1:]:
            last = merged[-1]
            if r['start'] <= last['end'] and r['d'] == last['d']:
                # 相邻且d值相同，合并
                last['end'] = max(last['end'], r['end'])
            else:
                merged.append(r)
        return merged

    def update_range_display(self):
        """更新右侧范围输入框显示"""
        # 保存父容器引用
        parent = self.range_frames[0].master if self.range_frames else None

        # 销毁现有输入框
        for frame in self.range_frames:
            frame.destroy()
        self.range_frames = []
        self.range_start_entries = []
        self.range_end_entries = []
        self.range_d_entries = []

        if parent is None:
            return

        # 重新创建输入框
        for r in self.d_ranges:
            frame = tk.Frame(parent, bg="#34495e")
            frame.pack(fill=tk.X, pady=2)

            start_entry = tk.Entry(frame, width=6, font=("Arial", 9), justify=tk.CENTER)
            start_entry.pack(side=tk.LEFT, padx=2)
            start_entry.insert(0, str(r['start']))

            tk.Label(frame, text="~", bg="#34495e", fg="white").pack(side=tk.LEFT)

            end_entry = tk.Entry(frame, width=6, font=("Arial", 9), justify=tk.CENTER)
            end_entry.pack(side=tk.LEFT, padx=2)
            end_entry.insert(0, str(r['end'] if r['end'] <= 100 else 101))

            tk.Label(frame, text="d=", bg="#34495e", fg="white").pack(side=tk.LEFT, padx=(10,2))

            d_entry = tk.Entry(frame, width=8, font=("Arial", 9), justify=tk.CENTER)
            d_entry.pack(side=tk.LEFT, padx=2)
            d_entry.insert(0, str(r['d']))

            self.range_frames.append(frame)
            self.range_start_entries.append(start_entry)
            self.range_end_entries.append(end_entry)
            self.range_d_entries.append(d_entry)

    def on_cell_select_drag(self, row, col):
        """拖动选择 - 不再使用"""
        pass

    def update_selection(self):
        """更新选中区域 - 按单元格序号顺序（支持跨行选区）"""
        # 清除之前的选中状态
        for r in range(10):
            for c in range(10):
                self.cell_labels[r][c].config(bg="#1a1a2e")

        if self.select_start and self.select_end:
            num1 = self.select_start[2]  # 起始序号
            num2 = self.select_end[2]    # 结束序号
            min_num = min(num1, num2)
            max_num = max(num1, num2)

            self.selected_cells = []
            for num in range(min_num, max_num + 1):
                row = (num - 1) // 10
                col = (num - 1) % 10
                self.cell_labels[row][col].config(bg=self.selection_color)
                self.selected_cells.append((row, col))

    def on_cell_select_end(self):
        """选择结束"""
        pass  # 保持选中状态，等待用户设置d值

    def apply_selected_d(self):
        """应用d值到选中的单元格"""
        if not self.selected_cells:
            messagebox.showwarning("Warning", "请先选择单元格")
            return

        try:
            d_value = float(self.selected_d_entry.get())
        except:
            messagebox.showwarning("Warning", "请输入有效的d值")
            return

        if len(self.selected_cells) == 1:
            r, c = self.selected_cells[0]
            cell_num = r * 10 + c + 1
            self.depth_data[r][c] = d_value
        else:
            # 获取序号范围
            cell_nums = [r * 10 + c + 1 for r, c in self.selected_cells]
            min_num = min(cell_nums)
            max_num = max(cell_nums)

            # 更新 d_ranges - 替换或更新覆盖该范围的范围
            new_ranges = []
            for rng in self.d_ranges:
                # 完全在选中范围外的保留
                if rng['end'] <= min_num or rng['start'] > max_num:
                    new_ranges.append(rng)
                else:
                    # 有交集 - 需要拆分
                    # 范围左侧在选中范围外
                    if rng['start'] < min_num:
                        new_ranges.append({'start': rng['start'], 'end': min_num, 'd': rng['d']})
                    # 范围右侧在选中范围外
                    if rng['end'] > max_num + 1:
                        new_ranges.append({'start': max_num + 1, 'end': rng['end'], 'd': rng['d']})

            # 添加新的选中范围
            new_ranges.append({'start': min_num, 'end': max_num + 1, 'd': d_value})

            # 合并相邻的相同d值范围
            self.d_ranges = self.merge_ranges(new_ranges)

        # 重新计算所有深度
        self.calculate_all_depths()

        # 更新右侧范围显示
        self.update_range_display()

        messagebox.showinfo("Success", f"已设置 #{min(cell_nums)}-{max(cell_nums)} 的d值为 {d_value}")

    def apply_all_d(self):
        """设置所有100个单元格的d值"""
        try:
            d_value = float(self.selected_d_entry.get())
        except:
            messagebox.showwarning("Warning", "请输入有效的d值")
            return

        # 选择所有单元格
        self.selected_cells = []
        for r in range(10):
            for c in range(10):
                self.selected_cells.append((r, c))
                self.cell_labels[r][c].config(bg=self.selection_color)

        # 更新 d_ranges
        self.d_ranges = [{'start': 1, 'end': 101, 'd': d_value}]

        # 重新计算所有深度
        self.calculate_all_depths()
        messagebox.showinfo("Success", f"已设置所有单元格的d值为 {d_value}")

    def get_d_for_cell(self, row, col):
        """获取指定单元格的d值"""
        cell_num = row * 10 + col + 1
        for r in self.d_ranges:
            if r['start'] <= cell_num < r['end']:
                return r['d']
        return 1.0

    def clear_selection(self):
        """清除选择"""
        self.select_start = None
        self.select_end = None
        self.selected_cells = []
        for r in range(10):
            for c in range(10):
                self.cell_labels[r][c].config(bg="#1a1a2e")
        self.selection_info.config(text="点击两个单元格选择范围")

    def update_preview_lines(self):
        """更新画布上的预览线"""
        if not self.show_preview_line.get() or len(self.corner_points) != 4:
            return

        # 删除旧的预览线
        for line_id in getattr(self, 'preview_line_ids', []):
            self.canvas.delete(line_id)
        self.preview_line_ids = []

        # 绘制新的预览线
        grid_pts = self.get_grid_points(self.corner_points)

        # 水平线
        for i in range(11):
            for j in range(10):
                p1 = grid_pts[i * 11 + j]
                p2 = grid_pts[i * 11 + j + 1]
                sx1 = int(p1[0] * self.scale) + self.offset_x
                sy1 = int(p1[1] * self.scale) + self.offset_y
                sx2 = int(p2[0] * self.scale) + self.offset_x
                sy2 = int(p2[1] * self.scale) + self.offset_y
                line_id = self.canvas.create_line(sx1, sy1, sx2, sy2, fill='#3498db', width=1)
                self.preview_line_ids.append(line_id)

        # 垂直线
        for i in range(11):
            for j in range(10):
                p1 = grid_pts[j * 11 + i]
                p2 = grid_pts[(j + 1) * 11 + i]
                sx1 = int(p1[0] * self.scale) + self.offset_x
                sy1 = int(p1[1] * self.scale) + self.offset_y
                sx2 = int(p2[0] * self.scale) + self.offset_x
                sy2 = int(p2[1] * self.scale) + self.offset_y
                line_id = self.canvas.create_line(sx1, sy1, sx2, sy2, fill='#27ae60', width=1)
                self.preview_line_ids.append(line_id)
    
    def clear_points(self):
        self.corner_points = []
        for mid in self.temp_marker_ids:
            self.canvas.delete(mid)
        self.temp_marker_ids = []
        # 删除预览线
        for line_id in getattr(self, 'preview_line_ids', []):
            self.canvas.delete(line_id)
        self.preview_line_ids = []
        self.update_progress()
        self.draw_image_and_grid()

    def clear_selection(self):
        """清除选择"""
        self.select_start = None
        self.select_end = None
        self.selected_cells = []
        for r in range(10):
            for c in range(10):
                self.cell_labels[r][c].config(bg="#1a1a2e")
        self.selection_info.config(text="点击两个单元格选择范围")

    def update_progress(self):
        cnt = len(self.corner_points)
        self.progress_label.config(text=f"Corner Points: {cnt}/4")

    def on_canvas_click(self, event):
        if self.current_image is None:
            return

        x, y = event.x, event.y

        if x < self.offset_x or x > self.offset_x + self.display_w:
            return
        if y < self.offset_y or y > self.offset_y + self.display_h:
            return

        ix = int((x - self.offset_x) / self.scale)
        iy = int((y - self.offset_y) / self.scale)

        # 检查是否点击了现有的角点（用于拖动）
        hit_point = None
        for i, (px, py) in enumerate(self.corner_points):
            sx = int(px * self.scale) + self.offset_x
            sy = int(py * self.scale) + self.offset_y
            if abs(x - sx) < 20 and abs(y - sy) < 20:
                hit_point = i
                break

        if hit_point is not None:
            # 拖动现有角点
            self.dragging_point = hit_point
            self.canvas.bind('<B1-Motion>', self.on_canvas_drag)
            self.canvas.bind('<ButtonRelease-1>', self.on_canvas_drag_end)
        else:
            # 添加新点
            if len(self.corner_points) < 4:
                self.corner_points.append((ix, iy))

                # 标记点 - 按顺序: 1-左上, 2-右上, 3-左下, 4-右下
                labels = ["1-左上", "2-右上", "3-左下", "4-右下"]
                pid = self.canvas.create_oval(x-12, y-12, x+12, y+12,
                                             fill="#f1c40f", outline="#e67e22", width=3)
                self.temp_marker_ids.append(pid)
                pid = self.canvas.create_text(x, y-22, text=labels[len(self.corner_points)-1],
                                            fill="#f1c40f", font=("Arial", 11, "bold"))
                self.temp_marker_ids.append(pid)

                self.update_progress()

                # 已有4个点时，绘制预览线
                if len(self.corner_points) == 4:
                    self.update_preview_lines()

    def on_canvas_drag(self, event):
        """拖动角点"""
        if not hasattr(self, 'dragging_point') or self.dragging_point is None:
            return

        x, y = event.x, event.y

        # 检查边界
        if x < self.offset_x or x > self.offset_x + self.display_w:
            return
        if y < self.offset_y or y > self.offset_y + self.display_h:
            return

        ix = int((x - self.offset_x) / self.scale)
        iy = int((y - self.offset_y) / self.scale)

        # 更新点位置
        self.corner_points[self.dragging_point] = (ix, iy)

        # 重新绘制角点标记
        self.draw_corner_markers()

        # 更新预览线
        self.update_preview_lines()

    def draw_image_only(self):
        """仅绘制图像，不清除标记"""
        if self.canvas is None or self.current_image is None:
            return

        canvas_w = self.canvas.winfo_width()
        canvas_h = self.canvas.winfo_height()
        if canvas_w < 10 or canvas_h < 10:
            canvas_w, canvas_h = 1000, 850

        img_w, img_h = self.current_image.size
        scale = min(canvas_w / img_w, canvas_h / img_h)
        new_w = int(img_w * scale)
        new_h = int(img_h * scale)
        offset_x = (canvas_w - new_w) // 2
        offset_y = (canvas_h - new_h) // 2

        self.scale = scale
        self.offset_x = offset_x
        self.offset_y = offset_y
        self.display_w = new_w
        self.display_h = new_h

        # 清除预览线和角点
        self.canvas.delete("all")
        self.temp_marker_ids = []

        display_img = self.current_image.copy().resize((new_w, new_h), Image.LANCZOS)
        photo = ImageTk.PhotoImage(display_img)
        self.canvas.create_image(offset_x, offset_y, anchor=tk.NW, image=photo)
        self.canvas.image = photo

        # 重绘角点标记
        self.draw_corner_markers()

    def draw_corner_markers(self):
        """绘制角点标记"""
        # 清除旧的角点标记
        for mid in self.temp_marker_ids:
            self.canvas.delete(mid)
        self.temp_marker_ids = []

        labels = ["1-左上", "2-右上", "3-左下", "4-右下"]
        for i, (x, y) in enumerate(self.corner_points):
            sx = int(x * self.scale) + self.offset_x
            sy = int(y * self.scale) + self.offset_y
            pid = self.canvas.create_oval(sx-12, sy-12, sx+12, sy+12,
                                         fill="#f1c40f", outline="#e67e22", width=3)
            self.temp_marker_ids.append(pid)
            pid = self.canvas.create_text(sx, sy-22, text=labels[i],
                                        fill="#f1c40f", font=("Arial", 11, "bold"))
            self.temp_marker_ids.append(pid)

    def on_canvas_drag_end(self, event):
        """拖动结束"""
        self.dragging_point = None
        self.canvas.unbind('<B1-Motion>')
        self.canvas.unbind('<ButtonRelease-1>')

    def validate_quadrilateral(self):
        """验证四个点能否形成平行四边形"""
        pts = self.corner_points
        if len(pts) != 4:
            return False

        # 计算向量
        p0, p1, p2, p3 = pts

        # 向量
        v01 = (p1[0] - p0[0], p1[1] - p0[1])  # 边1
        v02 = (p2[0] - p0[0], p2[1] - p0[1])  # 边2
        v13 = (p3[0] - p1[0], p3[1] - p1[1])  # 边3
        v23 = (p3[0] - p2[0], p3[1] - p2[1])  # 边4

        def is_parallel(v1, v2, tolerance=5):
            """判断两向量是否平行 (叉积为0)"""
            cross = abs(v1[0] * v2[1] - v1[1] * v2[0])
            len1 = (v1[0]**2 + v1[1]**2) ** 0.5
            len2 = (v2[0]**2 + v2[1]**2) ** 0.5
            if len1 < 1 or len2 < 1:
                return False
            return cross / (len1 * len2) < tolerance / 100  # 容许5%误差

        # 检查两组对边平行
        if is_parallel(v01, v23) and is_parallel(v02, v13):
            return True

        # 尝试交换对边组合
        if is_parallel(v01, v13) and is_parallel(v02, v23):
            return True

        return False

    def sort_points_to_quadrilateral(self):
        """将4个点排序为 TL, TR, BR, BL（左上、右上、右下、左下）"""
        pts = self.corner_points

        # 1. 按 x 坐标分组：左边两个和右边两个
        sorted_by_x = sorted(pts, key=lambda p: p[0])
        left_pts = sorted_by_x[:2]
        right_pts = sorted_by_x[2:]

        # 2. 每组内按 y 排序
        left_pts = sorted(left_pts, key=lambda p: p[1])  # 上在下之前
        right_pts = sorted(right_pts, key=lambda p: p[1])  # 上在下之前

        # 3. 组合为 TL, TR, BR, BL
        self.corner_points = [left_pts[0], right_pts[0], right_pts[1], left_pts[1]]

    def draw_image_and_grid(self):
        if self.canvas is None:
            return
        
        self.canvas.delete("all")
        
        if self.current_image is None:
            self.canvas.create_text(500, 425, text="Load an image to start", fill="#95a5a6", font=("Arial", 16))
            return
        
        canvas_w = self.canvas.winfo_width()
        canvas_h = self.canvas.winfo_height()
        if canvas_w < 10 or canvas_h < 10:
            canvas_w, canvas_h = 1000, 850
        
        img_w, img_h = self.current_image.size
        scale = min(canvas_w / img_w, canvas_h / img_h)
        new_w = int(img_w * scale)
        new_h = int(img_h * scale)
        offset_x = (canvas_w - new_w) // 2
        offset_y = (canvas_h - new_h) // 2
        
        self.scale = scale
        self.offset_x = offset_x
        self.offset_y = offset_y
        self.display_w = new_w
        self.display_h = new_h
        
        display_img = self.current_image.copy().resize((new_w, new_h), Image.LANCZOS)
        photo = ImageTk.PhotoImage(display_img)
        self.canvas.create_image(offset_x, offset_y, anchor=tk.NW, image=photo)
        self.canvas.image = photo  # 保持引用
        self.canvas.create_rectangle(offset_x, offset_y, offset_x + new_w, offset_y + new_h, 
                                   outline="#3498db", width=4)
        
        # 重绘角点标记
        labels = ["1-左上", "2-右上", "3-左下", "4-右下"]
        for i, (x, y) in enumerate(self.corner_points):
            sx = int(x * scale) + offset_x
            sy = int(y * scale) + offset_y
            pid = self.canvas.create_oval(sx-12, sy-12, sx+12, sy+12, 
                                         fill="#f1c40f", outline="#e67e22", width=3)
            self.temp_marker_ids.append(pid)
            pid = self.canvas.create_text(sx, sy-22, text=labels[i], 
                                        fill="#f1c40f", font=("Arial", 11, "bold"))
            self.temp_marker_ids.append(pid)
        
        self.update_progress()
    
    def select_image(self):
        file_path = filedialog.askopenfilename(
            title="Select Image",
            filetypes=[("Image files", "*.avif *.jpg *.jpeg *.png *.webp"), ("All files", "*.*")]
        )
        if file_path:
            self.load_image(file_path)
    
    def load_image(self, file_path):
        try:
            self.current_image_path = file_path
            self.current_image = Image.open(file_path)
            self.corner_points = []
            self.temp_marker_ids = []
            
            filename = os.path.basename(file_path)
            self.image_info_label.config(text=f"Loaded: {filename} | {self.current_image.size[0]}x{self.current_image.size[1]}")
            
            # 从文件名解析起始深度
            basename = os.path.splitext(filename)[0]
            if '-' in basename:
                parts = basename.split('-')
                if len(parts) == 2:
                    try:
                        start_depth = int(parts[0])
                        self.base_depth_entry.delete(0, tk.END)
                        self.base_depth_entry.insert(0, str(start_depth))
                    except:
                        pass
            
            self.root.after(100, self.draw_image_and_grid)
            messagebox.showinfo("Success", f"Image loaded!\n\nClick 4 corners to set region")
        except Exception as e:
            messagebox.showerror("Error", f"Cannot load image: {str(e)}")
    
    def select_output_dir(self):
        dir_path = filedialog.askdirectory(title="Select Output Directory")
        if dir_path:
            self.output_dir.set(dir_path)
    
    def get_grid_points(self, src_pts):
        """计算4点形成的四边形内的10x10等分点"""
        grid_pts = []

        # 四个角点: 0-左上(TL), 1-右上(TR), 2-左下(BL), 3-右下(BR)
        tl, tr, bl, br = src_pts[0], src_pts[1], src_pts[2], src_pts[3]

        for i in range(11):
            for j in range(11):
                u = i / 10.0
                v = j / 10.0

                # 上下边的线性插值
                top_x = (1-v) * tl[0] + v * tr[0]
                top_y = (1-v) * tl[1] + v * tr[1]
                bottom_x = (1-v) * bl[0] + v * br[0]
                bottom_y = (1-v) * bl[1] + v * br[1]

                # 最终点的插值
                x = (1-u) * top_x + u * bottom_x
                y = (1-u) * top_y + u * bottom_y

                grid_pts.append((x, y))

        return grid_pts
    
    def execute_slice(self):
        if self.current_image is None:
            messagebox.showwarning("Warning", "Please load an image first")
            return
        
        if len(self.corner_points) != 4:
            messagebox.showwarning("Warning", f"Please click 4 corner points ({len(self.corner_points)}/4)")
            return
        
        try:
            output_dir = self.output_dir.get()
            if not os.path.exists(output_dir):
                os.makedirs(output_dir)
            
            img_name = os.path.splitext(os.path.basename(self.current_image_path))[0]
            slice_dir = os.path.join(output_dir, f"{img_name}_slices")
            os.makedirs(slice_dir, exist_ok=True)
            
            img_w, img_h = self.current_image.size
            slice_size = self.slice_size.get()
            
            grid_pts = self.get_grid_points(self.corner_points)
            
            count = 0
            for r in range(10):
                for c in range(10):
                    tl = r * 11 + c
                    tr = r * 11 + c + 1
                    br = (r + 1) * 11 + c + 1
                    bl = (r + 1) * 11 + c

                    pts = [grid_pts[tl], grid_pts[tr], grid_pts[br], grid_pts[bl]]

                    # 检查坐标有效性
                    if not all(p and len(p) == 2 for p in pts):
                        continue

                    slice_img = self.extract_quad(self.current_image, pts)
                    slice_img = slice_img.resize((slice_size, slice_size), Image.LANCZOS)

                    depth = self.depth_data[r][c]
                    end_depth = depth + 1.0
                    # 使用行列号确保文件名唯一，避免深度相同时覆盖
                    slice_filename = f"slice_{r+1:02d}_{c+1:02d}_depth_{int(depth)}_{int(end_depth)}.png"
                    slice_img.save(os.path.join(slice_dir, slice_filename))
                    count += 1
            
            # 保存带网格预览
            grid_img = self.current_image.copy().convert("RGB")
            draw = ImageDraw.Draw(grid_img)
            
            for i, (x, y) in enumerate(self.corner_points):
                draw.ellipse([x-5, y-5, x+5, y+5], fill='red')
            
            for i in range(11):
                for j in range(10):
                    p1 = grid_pts[i * 11 + j]
                    p2 = grid_pts[i * 11 + j + 1]
                    draw.line([int(p1[0]), int(p1[1]), int(p2[0]), int(p2[1])], fill='blue', width=2)
                
                for j in range(10):
                    p1 = grid_pts[j * 11 + i]
                    p2 = grid_pts[(j + 1) * 11 + i]
                    draw.line([int(p1[0]), int(p1[1]), int(p2[0]), int(p2[1])], fill='green', width=2)
            
            grid_img.save(os.path.join(slice_dir, f"{img_name}_grid.png"))
            
            messagebox.showinfo("Success", f"Sliced {count} images!\nOutput: {slice_dir}")
            
        except Exception as e:
            messagebox.showerror("Error", f"Slice failed: {str(e)}")
    
    def extract_quad(self, img, pts):
        # 使用 round() 四舍五入代替 int() 截断，避免多边形缩小
        int_pts = [tuple(round(p[i]) for i in range(2)) for p in pts]

        min_x = max(0, min(p[0] for p in int_pts))
        max_x = min(img.width, max(p[0] for p in int_pts) + 1)
        min_y = max(0, min(p[1] for p in int_pts))
        max_y = min(img.height, max(p[1] for p in int_pts) + 1)

        width = max(1, max_x - min_x)
        height = max(1, max_y - min_y)

        crop = img.crop((min_x, min_y, max_x, max_y))

        result = Image.new('RGBA', (width, height), (0, 0, 0, 0))
        mask = Image.new('L', (width, height), 0)
        mask_draw = ImageDraw.Draw(mask)

        local_pts = [(p[0] - min_x, p[1] - min_y) for p in int_pts]
        mask_draw.polygon(local_pts, fill=255)
        
        result.paste(crop, (0, 0))
        result.putalpha(mask)
        
        result_rgb = Image.new('RGB', (width, height), (0, 0, 0))
        result_rgb.paste(result, (0, 0), mask)
        
        return result_rgb
    
    def select_xlsx_file(self):
        file_path = filedialog.askopenfilename(title="Select XLSX", filetypes=[("Excel", "*.xlsx *.xls"), ("All", "*.*")])
        if file_path:
            self.convert_xlsx_to_csv(file_path)
    
    def convert_xlsx_to_csv(self, xlsx_path):
        try:
            import time
            df = pd.read_excel(xlsx_path)
            csv_path = xlsx_path.replace('.xlsx', '.csv').replace('.xls', '.csv')
            if os.path.exists(csv_path):
                csv_path = f"{os.path.splitext(csv_path)[0]}_{int(time.time())}.csv"
            df.to_csv(csv_path, index=False, encoding='utf-8-sig')
            self.xlsx_info_label.config(text=f"Converted: {os.path.basename(csv_path)}")
            messagebox.showinfo("Success", "CSV saved!")
        except Exception as e:
            messagebox.showerror("Error", f"Conversion failed: {str(e)}")
    
    def batch_convert_xlsx(self):
        root_dir = "岩屑图像"
        converted = []
        if not os.path.exists(root_dir):
            messagebox.showwarning("Warning", "Directory not found")
            return
        for folder in os.listdir(root_dir):
            folder_path = os.path.join(root_dir, folder)
            if os.path.isdir(folder_path):
                for filename in os.listdir(folder_path):
                    if filename.endswith(('.xlsx', '.xls')):
                        try:
                            df = pd.read_excel(os.path.join(folder_path, filename))
                            csv_path = os.path.join(folder_path, filename.replace('.xlsx', '.csv').replace('.xls', '.csv'))
                            df.to_csv(csv_path, index=False, encoding='utf-8-sig')
                            converted.append(csv_path)
                        except:
                            pass
        if converted:
            messagebox.showinfo("Done", f"Converted {len(converted)} files")
        else:
            messagebox.showwarning("Warning", "No files found")


if __name__ == "__main__":
    root = tk.Tk()
    app = SliceToolApp(root)
    root.mainloop()
