# fish_tracker.py import cv2 import numpy as np from collections import defaultdict, deque from ultralytics import YOLO from deep_sort_realtime.deepsort_tracker import DeepSort class FishTracker: def __init__(self, model_path="yolov8n.pt", conf_thresh=0.5): self.model = YOLO(model_path) self.tracker = DeepSort(max_age=30, n_init=3, nn_budget=100) self.conf_thresh = conf_thresh self.trajectories = defaultdict(lambda: deque(maxlen=30)) # 存最近30帧轨迹 self.alerts = [] def process_frame(self, frame): """返回处理后帧、告警列表、每条鱼的统计信息""" self.alerts.clear() h, w = frame.shape[:2] # 1. YOLO 检测(只检测 "fish" 类别,COCO 中类别 ID=17) results = self.model(frame, verbose=False)[0] detections = [] for box in results.boxes: if int(box.cls.item()) == 17 and box.conf.item() > self.conf_thresh: # COCO: fish=17 x1, y1, x2, y2 = map(int, box.xyxy[0]) conf = float(box.conf.item()) detections.append([[x1, y1, x2 - x1, y2 - y1], conf, "fish"]) # 2. DeepSORT 追踪 tracks = self.tracker.update_tracks(detections, frame=frame) # 3. 更新轨迹 & 绘制 output_frame = frame.copy() track_stats = [] for track in tracks: if not track.is_confirmed(): continue track_id = track.track_id ltrb = track.to_ltrb() x1, y1, x2, y2 = map(int, ltrb) cx, cy = (x1 + x2) // 2, (y1 + y2) // 2 # 记录轨迹 self.trajectories[track_id].append((cx, cy, cv2.getTickCount())) # 计算速度(像素/秒) traj = self.trajectories[track_id] step_dist = 0 if len(traj) >= 2: total_dist = 0 total_time = 0 for i in range(1, len(traj)): dx = traj[i][0] - traj[i-1][0] dy = traj[i][1] - traj[i-1][1] dist = np.sqrt(dx*dx + dy*dy) if i == len(traj) - 1: step_dist = dist dt = (traj[i][2] - traj[i-1][2]) / cv2.getTickFrequency() total_dist += dist total_time += dt speed = total_dist / total_time if total_time > 0 else 0 else: speed = 0 # 判断是否沉底(底部 20% 区域) is_bottom = cy > h * 0.8 # 异常规则 is_slow = speed < 15 if speed < 15 and is_bottom: self.alerts.append(f"⚠️ 鱼 {track_id}: 低速({speed:.1f}px/s) + 沉底") track_stats.append({ "id": track_id, "speed": speed, "is_bottom": is_bottom, "is_slow": is_slow, "distance_delta": step_dist, }) # 绘制框和ID cv2.rectangle(output_frame, (x1, y1), (x2, y2), (0, 255, 0), 2) cv2.putText(output_frame, f"Fish-{track_id}", (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2) return output_frame, self.alerts, track_stats