Files
fish_monitor/fish_tracker.py
T
2025-12-23 09:30:56 +08:00

73 lines
2.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# fish_tracker.py
import cv2
import numpy as np
from collections import defaultdict, deque
from ultralytics import YOLO
from deep_sort_realtime.deepsort_tracker import DeepSort
class FishTracker:
def __init__(self, model_path="yolov8n.pt", conf_thresh=0.5):
self.model = YOLO(model_path)
self.tracker = DeepSort(max_age=30, n_init=3, nn_budget=100)
self.conf_thresh = conf_thresh
self.trajectories = defaultdict(lambda: deque(maxlen=30)) # 存最近30帧轨迹
self.alerts = []
def process_frame(self, frame):
self.alerts.clear()
h, w = frame.shape[:2]
# 1. YOLO 检测(只检测 "fish" 类别,COCO 中类别 ID=17
results = self.model(frame, verbose=False)[0]
detections = []
for box in results.boxes:
if int(box.cls.item()) == 17 and box.conf.item() > self.conf_thresh: # COCO: fish=17
x1, y1, x2, y2 = map(int, box.xyxy[0])
conf = float(box.conf.item())
detections.append([[x1, y1, x2 - x1, y2 - y1], conf, "fish"])
# 2. DeepSORT 追踪
tracks = self.tracker.update_tracks(detections, frame=frame)
# 3. 更新轨迹 & 绘制
output_frame = frame.copy()
for track in tracks:
if not track.is_confirmed():
continue
track_id = track.track_id
ltrb = track.to_ltrb()
x1, y1, x2, y2 = map(int, ltrb)
cx, cy = (x1 + x2) // 2, (y1 + y2) // 2
# 记录轨迹
self.trajectories[track_id].append((cx, cy, cv2.getTickCount()))
# 计算速度(像素/秒)
traj = self.trajectories[track_id]
if len(traj) >= 2:
total_dist = 0
total_time = 0
for i in range(1, len(traj)):
dx = traj[i][0] - traj[i-1][0]
dy = traj[i][1] - traj[i-1][1]
dist = np.sqrt(dx*dx + dy*dy)
dt = (traj[i][2] - traj[i-1][2]) / cv2.getTickFrequency()
total_dist += dist
total_time += dt
speed = total_dist / total_time if total_time > 0 else 0
else:
speed = 0
# 判断是否沉底(底部 20% 区域)
is_bottom = cy > h * 0.8
# 异常规则
if speed < 15 and is_bottom:
self.alerts.append(f"⚠️ 鱼 {track_id}: 低速({speed:.1f}px/s) + 沉底")
# 绘制框和ID
cv2.rectangle(output_frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
cv2.putText(output_frame, f"Fish-{track_id}", (x1, y1 - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)
return output_frame, self.alerts