观赏鱼健康系统v1.0
This commit is contained in:
@@ -0,0 +1,73 @@
|
||||
# fish_tracker.py
|
||||
import cv2
|
||||
import numpy as np
|
||||
from collections import defaultdict, deque
|
||||
from ultralytics import YOLO
|
||||
from deep_sort_realtime.deepsort_tracker import DeepSort
|
||||
|
||||
class FishTracker:
|
||||
def __init__(self, model_path="yolov8n.pt", conf_thresh=0.5):
|
||||
self.model = YOLO(model_path)
|
||||
self.tracker = DeepSort(max_age=30, n_init=3, nn_budget=100)
|
||||
self.conf_thresh = conf_thresh
|
||||
self.trajectories = defaultdict(lambda: deque(maxlen=30)) # 存最近30帧轨迹
|
||||
self.alerts = []
|
||||
|
||||
def process_frame(self, frame):
|
||||
self.alerts.clear()
|
||||
h, w = frame.shape[:2]
|
||||
|
||||
# 1. YOLO 检测(只检测 "fish" 类别,COCO 中类别 ID=17)
|
||||
results = self.model(frame, verbose=False)[0]
|
||||
detections = []
|
||||
for box in results.boxes:
|
||||
if int(box.cls.item()) == 17 and box.conf.item() > self.conf_thresh: # COCO: fish=17
|
||||
x1, y1, x2, y2 = map(int, box.xyxy[0])
|
||||
conf = float(box.conf.item())
|
||||
detections.append([[x1, y1, x2 - x1, y2 - y1], conf, "fish"])
|
||||
|
||||
# 2. DeepSORT 追踪
|
||||
tracks = self.tracker.update_tracks(detections, frame=frame)
|
||||
|
||||
# 3. 更新轨迹 & 绘制
|
||||
output_frame = frame.copy()
|
||||
for track in tracks:
|
||||
if not track.is_confirmed():
|
||||
continue
|
||||
track_id = track.track_id
|
||||
ltrb = track.to_ltrb()
|
||||
x1, y1, x2, y2 = map(int, ltrb)
|
||||
cx, cy = (x1 + x2) // 2, (y1 + y2) // 2
|
||||
|
||||
# 记录轨迹
|
||||
self.trajectories[track_id].append((cx, cy, cv2.getTickCount()))
|
||||
|
||||
# 计算速度(像素/秒)
|
||||
traj = self.trajectories[track_id]
|
||||
if len(traj) >= 2:
|
||||
total_dist = 0
|
||||
total_time = 0
|
||||
for i in range(1, len(traj)):
|
||||
dx = traj[i][0] - traj[i-1][0]
|
||||
dy = traj[i][1] - traj[i-1][1]
|
||||
dist = np.sqrt(dx*dx + dy*dy)
|
||||
dt = (traj[i][2] - traj[i-1][2]) / cv2.getTickFrequency()
|
||||
total_dist += dist
|
||||
total_time += dt
|
||||
speed = total_dist / total_time if total_time > 0 else 0
|
||||
else:
|
||||
speed = 0
|
||||
|
||||
# 判断是否沉底(底部 20% 区域)
|
||||
is_bottom = cy > h * 0.8
|
||||
|
||||
# 异常规则
|
||||
if speed < 15 and is_bottom:
|
||||
self.alerts.append(f"⚠️ 鱼 {track_id}: 低速({speed:.1f}px/s) + 沉底")
|
||||
|
||||
# 绘制框和ID
|
||||
cv2.rectangle(output_frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
||||
cv2.putText(output_frame, f"Fish-{track_id}", (x1, y1 - 10),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)
|
||||
|
||||
return output_frame, self.alerts
|
||||
Reference in New Issue
Block a user