From ae9d25225597d97f548c86d2bf28729aa75f2eaa Mon Sep 17 00:00:00 2001 From: panda_home <1415243231@qq.com> Date: Tue, 23 Dec 2025 09:30:56 +0800 Subject: [PATCH] =?UTF-8?q?=E8=A7=82=E8=B5=8F=E9=B1=BC=E5=81=A5=E5=BA=B7?= =?UTF-8?q?=E7=B3=BB=E7=BB=9Fv1.0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fish_tracker.py | 73 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) create mode 100644 fish_tracker.py diff --git a/fish_tracker.py b/fish_tracker.py new file mode 100644 index 0000000..4eb4772 --- /dev/null +++ b/fish_tracker.py @@ -0,0 +1,73 @@ +# fish_tracker.py +import cv2 +import numpy as np +from collections import defaultdict, deque +from ultralytics import YOLO +from deep_sort_realtime.deepsort_tracker import DeepSort + +class FishTracker: + def __init__(self, model_path="yolov8n.pt", conf_thresh=0.5): + self.model = YOLO(model_path) + self.tracker = DeepSort(max_age=30, n_init=3, nn_budget=100) + self.conf_thresh = conf_thresh + self.trajectories = defaultdict(lambda: deque(maxlen=30)) # 存最近30帧轨迹 + self.alerts = [] + + def process_frame(self, frame): + self.alerts.clear() + h, w = frame.shape[:2] + + # 1. YOLO 检测(只检测 "fish" 类别,COCO 中类别 ID=17) + results = self.model(frame, verbose=False)[0] + detections = [] + for box in results.boxes: + if int(box.cls.item()) == 17 and box.conf.item() > self.conf_thresh: # COCO: fish=17 + x1, y1, x2, y2 = map(int, box.xyxy[0]) + conf = float(box.conf.item()) + detections.append([[x1, y1, x2 - x1, y2 - y1], conf, "fish"]) + + # 2. DeepSORT 追踪 + tracks = self.tracker.update_tracks(detections, frame=frame) + + # 3. 更新轨迹 & 绘制 + output_frame = frame.copy() + for track in tracks: + if not track.is_confirmed(): + continue + track_id = track.track_id + ltrb = track.to_ltrb() + x1, y1, x2, y2 = map(int, ltrb) + cx, cy = (x1 + x2) // 2, (y1 + y2) // 2 + + # 记录轨迹 + self.trajectories[track_id].append((cx, cy, cv2.getTickCount())) + + # 计算速度(像素/秒) + traj = self.trajectories[track_id] + if len(traj) >= 2: + total_dist = 0 + total_time = 0 + for i in range(1, len(traj)): + dx = traj[i][0] - traj[i-1][0] + dy = traj[i][1] - traj[i-1][1] + dist = np.sqrt(dx*dx + dy*dy) + dt = (traj[i][2] - traj[i-1][2]) / cv2.getTickFrequency() + total_dist += dist + total_time += dt + speed = total_dist / total_time if total_time > 0 else 0 + else: + speed = 0 + + # 判断是否沉底(底部 20% 区域) + is_bottom = cy > h * 0.8 + + # 异常规则 + if speed < 15 and is_bottom: + self.alerts.append(f"⚠️ 鱼 {track_id}: 低速({speed:.1f}px/s) + 沉底") + + # 绘制框和ID + cv2.rectangle(output_frame, (x1, y1), (x2, y2), (0, 255, 0), 2) + cv2.putText(output_frame, f"Fish-{track_id}", (x1, y1 - 10), + cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2) + + return output_frame, self.alerts \ No newline at end of file