73 lines
2.8 KiB
Python
73 lines
2.8 KiB
Python
# fish_tracker.py
|
||
import cv2
|
||
import numpy as np
|
||
from collections import defaultdict, deque
|
||
from ultralytics import YOLO
|
||
from deep_sort_realtime.deepsort_tracker import DeepSort
|
||
|
||
class FishTracker:
|
||
def __init__(self, model_path="yolov8n.pt", conf_thresh=0.5):
|
||
self.model = YOLO(model_path)
|
||
self.tracker = DeepSort(max_age=30, n_init=3, nn_budget=100)
|
||
self.conf_thresh = conf_thresh
|
||
self.trajectories = defaultdict(lambda: deque(maxlen=30)) # 存最近30帧轨迹
|
||
self.alerts = []
|
||
|
||
def process_frame(self, frame):
|
||
self.alerts.clear()
|
||
h, w = frame.shape[:2]
|
||
|
||
# 1. YOLO 检测(只检测 "fish" 类别,COCO 中类别 ID=17)
|
||
results = self.model(frame, verbose=False)[0]
|
||
detections = []
|
||
for box in results.boxes:
|
||
if int(box.cls.item()) == 17 and box.conf.item() > self.conf_thresh: # COCO: fish=17
|
||
x1, y1, x2, y2 = map(int, box.xyxy[0])
|
||
conf = float(box.conf.item())
|
||
detections.append([[x1, y1, x2 - x1, y2 - y1], conf, "fish"])
|
||
|
||
# 2. DeepSORT 追踪
|
||
tracks = self.tracker.update_tracks(detections, frame=frame)
|
||
|
||
# 3. 更新轨迹 & 绘制
|
||
output_frame = frame.copy()
|
||
for track in tracks:
|
||
if not track.is_confirmed():
|
||
continue
|
||
track_id = track.track_id
|
||
ltrb = track.to_ltrb()
|
||
x1, y1, x2, y2 = map(int, ltrb)
|
||
cx, cy = (x1 + x2) // 2, (y1 + y2) // 2
|
||
|
||
# 记录轨迹
|
||
self.trajectories[track_id].append((cx, cy, cv2.getTickCount()))
|
||
|
||
# 计算速度(像素/秒)
|
||
traj = self.trajectories[track_id]
|
||
if len(traj) >= 2:
|
||
total_dist = 0
|
||
total_time = 0
|
||
for i in range(1, len(traj)):
|
||
dx = traj[i][0] - traj[i-1][0]
|
||
dy = traj[i][1] - traj[i-1][1]
|
||
dist = np.sqrt(dx*dx + dy*dy)
|
||
dt = (traj[i][2] - traj[i-1][2]) / cv2.getTickFrequency()
|
||
total_dist += dist
|
||
total_time += dt
|
||
speed = total_dist / total_time if total_time > 0 else 0
|
||
else:
|
||
speed = 0
|
||
|
||
# 判断是否沉底(底部 20% 区域)
|
||
is_bottom = cy > h * 0.8
|
||
|
||
# 异常规则
|
||
if speed < 15 and is_bottom:
|
||
self.alerts.append(f"⚠️ 鱼 {track_id}: 低速({speed:.1f}px/s) + 沉底")
|
||
|
||
# 绘制框和ID
|
||
cv2.rectangle(output_frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
||
cv2.putText(output_frame, f"Fish-{track_id}", (x1, y1 - 10),
|
||
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)
|
||
|
||
return output_frame, self.alerts |