1509 lines
65 KiB
Python
1509 lines
65 KiB
Python
import flask
|
|
from flask import Response, request, jsonify, render_template, send_from_directory, send_file
|
|
import os
|
|
import time
|
|
import json
|
|
import threading
|
|
import uuid
|
|
import logging
|
|
import cv2
|
|
import numpy as np
|
|
import base64
|
|
from collections import Counter, defaultdict
|
|
from skimage import io as skio
|
|
import itertools
|
|
import random
|
|
import torch
|
|
import torch.nn.functional as F
|
|
import atexit
|
|
import settings_manager
|
|
import config
|
|
import database
|
|
import file_storage
|
|
import background_tasks
|
|
import ai_models
|
|
from bbox_writer import validate_bboxes_text, convert_text_to_rects_and_labels, extract_labels
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from colorama import Fore, Style, init
|
|
|
|
|
|
class ColoredFormatter(logging.Formatter):
|
|
COLORS = {
|
|
'TIMESTAMP': Fore.WHITE + Style.DIM,
|
|
'THREAD': Fore.CYAN,
|
|
'LEVEL_DEFAULT': Fore.WHITE,
|
|
'MESSAGE_DEFAULT': Fore.WHITE + Style.NORMAL,
|
|
|
|
logging.DEBUG: {'level': Fore.MAGENTA, 'message': Fore.MAGENTA},
|
|
logging.INFO: {'level': Fore.GREEN + Style.BRIGHT, 'message': Fore.WHITE},
|
|
logging.WARNING: {'level': Fore.YELLOW + Style.BRIGHT, 'message': Fore.YELLOW},
|
|
logging.ERROR: {'level': Fore.RED + Style.BRIGHT, 'message': Fore.RED},
|
|
logging.CRITICAL: {'level': Fore.RED + Style.BRIGHT, 'message': Fore.RED + Style.BRIGHT},
|
|
}
|
|
|
|
def __init__(self, fmt=None, datefmt=None, style='%'):
|
|
super().__init__(fmt, datefmt, style)
|
|
|
|
def format(self, record):
|
|
level_colors = self.COLORS.get(record.levelno, {
|
|
'level': self.COLORS['LEVEL_DEFAULT'],
|
|
'message': self.COLORS['MESSAGE_DEFAULT']
|
|
})
|
|
|
|
asctime = self.formatTime(record, self.datefmt)
|
|
colored_asctime = f"{self.COLORS['TIMESTAMP']}{asctime}{Style.RESET_ALL}"
|
|
|
|
colored_levelname = f"{level_colors['level']}{record.levelname:<8}{Style.RESET_ALL}"
|
|
|
|
colored_threadname = f"{self.COLORS['THREAD']}[{record.threadName}]{Style.RESET_ALL}"
|
|
|
|
message = record.getMessage()
|
|
colored_message = f"{level_colors['message']}{message}{Style.RESET_ALL}"
|
|
|
|
if record.exc_info:
|
|
exc_text = self.formatException(record.exc_info)
|
|
colored_message += f"\n{level_colors['message']}{exc_text}{Style.RESET_ALL}"
|
|
return f"{colored_asctime} - {colored_levelname} - {colored_threadname} - {colored_message}"
|
|
|
|
try:
|
|
import yaml
|
|
except ImportError:
|
|
logging.error("PyYAML is not installed! Dataset export will fail. Please run 'pip install pyyaml'.")
|
|
yaml = None
|
|
|
|
app = flask.Flask(__name__)
|
|
app.secret_key = os.urandom(24)
|
|
prototype_executor = ThreadPoolExecutor(max_workers=8)
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(threadName)s - %(message)s')
|
|
|
|
with app.app_context():
|
|
database.init_db()
|
|
database.migrate_db()
|
|
file_storage.init_storage()
|
|
|
|
|
|
def validate_description(desc, existing_descriptions):
|
|
if not (1 <= len(desc) <= config.MAX_DESCRIPTION_LENGTH):
|
|
return False, "Description must be between 1 and 30 characters."
|
|
if desc in existing_descriptions:
|
|
return False, "Description is a duplicate."
|
|
return True, ""
|
|
|
|
|
|
def sanitize_dict(d):
|
|
return d
|
|
|
|
|
|
def string_to_color_bgr(s):
|
|
hash_val = 0
|
|
for char in s:
|
|
hash_val = ord(char) + ((hash_val << 5) - hash_val)
|
|
hue = hash_val % 180
|
|
color_hsv = np.uint8([[[hue, 200, 200]]])
|
|
color_bgr = cv2.cvtColor(color_hsv, cv2.COLOR_HSV2BGR)[0][0]
|
|
return tuple(map(int, color_bgr))
|
|
|
|
|
|
def calculate_iou(boxA, boxB):
|
|
xA = max(boxA[0], boxB[0])
|
|
yA = max(boxA[1], boxB[1])
|
|
xB = min(boxA[2], boxB[2])
|
|
yB = min(boxA[3], boxB[3])
|
|
|
|
interArea = max(0, xB - xA) * max(0, yB - yA)
|
|
if interArea == 0:
|
|
return 0.0
|
|
|
|
boxAArea = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1])
|
|
boxBArea = (boxB[2] - boxB[0]) * (boxB[3] - boxB[1])
|
|
|
|
iou = interArea / float(boxAArea + boxBArea - interArea)
|
|
return iou
|
|
|
|
|
|
def generate_mosaic_previews(sample_pool, selected_video_uuid, selected_frame_number):
|
|
if len(sample_pool) < 4:
|
|
sample_pool.extend(sample_pool * (4 - len(sample_pool)))
|
|
|
|
all_labels = database.get_all_class_labels()
|
|
class_map = {name: i for i, name in enumerate(all_labels)}
|
|
|
|
conn = database.get_db_connection()
|
|
image_infos = []
|
|
for sample in sample_pool:
|
|
video_info = database.get_video_entity(sample['video_uuid'])
|
|
frame_info = conn.execute(
|
|
'SELECT bboxes_text FROM video_frames WHERE video_uuid = ? AND frame_number = ?',
|
|
(sample['video_uuid'], sample['frame_number'])
|
|
).fetchone()
|
|
|
|
if video_info and frame_info and frame_info['bboxes_text']:
|
|
image_infos.append({
|
|
"video_uuid": sample['video_uuid'],
|
|
"frame_number": sample['frame_number'],
|
|
"bboxes_text": frame_info['bboxes_text'],
|
|
"width": video_info['width'],
|
|
"height": video_info['height']
|
|
})
|
|
conn.close()
|
|
|
|
if len(image_infos) < 4:
|
|
return jsonify({'success': False,
|
|
'message': 'Not enough labeled images in the sample pool to generate a mosaic preview.'}), 400
|
|
|
|
previews = []
|
|
selected_image_info = next((info for info in image_infos if info['video_uuid'] == selected_video_uuid and info[
|
|
'frame_number'] == selected_frame_number), None)
|
|
|
|
for _ in range(6):
|
|
other_images = [info for info in image_infos if info != selected_image_info]
|
|
random.shuffle(other_images)
|
|
|
|
mosaic_set = [selected_image_info] + other_images[:3] if selected_image_info else other_images[:4]
|
|
random.shuffle(mosaic_set)
|
|
|
|
mosaic_img, final_bboxes = file_storage.create_mosaic_image(mosaic_set, class_map)
|
|
|
|
h, w, _ = mosaic_img.shape
|
|
vis_image = mosaic_img.copy()
|
|
for bbox_data in final_bboxes:
|
|
class_index, x_center, y_center, width_norm, height_norm = bbox_data
|
|
class_name = all_labels[class_index]
|
|
color = string_to_color_bgr(class_name)
|
|
|
|
x1 = int((x_center - width_norm / 2) * w)
|
|
y1 = int((y_center - height_norm / 2) * h)
|
|
x2 = int((x_center + width_norm / 2) * w)
|
|
y2 = int((y_center + height_norm / 2) * h)
|
|
cv2.rectangle(vis_image, (x1, y1), (x2, y2), color, 2)
|
|
|
|
_, buffer = cv2.imencode('.jpg', vis_image)
|
|
img_base64 = base64.b64encode(buffer).decode('utf-8')
|
|
previews.append(f"data:image/jpeg;base64,{img_base64}")
|
|
|
|
return previews
|
|
|
|
|
|
@app.route('/')
|
|
def index():
|
|
return render_template('root.html',
|
|
limit_data=config.get_limit_data_for_render_template(),
|
|
tracker_fns=config.TRACKER_FNS)
|
|
|
|
|
|
@app.route('/labelVideo')
|
|
def label_video():
|
|
task_uuid = request.args.get('task_uuid')
|
|
if not task_uuid:
|
|
return "Task UUID is required.", 400
|
|
|
|
task_entity = database.get_task_entity(task_uuid)
|
|
if not task_entity:
|
|
return "Annotation task not found.", 404
|
|
|
|
if task_entity['status'] == 'PENDING':
|
|
database.update_task_status(task_uuid, 'IN_PROGRESS')
|
|
task_entity = database.get_task_entity(task_uuid)
|
|
|
|
video_entity = database.get_video_entity(task_entity['video_uuid'])
|
|
if not video_entity:
|
|
return "Associated video not found.", 404
|
|
|
|
first_frame_url = f"/media/frames/{video_entity['video_uuid']}/frame_{task_entity['start_frame']:05d}.jpg"
|
|
settings = settings_manager.load_settings()
|
|
|
|
return render_template('labelVideo.html',
|
|
task_entity=sanitize_dict(task_entity),
|
|
video_entity=sanitize_dict(video_entity),
|
|
first_frame_url=first_frame_url,
|
|
settings=settings,
|
|
limit_data=config.get_limit_data_for_render_template())
|
|
|
|
|
|
@app.route('/media/<path:path>')
|
|
def send_media(path):
|
|
return send_from_directory(config.STORAGE_DIR, path)
|
|
|
|
|
|
@app.route('/media/annotated_frame/<video_uuid>/<int:frame_number>.jpg')
|
|
def serve_annotated_frame(video_uuid, frame_number):
|
|
try:
|
|
frame_path = file_storage.get_frame_path(video_uuid, frame_number)
|
|
if not os.path.exists(frame_path):
|
|
return "Frame not found", 404
|
|
image = cv2.imread(frame_path)
|
|
if image is None:
|
|
return "Could not read frame image", 500
|
|
|
|
conn = database.get_db_connection()
|
|
frame_data = conn.execute(
|
|
'SELECT bboxes_text FROM video_frames WHERE video_uuid = ? AND frame_number = ?',
|
|
(video_uuid, frame_number)
|
|
).fetchone()
|
|
conn.close()
|
|
|
|
bboxes_text = frame_data['bboxes_text'] if frame_data else None
|
|
|
|
if bboxes_text and bboxes_text.strip():
|
|
rects, labels, _ = convert_text_to_rects_and_labels(bboxes_text)
|
|
for i, rect in enumerate(rects):
|
|
label = labels[i]
|
|
color = string_to_color_bgr(label)
|
|
x1, y1, x2, y2 = rect
|
|
cv2.rectangle(image, (x1, y1), (x2, y2), color, 2)
|
|
(text_width, text_height), baseline = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
|
|
cv2.rectangle(image, (x1, y1 - text_height - 5), (x1 + text_width, y1), color, -1)
|
|
cv2.putText(image, label, (x1, y1 - 4), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
|
|
|
|
success, buffer = cv2.imencode('.jpg', image)
|
|
if not success:
|
|
return "Failed to encode image", 500
|
|
|
|
return Response(buffer.tobytes(), mimetype='image/jpeg')
|
|
|
|
except Exception as e:
|
|
logging.error(f"Error generating annotated frame for {video_uuid}/{frame_number}: {e}")
|
|
return "Internal server error", 500
|
|
|
|
|
|
@app.route('/listVideos', methods=['GET'])
|
|
def list_videos():
|
|
all_videos = database.get_all_video_list()
|
|
ready_videos = database.get_ready_videos_with_labels()
|
|
return jsonify({
|
|
'all_videos': [sanitize_dict(v) for v in all_videos],
|
|
'ready_videos_for_dataset': [sanitize_dict(v) for v in ready_videos]
|
|
})
|
|
|
|
|
|
@app.route('/uploadVideo', methods=['POST'])
|
|
def upload_video():
|
|
desc = request.form.get('description')
|
|
video_file = request.files.get('video_file')
|
|
|
|
is_valid, message = validate_description(desc, [v['description'] for v in database.get_all_video_list()])
|
|
if not is_valid:
|
|
return jsonify({'success': False, 'message': message}), 400
|
|
if not video_file:
|
|
return jsonify({'success': False, 'message': 'No video file provided.'}), 400
|
|
|
|
create_time_ms = int(time.time() * 1000)
|
|
video_uuid = database.create_video_entry(desc, video_file.filename, 0, create_time_ms)
|
|
file_storage.save_uploaded_video(video_file, video_uuid)
|
|
|
|
threading.Thread(target=background_tasks.extract_frames_task, args=(video_uuid,),
|
|
name=f"Extractor-{video_uuid[:6]}").start()
|
|
|
|
return jsonify({'success': True, 'video_uuid': video_uuid})
|
|
|
|
|
|
@app.route('/importFrames', methods=['POST'])
|
|
def import_frames():
|
|
video_uuid = request.form.get('video_uuid')
|
|
frame_files = request.files.getlist('frame_files')
|
|
|
|
if not video_uuid or not frame_files:
|
|
return jsonify({'success': False, 'message': 'Missing video UUID or frame files.'}), 400
|
|
|
|
video = database.get_video_entity(video_uuid)
|
|
if not video:
|
|
return jsonify({'success': False, 'message': 'Video not found.'}), 404
|
|
|
|
try:
|
|
imported_count = database.add_frames_from_upload(video_uuid, frame_files)
|
|
return jsonify({'success': True, 'imported_count': imported_count})
|
|
except Exception as e:
|
|
logging.error(f"Failed to import frames for {video_uuid}: {e}")
|
|
return jsonify({'success': False, 'message': str(e)}), 500
|
|
|
|
|
|
@app.route('/retrieveVideoEntity', methods=['POST'])
|
|
def retrieve_video_entity():
|
|
video_uuid = request.json.get('video_uuid')
|
|
entity = database.get_video_entity(video_uuid)
|
|
if entity:
|
|
return jsonify({'success': True, 'video_entity': sanitize_dict(entity)})
|
|
return jsonify({'success': False, 'message': 'Video not found.'})
|
|
|
|
|
|
@app.route('/deleteVideo', methods=['POST'])
|
|
def delete_video():
|
|
video_uuid = request.json.get('video_uuid')
|
|
database.delete_video(video_uuid)
|
|
file_storage.delete_video_file(video_uuid)
|
|
file_storage.delete_frames_for_video(video_uuid)
|
|
return jsonify({'success': True})
|
|
|
|
|
|
@app.route('/retrieveVideoFrames', methods=['POST'])
|
|
def retrieve_video_frames():
|
|
video_uuid = request.json.get('video_uuid')
|
|
frames = database.get_video_frames(video_uuid)
|
|
for frame in frames:
|
|
frame['image_url'] = f"/media/frames/{video_uuid}/frame_{frame['frame_number']:05d}.jpg"
|
|
return jsonify({'success': True, 'frames': [sanitize_dict(f) for f in frames]})
|
|
|
|
|
|
@app.route('/storeVideoFrameBboxesText', methods=['POST'])
|
|
def store_video_frame_bboxes_text():
|
|
data = request.json
|
|
video_uuid = data.get('video_uuid')
|
|
frame_number = int(data.get('frame_number'))
|
|
bboxes_text = validate_bboxes_text(data.get('bboxes_text'))
|
|
|
|
database.save_frame_bboxes(video_uuid, frame_number, bboxes_text)
|
|
|
|
return jsonify({'success': True})
|
|
|
|
|
|
@app.route('/listTasks', methods=['GET'])
|
|
def list_tasks():
|
|
video_uuid = request.args.get('video_uuid')
|
|
if not video_uuid:
|
|
return jsonify({'success': False, 'message': 'Video UUID is required.'}), 400
|
|
tasks = database.get_tasks_for_video(video_uuid)
|
|
return jsonify({'success': True, 'tasks': [sanitize_dict(t) for t in tasks]})
|
|
|
|
|
|
@app.route('/createTask', methods=['POST'])
|
|
def create_task():
|
|
data = request.json
|
|
video_uuid = data.get('video_uuid')
|
|
assigned_to = data.get('assigned_to')
|
|
description = data.get('description', '')
|
|
start_frame = data.get('start_frame')
|
|
end_frame = data.get('end_frame')
|
|
|
|
if not all([video_uuid, assigned_to, start_frame is not None, end_frame is not None]):
|
|
return jsonify({'success': False, 'message': 'Missing required fields.'}), 400
|
|
|
|
try:
|
|
start_frame, end_frame = int(start_frame), int(end_frame)
|
|
except (ValueError, TypeError):
|
|
return jsonify({'success': False, 'message': 'Frame numbers must be integers.'}), 400
|
|
|
|
video = database.get_video_entity(video_uuid)
|
|
if not video:
|
|
return jsonify({'success': False, 'message': 'Video not found.'}), 404
|
|
|
|
if not (0 <= start_frame < end_frame < video['frame_count']):
|
|
return jsonify({'success': False,
|
|
'message': f'Invalid frame range. Must be within 0 and {video["frame_count"] - 1}.'}), 400
|
|
|
|
try:
|
|
task_uuid = database.create_annotation_task(video_uuid, assigned_to, description, start_frame, end_frame)
|
|
return jsonify({'success': True, 'task_uuid': task_uuid})
|
|
except ValueError as e:
|
|
return jsonify({'success': False, 'message': str(e)}), 400
|
|
|
|
|
|
@app.route('/deleteTask', methods=['POST'])
|
|
def delete_task():
|
|
task_uuid = request.json.get('task_uuid')
|
|
database.delete_task(task_uuid)
|
|
return jsonify({'success': True})
|
|
|
|
|
|
@app.route('/updateTaskStatus', methods=['POST'])
|
|
def update_task_status():
|
|
data = request.json
|
|
task_uuid = data.get('task_uuid')
|
|
status = data.get('status')
|
|
if not task_uuid or status not in ['PENDING', 'IN_PROGRESS', 'COMPLETED']:
|
|
return jsonify({'success': False, 'message': 'Invalid task UUID or status.'}), 400
|
|
database.update_task_status(task_uuid, status)
|
|
return jsonify({'success': True})
|
|
|
|
|
|
@app.route('/listClasses', methods=['GET'])
|
|
def list_classes():
|
|
labels = database.get_all_class_labels()
|
|
return jsonify({'success': True, 'labels': labels})
|
|
|
|
|
|
@app.route('/api/rebuild_prototypes', methods=['POST'])
|
|
def rebuild_prototypes():
|
|
data = request.json
|
|
class_name = data.get('class_name')
|
|
|
|
if class_name:
|
|
logging.info(f"Manual prototype rebuild requested for class: {class_name}")
|
|
prototype_executor.submit(ai_models.update_prototype_for_class, class_name)
|
|
message = f"Started rebuilding prototype for '{class_name}'."
|
|
else:
|
|
logging.info("Manual prototype rebuild requested for ALL classes.")
|
|
all_labels = database.get_all_class_labels()
|
|
for label in all_labels:
|
|
prototype_executor.submit(ai_models.update_prototype_for_class, label)
|
|
message = f"Started rebuilding prototypes for all {len(all_labels)} classes in the background."
|
|
|
|
return jsonify({'success': True, 'message': message})
|
|
|
|
@app.route('/api/interpolateBboxes', methods=['POST'])
|
|
def interpolate_bboxes():
|
|
data = request.json
|
|
video_uuid = data.get('video_uuid')
|
|
object_id = data.get('object_id')
|
|
start_frame_data = data.get('start_frame')
|
|
end_frame_data = data.get('end_frame')
|
|
|
|
if not all([video_uuid, object_id, start_frame_data, end_frame_data]):
|
|
return jsonify({'success': False, 'message': 'Missing required data.'}), 400
|
|
|
|
try:
|
|
start_frame_num = int(start_frame_data['frame_number'])
|
|
end_frame_num = int(end_frame_data['frame_number'])
|
|
start_bbox = start_frame_data['bbox']
|
|
end_bbox = end_frame_data['bbox']
|
|
label = start_bbox['label']
|
|
|
|
if start_frame_num >= end_frame_num:
|
|
start_frame_num, end_frame_num = end_frame_num, start_frame_num
|
|
start_bbox, end_bbox = end_bbox, start_bbox
|
|
|
|
total_steps = end_frame_num - start_frame_num
|
|
if total_steps <= 1:
|
|
return jsonify({'success': True, 'message': 'No frames to interpolate.'})
|
|
for i in range(1, total_steps):
|
|
current_frame_num = start_frame_num + i
|
|
t = i / float(total_steps)
|
|
interp_x1 = int(start_bbox['x1'] + (end_bbox['x1'] - start_bbox['x1']) * t)
|
|
interp_y1 = int(start_bbox['y1'] + (end_bbox['y1'] - start_bbox['y1']) * t)
|
|
interp_x2 = int(start_bbox['x2'] + (end_bbox['x2'] - start_bbox['x2']) * t)
|
|
interp_y2 = int(start_bbox['y2'] + (end_bbox['y2'] - start_bbox['y2']) * t)
|
|
|
|
new_bbox_line = f"{interp_x1},{interp_y1},{interp_x2},{interp_y2},{label},{object_id}"
|
|
conn = database.get_db_connection()
|
|
frame_db = conn.execute('SELECT bboxes_text FROM video_frames WHERE video_uuid = ? AND frame_number = ?',
|
|
(video_uuid, current_frame_num)).fetchone()
|
|
|
|
existing_bboxes = frame_db['bboxes_text'] if frame_db else ''
|
|
lines = existing_bboxes.split('\n') if existing_bboxes else []
|
|
updated_lines = [line for line in lines if not line.endswith(f',{object_id}')]
|
|
|
|
updated_lines.append(new_bbox_line)
|
|
final_bboxes_text = '\n'.join(filter(None, updated_lines))
|
|
|
|
conn.close()
|
|
database.save_frame_bboxes(video_uuid, current_frame_num, final_bboxes_text)
|
|
|
|
return jsonify({'success': True, 'message': f'Interpolated {total_steps - 1} frames successfully.'})
|
|
|
|
except Exception as e:
|
|
logging.error(f"Interpolation failed: {e}", exc_info=True)
|
|
return jsonify({'success': False, 'message': str(e)}), 500
|
|
|
|
|
|
@app.route('/addClass', methods=['POST'])
|
|
def add_class():
|
|
data = request.json
|
|
label_name = data.get('label_name', '').strip()
|
|
if not label_name:
|
|
return jsonify({'success': False, 'message': 'Label name cannot be empty.'}), 400
|
|
database.add_class_label(label_name)
|
|
return jsonify({'success': True})
|
|
|
|
|
|
@app.route('/deleteClass', methods=['POST'])
|
|
def delete_class():
|
|
data = request.json
|
|
label_name = data.get('label_name')
|
|
if not label_name:
|
|
return jsonify({'success': False, 'message': 'Label name is required.'}), 400
|
|
database.delete_class_label(label_name)
|
|
return jsonify({'success': True})
|
|
|
|
|
|
@app.route('/api/settings', methods=['GET'])
|
|
def get_settings():
|
|
settings = settings_manager.load_settings()
|
|
return jsonify({'success': True, 'settings': settings})
|
|
|
|
|
|
@app.route('/api/settings', methods=['POST'])
|
|
def save_settings():
|
|
new_settings = request.json
|
|
if not new_settings:
|
|
return jsonify({'success': False, 'message': 'No settings data provided.'}), 400
|
|
|
|
current_settings = settings_manager.load_settings()
|
|
|
|
sam_model_changed = current_settings.get('sam_model_checkpoint') != new_settings.get('sam_model_checkpoint')
|
|
mobilenet_model_changed = current_settings.get('feature_extractor_model_name') != new_settings.get('feature_extractor_model_name')
|
|
device_changed = current_settings.get('gpu_device') != new_settings.get('gpu_device')
|
|
|
|
restart_required = sam_model_changed or mobilenet_model_changed or device_changed
|
|
|
|
if settings_manager.save_settings(new_settings):
|
|
if sam_model_changed or device_changed:
|
|
logging.info("SAM model or device setting changed. Clearing SAM cache.")
|
|
try:
|
|
from ultralytics_sam_tasks import _sam_model_cache
|
|
_sam_model_cache["model"] = None
|
|
_sam_model_cache["path"] = None
|
|
except (ImportError, AttributeError):
|
|
logging.warning("Could not clear SAM model cache.")
|
|
|
|
if mobilenet_model_changed or device_changed:
|
|
logging.info("Feature extractor model or device setting changed. Clearing MobileNet cache.")
|
|
ai_models.clear_feature_extractor_cache()
|
|
|
|
if device_changed:
|
|
settings_manager.update_device()
|
|
|
|
return jsonify({
|
|
'success': True,
|
|
'message': 'Settings saved successfully!',
|
|
'restart_required': restart_required
|
|
})
|
|
else:
|
|
return jsonify({'success': False, 'message': 'Failed to save settings to file.'}), 500
|
|
|
|
|
|
@app.route('/api/clear_cache', methods=['POST'])
|
|
def clear_cache():
|
|
try:
|
|
count = len(ai_models.PREPROCESSED_DATA_CACHE)
|
|
ai_models.PREPROCESSED_DATA_CACHE.clear()
|
|
logging.info(f"Cleared {count} items from PREPROCESSED_DATA_CACHE.")
|
|
return jsonify({'success': True, 'message': f'Successfully cleared {count} cached items.'})
|
|
except Exception as e:
|
|
logging.error(f"Failed to clear cache: {e}")
|
|
return jsonify({'success': False, 'message': 'An error occurred while clearing the cache.'}), 500
|
|
|
|
|
|
@app.route('/samPredict', methods=['POST'])
|
|
def sam_predict():
|
|
try:
|
|
from ultralytics_sam_tasks import predict_box_from_point_ultralytics, get_sam_model
|
|
if not get_sam_model():
|
|
return jsonify({'success': False, 'message': 'Ultralytics SAM model is not available.'}), 501
|
|
except ImportError:
|
|
return jsonify({'success': False, 'message': 'SAM features are not installed on server.'}), 501
|
|
|
|
data = request.json
|
|
video_uuid = data.get('video_uuid')
|
|
frame_number = data.get('frame_number')
|
|
point_coords = data.get('point')
|
|
|
|
if not all([video_uuid, frame_number is not None, point_coords]):
|
|
return jsonify({'success': False, 'message': 'Missing required data (video_uuid, frame_number, point).'}), 400
|
|
|
|
try:
|
|
image_path = file_storage.get_frame_path(video_uuid, int(frame_number))
|
|
if not os.path.exists(image_path):
|
|
return jsonify({'success': False, 'message': 'Frame image not found on server.'}), 404
|
|
|
|
coords_tuple = (int(point_coords['x']), int(point_coords['y']))
|
|
|
|
bbox = predict_box_from_point_ultralytics(image_path, coords_tuple)
|
|
|
|
if bbox:
|
|
return jsonify({'success': True, 'bbox': bbox})
|
|
else:
|
|
return jsonify({'success': False, 'message': 'No object found at the specified point.'})
|
|
|
|
except Exception as e:
|
|
logging.error(f"SAM prediction failed: {e}", exc_info=True)
|
|
return jsonify({'success': False, 'message': str(e)}), 500
|
|
|
|
|
|
@app.route('/interactive_segment/preprocess', methods=['POST'])
|
|
def interactive_segment_preprocess_route():
|
|
data = request.json
|
|
video_uuid = data.get('video_uuid')
|
|
frame_number = int(data.get('frame_number'))
|
|
|
|
if video_uuid is None or frame_number is None:
|
|
return jsonify({'success': False, 'message': 'Missing video_uuid or frame_number.'}), 400
|
|
|
|
try:
|
|
ai_models.get_features_for_all_masks(video_uuid, frame_number)
|
|
cache_key = f"{video_uuid}_{frame_number}"
|
|
return jsonify({'success': True, 'message': 'Preprocessing successful', 'cache_key': cache_key})
|
|
except Exception as e:
|
|
logging.error(f"智能选择预处理失败: {e}", exc_info=True)
|
|
return jsonify({'success': False, 'message': f'Internal Server Error: {str(e)}'}), 500
|
|
|
|
|
|
@app.route('/interactive_segment/predict', methods=['POST'])
|
|
def interactive_segment_predict_route():
|
|
data = request.json
|
|
video_uuid = data.get('video_uuid')
|
|
frame_number = int(data.get('frame_number'))
|
|
prompt_boxes = data.get('prompt_boxes', [])
|
|
|
|
if not all([video_uuid, frame_number is not None, prompt_boxes]):
|
|
return jsonify({'success': False, 'message': 'Missing required data.'}), 400
|
|
if not prompt_boxes:
|
|
return jsonify({'success': False, 'message': 'Positive prompt boxes are required.'}), 400
|
|
|
|
try:
|
|
positive_prompt_box = prompt_boxes[0]
|
|
results = ai_models.predict_from_one_shot(video_uuid, frame_number, positive_prompt_box)
|
|
return jsonify({'success': True, 'results': results})
|
|
|
|
except Exception as e:
|
|
logging.error(f"智能选择预测失败: {e}", exc_info=True)
|
|
return jsonify({'success': False, 'message': f'Internal Server Error: {str(e)}'}), 500
|
|
|
|
|
|
@app.route('/interactive_segment/predict_from_dataset', methods=['POST'])
|
|
def predict_from_dataset_route():
|
|
data = request.json
|
|
video_uuid = data.get('video_uuid')
|
|
frame_number = int(data.get('frame_number'))
|
|
class_name = data.get('class_name')
|
|
|
|
if not all([video_uuid, frame_number is not None, class_name]):
|
|
return jsonify({'success': False, 'message': 'Missing required data.'}), 400
|
|
|
|
try:
|
|
positive_prototypes = ai_models.build_prototypes_for_class(class_name)
|
|
if positive_prototypes is None or len(positive_prototypes) == 0:
|
|
return jsonify({'success': False,
|
|
'message': f"No labeled examples found for class '{class_name}' in the dataset, or failed to extract features."})
|
|
|
|
results = ai_models.predict_with_prototypes(video_uuid, frame_number, positive_prototypes)
|
|
return jsonify({'success': True, 'results': results})
|
|
|
|
except Exception as e:
|
|
logging.error(f"Dataset-driven prediction failed: {e}", exc_info=True)
|
|
return jsonify({'success': False, 'message': f'Internal Server Error: {str(e)}'}), 500
|
|
|
|
@app.route('/api/background_preprocess_frame', methods=['POST'])
|
|
def background_preprocess_frame():
|
|
data = request.json
|
|
video_uuid = data.get('video_uuid')
|
|
frame_number = data.get('frame_number')
|
|
|
|
if not video_uuid or frame_number is None:
|
|
return jsonify({'success': False, 'message': 'Missing data.'}), 400
|
|
if background_tasks.active_tasks.get(video_uuid):
|
|
return jsonify({'success': False, 'message': 'Another task is active.'})
|
|
cache_key = f"{video_uuid}_{frame_number}"
|
|
if cache_key in ai_models.PREPROCESSED_DATA_CACHE:
|
|
return jsonify({'success': True, 'message': 'Already cached.'})
|
|
|
|
threading.Thread(
|
|
target=lambda: ai_models.get_features_for_all_masks(video_uuid, frame_number),
|
|
name=f"Preprocess-{video_uuid[:6]}-{frame_number}"
|
|
).start()
|
|
|
|
return jsonify({'success': True, 'message': 'Preprocessing started in background.'})
|
|
|
|
@app.route('/api/get_random_frames_for_neg_sampling', methods=['POST'])
|
|
def get_random_frames_for_neg_sampling():
|
|
data = request.json
|
|
video_uuid = data.get('video_uuid')
|
|
count = int(data.get('count', 10))
|
|
|
|
if not video_uuid:
|
|
return jsonify({'success': False, 'message': 'Video UUID is required.'}), 400
|
|
|
|
try:
|
|
all_frame_numbers = database.get_frame_numbers_for_video(video_uuid)
|
|
if len(all_frame_numbers) < count:
|
|
sampled_numbers = all_frame_numbers
|
|
else:
|
|
sampled_numbers = random.sample(all_frame_numbers, count)
|
|
|
|
frames_data = []
|
|
for fn in sorted(sampled_numbers):
|
|
frames_data.append({
|
|
'video_uuid': video_uuid,
|
|
'frame_number': fn,
|
|
'image_url': f"/media/frames/{video_uuid}/frame_{fn:05d}.jpg"
|
|
})
|
|
|
|
return jsonify({'success': True, 'frames': frames_data})
|
|
except Exception as e:
|
|
logging.error(f"Error getting random frames: {e}", exc_info=True)
|
|
return jsonify({'success': False, 'message': str(e)}), 500
|
|
|
|
|
|
@app.route('/apply_prototypes_to_video', methods=['POST'])
|
|
def apply_prototypes_to_video_route():
|
|
data = request.json
|
|
video_uuid = data.get('video_uuid')
|
|
class_name = data.get('class_name')
|
|
negative_samples = data.get('negative_samples', None)
|
|
confidence_threshold = float(data.get('confidence_threshold', 0.5))
|
|
|
|
if not video_uuid or not class_name:
|
|
return jsonify({'success': False, 'message': 'Video UUID and Class Name are required.'}), 400
|
|
|
|
if background_tasks.active_tasks.get(video_uuid):
|
|
return jsonify({'success': False, 'message': 'Another task is already running for this video.'}), 409
|
|
|
|
threading.Thread(
|
|
target=background_tasks.apply_prototypes_to_video_task,
|
|
args=(video_uuid, class_name, negative_samples, confidence_threshold, app.app_context()),
|
|
name=f"ApplyPrototypes-{video_uuid[:6]}"
|
|
).start()
|
|
|
|
return jsonify({'success': True, 'message': 'Task to apply suggestions has started.'})
|
|
|
|
|
|
@app.route('/startSam2Tracking', methods=['POST'])
|
|
def start_sam2_tracking():
|
|
try:
|
|
from ultralytics_sam_tasks import get_sam_model
|
|
if not get_sam_model():
|
|
return jsonify({'success': False, 'message': 'SAM tracking feature is not available on the server.'}), 501
|
|
except ImportError:
|
|
return jsonify({'success': False, 'message': 'SAM features are not installed on server.'}), 501
|
|
|
|
data = request.json
|
|
video_uuid = data.get('video_uuid')
|
|
start_frame = int(data.get('start_frame'))
|
|
end_frame = int(data.get('end_frame'))
|
|
init_bboxes_text = data.get('init_bboxes_text')
|
|
|
|
if not all([video_uuid, start_frame is not None, end_frame is not None, init_bboxes_text]):
|
|
return jsonify({'success': False, 'message': 'Missing required data for tracking.'}), 400
|
|
|
|
if background_tasks.active_tasks.get(video_uuid):
|
|
return jsonify(
|
|
{'success': False, 'message': 'Another task (extraction or tracking) is already running for this video.'})
|
|
|
|
tracker_uuid = str(uuid.uuid4().hex)
|
|
|
|
threading.Thread(target=background_tasks.start_sam2_tracking_task, args=(
|
|
video_uuid, tracker_uuid, start_frame, end_frame, init_bboxes_text
|
|
), name=f"SAM-Tracker-{video_uuid[:6]}").start()
|
|
|
|
return jsonify({'success': True, 'tracker_uuid': tracker_uuid})
|
|
|
|
|
|
@app.route('/startSam2BatchTracking', methods=['POST'])
|
|
def start_sam2_batch_tracking():
|
|
try:
|
|
from ultralytics_sam_tasks import get_sam_model
|
|
if not get_sam_model():
|
|
return jsonify({'success': False, 'message': 'SAM tracking feature is not available on the server.'}), 501
|
|
except ImportError:
|
|
return jsonify({'success': False, 'message': 'SAM features are not installed on server.'}), 501
|
|
|
|
data = request.json
|
|
video_uuid = data.get('video_uuid')
|
|
start_frame = int(data.get('start_frame'))
|
|
end_frame = int(data.get('end_frame'))
|
|
init_bboxes_text = data.get('init_bboxes_text')
|
|
|
|
if not all([video_uuid, start_frame is not None, end_frame is not None, init_bboxes_text]):
|
|
return jsonify({'success': False, 'message': 'Missing required data for batch tracking.'}), 400
|
|
|
|
if background_tasks.active_tasks.get(video_uuid):
|
|
return jsonify(
|
|
{'success': False, 'message': 'Another task (extraction or tracking) is already running for this video.'})
|
|
|
|
tracker_uuid = str(uuid.uuid4().hex)
|
|
|
|
threading.Thread(target=background_tasks.start_sam2_batch_tracking_task, args=(
|
|
video_uuid, tracker_uuid, start_frame, end_frame, init_bboxes_text
|
|
), name=f"SAM-Batch-Tracker-{video_uuid[:6]}").start()
|
|
|
|
return jsonify({'success': True, 'tracker_uuid': tracker_uuid})
|
|
|
|
|
|
@app.route('/streamSam2Tracking/<tracker_uuid>')
|
|
def stream_sam2_tracking(tracker_uuid):
|
|
def generate_events():
|
|
while tracker_uuid not in background_tasks.tracking_sessions:
|
|
time.sleep(0.1)
|
|
|
|
session = background_tasks.tracking_sessions.get(tracker_uuid)
|
|
if not session:
|
|
error_event = {"event": "error", "message": "Tracking session not found or failed to start."}
|
|
yield f"data: {json.dumps(error_event)}\n\n"
|
|
return
|
|
|
|
last_sent_frame = -1
|
|
try:
|
|
while True:
|
|
status = session.get('status', 'STARTING')
|
|
sorted_frames = sorted([k for k in session.get('results', {}).keys() if k > last_sent_frame])
|
|
|
|
for frame_num in sorted_frames:
|
|
result_data = {
|
|
"event": "update",
|
|
"frame_number": frame_num,
|
|
"bboxes_text": session['results'][frame_num],
|
|
"progress": session.get('progress', 0),
|
|
"total": session.get('total', 0)
|
|
}
|
|
yield f"data: {json.dumps(result_data)}\n\n"
|
|
last_sent_frame = frame_num
|
|
|
|
if status in ['COMPLETED', 'STOPPED', 'FAILED']:
|
|
final_event = {"event": status.lower(), "message": session.get('message', '')}
|
|
yield f"data: {json.dumps(final_event)}\n\n"
|
|
break
|
|
|
|
time.sleep(0.2)
|
|
except GeneratorExit:
|
|
logging.info(f"Client disconnected from SSE stream for tracker {tracker_uuid}")
|
|
finally:
|
|
logging.info(f"SSE stream for tracker {tracker_uuid} is closing.")
|
|
|
|
return Response(generate_events(), mimetype='text/event-stream')
|
|
|
|
|
|
@app.route('/stopSam2Tracking', methods=['POST'])
|
|
def stop_sam2_tracking():
|
|
tracker_uuid = request.json.get('tracker_uuid')
|
|
if tracker_uuid in background_tasks.tracking_sessions:
|
|
session = background_tasks.tracking_sessions[tracker_uuid]
|
|
session['stop_requested'] = True
|
|
logging.info(f"Stop request received for SAM tracking session {tracker_uuid}")
|
|
return jsonify({'success': True})
|
|
return jsonify({'success': False, 'message': 'Tracker not found.'})
|
|
|
|
|
|
@app.route('/prepareToStartTracking', methods=['POST'])
|
|
def prepare_to_start_tracking():
|
|
data = request.json
|
|
video_uuid = data.get('video_uuid')
|
|
if background_tasks.active_tasks.get(video_uuid):
|
|
return jsonify(
|
|
{'success': False, 'message': 'Another task (extraction or tracking) is already running for this video.'})
|
|
tracker_uuid = str(uuid.uuid4().hex)
|
|
threading.Thread(target=background_tasks.start_tracking_task, args=(
|
|
video_uuid, tracker_uuid, data.get('tracker_name'),
|
|
float(data.get('scale')), int(data.get('init_frame_number')),
|
|
data.get('init_bboxes_text'),
|
|
), name=f"Tracker-{video_uuid[:6]}").start()
|
|
return jsonify({'success': True, 'tracker_uuid': tracker_uuid})
|
|
|
|
|
|
@app.route('/retrieveTrackedBboxes', methods=['POST'])
|
|
def retrieve_tracked_bboxes():
|
|
tracker_uuid = request.json.get('tracker_uuid')
|
|
session = background_tasks.tracking_sessions.get(tracker_uuid)
|
|
if session:
|
|
session['last_client_update'] = time.time()
|
|
return jsonify({
|
|
'success': True, 'tracker_failed': session['status'] in ['FAILED', 'TIMED OUT'],
|
|
'frame_number': session.get('current_frame'), 'bboxes_text': session.get('bboxes_text'),
|
|
})
|
|
return jsonify({'success': False, 'tracker_failed': True})
|
|
|
|
|
|
@app.route('/continueTracking', methods=['POST'])
|
|
def continue_tracking():
|
|
data = request.json
|
|
tracker_uuid = data.get('tracker_uuid')
|
|
session = background_tasks.tracking_sessions.get(tracker_uuid)
|
|
if session and session['status'] == 'RUNNING':
|
|
session['last_client_update'] = time.time()
|
|
session['bboxes_text'] = data.get('bboxes_text')
|
|
session['current_frame'] = int(data.get('frame_number')) + 1
|
|
database.save_frame_bboxes(data.get('video_uuid'), int(data.get('frame_number')), data.get('bboxes_text'))
|
|
return jsonify({'success': True})
|
|
return jsonify({'success': False})
|
|
|
|
|
|
@app.route('/stopTracking', methods=['POST'])
|
|
def stop_tracking():
|
|
tracker_uuid = request.json.get('tracker_uuid')
|
|
if tracker_uuid in background_tasks.tracking_sessions:
|
|
background_tasks.tracking_sessions[tracker_uuid]['stop_requested'] = True
|
|
return jsonify({'success': True})
|
|
|
|
|
|
@app.route('/listDatasets', methods=['GET'])
|
|
def list_datasets():
|
|
datasets = database.get_dataset_list()
|
|
return jsonify({'datasets': [sanitize_dict(d) for d in datasets]})
|
|
|
|
|
|
@app.route('/createDataset', methods=['POST'])
|
|
def create_dataset():
|
|
data = request.json
|
|
desc = data.get('description')
|
|
video_uuids = data.get('video_uuids')
|
|
eval_percent = float(data.get('eval_percent', 20.0))
|
|
test_percent = float(data.get('test_percent', 10.0))
|
|
augmentation_options = data.get('augmentation_options', {})
|
|
|
|
is_valid, message = validate_description(desc, [d['description'] for d in database.get_dataset_list()])
|
|
if not is_valid:
|
|
return jsonify({'success': False, 'message': message}), 400
|
|
if not video_uuids:
|
|
return jsonify({'success': False, 'message': 'Please select at least one video.'}), 400
|
|
|
|
create_time = int(time.time() * 1000)
|
|
dataset_uuid = database.create_dataset_entry(desc, video_uuids, create_time, eval_percent, test_percent)
|
|
|
|
threading.Thread(target=background_tasks.create_dataset_task, args=(
|
|
dataset_uuid, video_uuids, eval_percent, test_percent, augmentation_options
|
|
), name=f"Dataset-{dataset_uuid[:6]}").start()
|
|
|
|
return jsonify({'success': True, 'dataset_uuid': dataset_uuid})
|
|
|
|
|
|
@app.route('/regenerateDataset', methods=['POST'])
|
|
def regenerate_dataset():
|
|
dataset_uuid = request.json.get('dataset_uuid')
|
|
if not dataset_uuid:
|
|
return jsonify({'success': False, 'message': 'Dataset UUID is required.'}), 400
|
|
|
|
dataset = database.get_dataset_entity(dataset_uuid)
|
|
if not dataset:
|
|
return jsonify({'success': False, 'message': 'Dataset not found.'}), 404
|
|
|
|
file_storage.delete_dataset_files(dataset_uuid)
|
|
database.update_dataset_status(dataset_uuid, 'PENDING')
|
|
video_uuids = json.loads(dataset['video_uuids'])
|
|
eval_percent = dataset.get('eval_percent')
|
|
test_percent = dataset.get('test_percent')
|
|
augmentation_options = {'enabled': False}
|
|
|
|
threading.Thread(target=background_tasks.create_dataset_task, args=(
|
|
dataset_uuid, video_uuids, eval_percent, test_percent, augmentation_options
|
|
), name=f"Dataset-Regen-{dataset_uuid[:6]}").start()
|
|
|
|
return jsonify({'success': True, 'message': 'Dataset regeneration started.'})
|
|
|
|
|
|
@app.route('/downloadDataset/<dataset_uuid>')
|
|
def download_dataset(dataset_uuid):
|
|
dataset = database.get_dataset_entity(dataset_uuid)
|
|
if not dataset or dataset['status'] != 'READY' or not dataset['zip_path']:
|
|
return "Dataset not found or not ready.", 404
|
|
try:
|
|
return send_file(dataset['zip_path'], as_attachment=True)
|
|
except Exception as e:
|
|
logging.error(f"Could not send file: {e}")
|
|
return "Error downloading file.", 500
|
|
|
|
|
|
@app.route('/deleteDataset', methods=['POST'])
|
|
def delete_dataset():
|
|
dataset_uuid = request.json.get('dataset_uuid')
|
|
database.delete_dataset(dataset_uuid)
|
|
file_storage.delete_dataset_files(dataset_uuid)
|
|
return jsonify({'success': True})
|
|
|
|
|
|
@app.route('/listModels', methods=['GET'])
|
|
def list_models():
|
|
models = database.get_model_list()
|
|
return jsonify({'models': [sanitize_dict(m) for m in models]})
|
|
|
|
|
|
@app.route('/importModel', methods=['POST'])
|
|
def import_model():
|
|
desc = request.form.get('description')
|
|
model_file = request.files.get('model_file')
|
|
label_file = request.files.get('label_file')
|
|
model_type = request.form.get('model_type')
|
|
|
|
is_valid, message = validate_description(desc, [m['description'] for m in database.get_model_list()])
|
|
if not is_valid:
|
|
return jsonify({'success': False, 'message': message}), 400
|
|
|
|
if not model_file or not model_file.filename.endswith('.tflite'):
|
|
return jsonify({'success': False, 'message': 'Please provide a .tflite model file.'}), 400
|
|
if not label_file or not (label_file.filename.endswith('.txt') or label_file.filename.endswith('.labels')):
|
|
return jsonify({'success': False, 'message': 'Please provide a .txt or .labels file.'}), 400
|
|
if not model_type in ['float32', 'uint8']:
|
|
return jsonify({'success': False, 'message': 'Invalid model type selected.'}), 400
|
|
|
|
create_time = int(time.time() * 1000)
|
|
model_uuid = database.import_model_metadata(desc, label_file.filename, model_type, create_time)
|
|
|
|
file_storage.save_imported_model(model_file, model_uuid)
|
|
file_storage.save_imported_label_file(label_file, model_uuid)
|
|
|
|
return jsonify({'success': True, 'model_uuid': model_uuid})
|
|
|
|
|
|
@app.route('/deleteModel', methods=['POST'])
|
|
def delete_model():
|
|
model_uuid = request.json.get('model_uuid')
|
|
database.delete_model(model_uuid)
|
|
file_storage.delete_model_file(model_uuid)
|
|
file_storage.delete_label_file(model_uuid)
|
|
return jsonify({'success': True})
|
|
|
|
|
|
@app.route('/startPreAnnotation', methods=['POST'])
|
|
def start_pre_annotation():
|
|
data = request.json
|
|
video_uuid = data.get('video_uuid')
|
|
model_uuid = data.get('model_uuid')
|
|
options = data.get('options', {})
|
|
|
|
if not video_uuid or not model_uuid:
|
|
return jsonify({'success': False, 'message': 'Video UUID and Model UUID are required.'}), 400
|
|
|
|
video = database.get_video_entity(video_uuid)
|
|
if not video:
|
|
return jsonify({'success': False, 'message': 'Video not found.'}), 404
|
|
if video['status'] != 'READY':
|
|
return jsonify({'success': False, 'message': f"Video must be in READY state, but is {video['status']}."}), 400
|
|
|
|
if background_tasks.active_tasks.get(video_uuid):
|
|
return jsonify({'success': False, 'message': 'Another task is already running for this video.'}), 409
|
|
|
|
try:
|
|
options['start_frame'] = int(options.get('start_frame', 0))
|
|
options['end_frame'] = int(options.get('end_frame', video['frame_count'] - 1))
|
|
options['confidence'] = float(options.get('confidence', 0.5))
|
|
options['merge_strategy'] = options.get('merge_strategy', 'overwrite')
|
|
except (ValueError, TypeError) as e:
|
|
return jsonify({'success': False, 'message': f'Invalid options provided: {e}'}), 400
|
|
|
|
threading.Thread(
|
|
target=background_tasks.pre_annotate_video_task,
|
|
args=(video_uuid, model_uuid, options),
|
|
name=f"PreAnnotator-{video_uuid[:6]}"
|
|
).start()
|
|
|
|
return jsonify({'success': True, 'message': 'Pre-annotation task started.'})
|
|
|
|
|
|
@app.route('/cancelTask', methods=['POST'])
|
|
def cancel_task():
|
|
video_uuid = request.json.get('video_uuid')
|
|
if not video_uuid:
|
|
return jsonify({'success': False, 'message': 'Video UUID is required.'}), 400
|
|
|
|
video = database.get_video_entity(video_uuid)
|
|
if not video:
|
|
return jsonify({'success': False, 'message': 'Video not found.'}), 404
|
|
|
|
if video['status'] in ['PRE_ANNOTATING', 'APPLYING_PROTOTYPES']:
|
|
database.update_video_status(video_uuid, 'CANCELLING', 'Cancellation requested by user.')
|
|
return jsonify({'success': True, 'message': 'Cancellation request sent.'})
|
|
else:
|
|
return jsonify({'success': False, 'message': f'Cannot cancel task, video status is {video["status"]}.'}), 400
|
|
|
|
|
|
@app.route('/datasetAnalysis/<dataset_uuid>')
|
|
def dataset_analysis(dataset_uuid):
|
|
dataset = database.get_dataset_entity(dataset_uuid)
|
|
if not dataset:
|
|
return "Dataset not found", 404
|
|
return render_template('dataset_analysis.html',
|
|
dataset=sanitize_dict(dataset),
|
|
limit_data=config.get_limit_data_for_render_template())
|
|
|
|
|
|
@app.route('/api/datasetAnalysis/<dataset_uuid>', methods=['GET'])
|
|
def get_dataset_analysis_data(dataset_uuid):
|
|
dataset = database.get_dataset_entity(dataset_uuid)
|
|
if not dataset:
|
|
return jsonify({'success': False, 'message': 'Dataset not found.'}), 404
|
|
|
|
video_uuids = json.loads(dataset.get('video_uuids', '[]'))
|
|
tasks_by_video = {vu: database.get_tasks_for_video(vu) for vu in video_uuids}
|
|
video_info_cache = {vu: database.get_video_entity(vu) for vu in video_uuids}
|
|
|
|
def get_task_for_frame(video_uuid, frame_number):
|
|
for task in tasks_by_video.get(video_uuid, []):
|
|
if task['start_frame'] <= frame_number <= task['end_frame']:
|
|
return task['task_uuid']
|
|
return None
|
|
|
|
all_frames = [
|
|
{**dict(frame), 'video_uuid': vu, 'video_description': video_info_cache[vu].get('description')}
|
|
for vu in video_uuids
|
|
for frame in database.get_video_frames(vu)
|
|
if frame.get('bboxes_text', '').strip()
|
|
]
|
|
|
|
class_counts = Counter()
|
|
aspect_ratios, objects_per_image, center_points, brightness_levels = [], [], [], []
|
|
all_bboxes_for_outliers = []
|
|
suspicious_pairs = []
|
|
image_class_map = {}
|
|
|
|
for i, frame in enumerate(all_frames):
|
|
video_uuid, frame_number, bboxes_text = frame['video_uuid'], frame['frame_number'], frame['bboxes_text']
|
|
rects, labels, _ = convert_text_to_rects_and_labels(bboxes_text)
|
|
|
|
image_class_map[i] = list(set(labels))
|
|
objects_per_image.append(len(labels))
|
|
|
|
if len(rects) > 1:
|
|
for (idx1, rect1), (idx2, rect2) in itertools.combinations(enumerate(rects), 2):
|
|
iou = calculate_iou(rect1, rect2)
|
|
if iou > 0.95:
|
|
suspicious_pairs.append({
|
|
'image_index': i, 'iou': iou,
|
|
'box1_label': labels[idx1], 'box2_label': labels[idx2]
|
|
})
|
|
try:
|
|
image_gray = skio.imread(file_storage.get_frame_path(video_uuid, frame_number), as_gray=True)
|
|
brightness_levels.append(np.mean(image_gray) * 255)
|
|
except Exception:
|
|
pass
|
|
|
|
for j, rect in enumerate(rects):
|
|
class_counts[labels[j]] += 1
|
|
width, height = int(rect[2] - rect[0]), int(rect[3] - rect[1])
|
|
|
|
if width > 0 and height > 0:
|
|
aspect_ratios.append(width / height)
|
|
video_info = video_info_cache[video_uuid]
|
|
if video_info and video_info['width'] > 0 and video_info['height'] > 0:
|
|
center_x = (float(rect[0]) + float(rect[2])) / 2.0 / float(video_info['width'])
|
|
center_y = (float(rect[1]) + float(rect[3])) / 2.0 / float(video_info['height'])
|
|
center_points.append({'x': center_x, 'y': center_y})
|
|
all_bboxes_for_outliers.append(
|
|
{'id': f'{video_uuid}_{frame_number}_{j}', 'image_index': i, 'area': width * height,
|
|
'aspect_ratio': width / height})
|
|
|
|
annotator_stats = {}
|
|
all_tasks = [task for vid_tasks in tasks_by_video.values() for task in vid_tasks]
|
|
for task in all_tasks:
|
|
user = task['assigned_to']
|
|
if user not in annotator_stats:
|
|
annotator_stats[user] = {'image_count': 0, 'class_counts': Counter()}
|
|
|
|
user_frame_sets = {user: set() for user in annotator_stats.keys()}
|
|
for frame in all_frames:
|
|
for task in all_tasks:
|
|
if task['video_uuid'] == frame['video_uuid'] and task['start_frame'] <= frame['frame_number'] <= task[
|
|
'end_frame']:
|
|
user = task['assigned_to']
|
|
user_frame_sets[user].add(f"{frame['video_uuid']}_{frame['frame_number']}")
|
|
annotator_stats[user]['class_counts'].update(extract_labels(frame['bboxes_text']))
|
|
for user, frame_set in user_frame_sets.items():
|
|
annotator_stats[user]['image_count'] = len(frame_set)
|
|
|
|
gallery_images = [{
|
|
'original_url': f"/media/frames/{f['video_uuid']}/frame_{f['frame_number']:05d}.jpg",
|
|
'video': f['video_description'], 'frame': f['frame_number'], 'video_uuid': f['video_uuid'],
|
|
'task_uuid': get_task_for_frame(f['video_uuid'], f['frame_number'])
|
|
} for f in all_frames]
|
|
|
|
warnings = []
|
|
total_instances = sum(class_counts.values())
|
|
if class_counts:
|
|
avg_instances = total_instances / len(class_counts)
|
|
for class_name, count in class_counts.items():
|
|
if count < 10 or count < avg_instances * 0.1:
|
|
warnings.append(
|
|
f"<b>Class Imbalance:</b> Class '{class_name}' has very few instances ({count}), which may affect model performance.")
|
|
|
|
small_object_threshold = 100
|
|
small_object_count = sum(1 for bbox in all_bboxes_for_outliers if bbox['area'] < small_object_threshold)
|
|
if small_object_count > 0:
|
|
warnings.append(
|
|
f"<b>Small Object Warning:</b> Found {small_object_count} objects with an area smaller than {small_object_threshold} pixels. Please check if they are labeling errors.")
|
|
|
|
if suspicious_pairs:
|
|
warnings.append(
|
|
f"<b>Potential Duplicate Warning:</b> Found {len(suspicious_pairs)} pairs of bounding boxes with high overlap (IoU > 0.95), which might be duplicate labels.")
|
|
|
|
summary_text = f"This dataset contains <strong>{len(class_counts)}</strong> classes, with a total of <strong>{total_instances}</strong> instances across <strong>{len(all_frames)}</strong> labeled images."
|
|
|
|
return jsonify({
|
|
'success': True,
|
|
'summary_text': summary_text,
|
|
'warnings': warnings,
|
|
'class_counts': dict(class_counts),
|
|
'aspect_ratios': aspect_ratios,
|
|
'objects_per_image': objects_per_image,
|
|
'center_points': center_points,
|
|
'brightness_levels': brightness_levels,
|
|
'annotator_stats': {u: {'image_count': d['image_count'], 'class_counts': dict(d['class_counts'])} for u, d in
|
|
annotator_stats.items()},
|
|
'all_bboxes': all_bboxes_for_outliers,
|
|
'suspicious_pairs': suspicious_pairs,
|
|
'image_class_map': image_class_map,
|
|
'gallery_images': gallery_images
|
|
})
|
|
|
|
|
|
def calculate_color_histogram(image_path, rect):
|
|
try:
|
|
image = cv2.imread(image_path)
|
|
if image is None:
|
|
return None
|
|
|
|
x1, y1, x2, y2 = map(int, rect)
|
|
roi = image[y1:y2, x1:x2]
|
|
|
|
if roi.size == 0:
|
|
return None
|
|
|
|
hsv_roi = cv2.cvtColor(roi, cv2.COLOR_BGR2HSV)
|
|
hist = cv2.calcHist([hsv_roi], [0, 1], None, [16, 16], [0, 180, 0, 256])
|
|
cv2.normalize(hist, hist, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX)
|
|
|
|
return hist.flatten()
|
|
except Exception as e:
|
|
logging.warning(f"计算颜色直方图失败: {e} for path {image_path}")
|
|
return None
|
|
|
|
|
|
@app.route('/api/datasetAnalysis/<dataset_uuid>/consistency_check', methods=['POST'])
|
|
def run_consistency_check(dataset_uuid):
|
|
try:
|
|
request_data = request.get_json()
|
|
is_color_check_enabled = request_data.get('enable_color_check', True)
|
|
|
|
if is_color_check_enabled:
|
|
logging.info("Starting AI Quality Control with SEMANTIC and COLOR checks.")
|
|
else:
|
|
logging.info("Starting AI Quality Control with SEMANTIC check ONLY.")
|
|
|
|
dataset = database.get_dataset_entity(dataset_uuid)
|
|
if not dataset:
|
|
return jsonify({'success': False, 'message': 'Dataset not found.'}), 404
|
|
|
|
video_uuids = json.loads(dataset.get('video_uuids', '[]'))
|
|
all_frames = [
|
|
dict(frame) for vu in video_uuids for frame in database.get_video_frames(vu)
|
|
if frame.get('bboxes_text', '').strip()
|
|
]
|
|
|
|
all_bboxes_info = []
|
|
frames_to_process = defaultdict(list)
|
|
for i, frame in enumerate(all_frames):
|
|
rects, labels, _ = convert_text_to_rects_and_labels(frame['bboxes_text'])
|
|
for j, rect in enumerate(rects):
|
|
box_info = {'image_index': i, 'rect': rect, 'label': labels[j], 'video_uuid': frame['video_uuid'],
|
|
'frame_number': frame['frame_number'], 'embedding': None, 'color_hist': None}
|
|
frames_to_process[f"{frame['video_uuid']};{frame['frame_number']}"].append((len(all_bboxes_info), rect))
|
|
all_bboxes_info.append(box_info)
|
|
|
|
for frame_key, rect_data in frames_to_process.items():
|
|
video_uuid, frame_number_str = frame_key.split(';')
|
|
frame_number = int(frame_number_str)
|
|
image_path = file_storage.get_frame_path(video_uuid, frame_number)
|
|
rects_in_frame = [r for _, r in rect_data]
|
|
embeddings = ai_models.get_features_for_specific_bboxes(video_uuid, frame_number, rects_in_frame)
|
|
for i, (global_idx, rect) in enumerate(rect_data):
|
|
if embeddings is not None and i < len(embeddings):
|
|
all_bboxes_info[global_idx]['embedding'] = embeddings[i]
|
|
color_hist = calculate_color_histogram(image_path, rect)
|
|
if color_hist is not None:
|
|
all_bboxes_info[global_idx]['color_hist'] = color_hist
|
|
|
|
all_features_by_class = defaultdict(list)
|
|
for info in all_bboxes_info:
|
|
if info['embedding'] is not None and (not is_color_check_enabled or info['color_hist'] is not None):
|
|
all_features_by_class[info['label']].append(info)
|
|
|
|
prototype_library = {}
|
|
for class_name, infos in all_features_by_class.items():
|
|
if len(infos) > 0:
|
|
prototype_library[class_name] = {}
|
|
embeddings_tensor = torch.stack([info['embedding'] for info in infos])
|
|
prototype_library[class_name]['semantic'] = torch.mean(embeddings_tensor, dim=0)
|
|
if is_color_check_enabled:
|
|
color_hists = [info['color_hist'] for info in infos if info['color_hist'] is not None]
|
|
if color_hists:
|
|
prototype_library[class_name]['color'] = np.mean(np.array(color_hists), axis=0)
|
|
|
|
outlier_image_indices = set()
|
|
|
|
COLOR_CONFUSION_FACTOR = 2.0
|
|
|
|
for class_name, infos in all_features_by_class.items():
|
|
if len(infos) < 5 or class_name not in prototype_library:
|
|
continue
|
|
|
|
for info in infos:
|
|
is_outlier = False
|
|
|
|
candidate_embedding = info['embedding']
|
|
own_semantic_sim = F.cosine_similarity(candidate_embedding, prototype_library[class_name]['semantic'],
|
|
dim=0).item()
|
|
for other_class, other_prototypes in prototype_library.items():
|
|
if other_class == class_name: continue
|
|
other_semantic_sim = F.cosine_similarity(candidate_embedding, other_prototypes['semantic'],
|
|
dim=0).item()
|
|
if other_semantic_sim > own_semantic_sim + 0.2:
|
|
logging.info(
|
|
f"SEMANTIC outlier: A '{class_name}' (sim: {own_semantic_sim:.2f}) looks more like a '{other_class}' (sim: {other_semantic_sim:.2f}). Image index: {info['image_index']}")
|
|
is_outlier = True
|
|
break
|
|
if is_outlier:
|
|
outlier_image_indices.add(info['image_index'])
|
|
continue
|
|
|
|
if is_color_check_enabled:
|
|
if 'color' not in prototype_library[class_name] or info['color_hist'] is None:
|
|
continue
|
|
|
|
if len(prototype_library) < 2: continue
|
|
|
|
candidate_color_hist = info['color_hist']
|
|
dist_to_own_color = np.sum(np.abs(candidate_color_hist - prototype_library[class_name]['color']))
|
|
|
|
min_dist_to_other_color = float('inf')
|
|
closest_other_class = None
|
|
|
|
for other_class, other_prototypes in prototype_library.items():
|
|
if other_class == class_name or 'color' not in other_prototypes: continue
|
|
dist = np.sum(np.abs(candidate_color_hist - other_prototypes['color']))
|
|
if dist < min_dist_to_other_color:
|
|
min_dist_to_other_color = dist
|
|
closest_other_class = other_class
|
|
|
|
if closest_other_class and min_dist_to_other_color * COLOR_CONFUSION_FACTOR < dist_to_own_color:
|
|
logging.info(
|
|
f"COLOR outlier: A '{class_name}' (color dist: {dist_to_own_color:.2f}) has a color profile much closer to '{closest_other_class}' (color dist: {min_dist_to_other_color:.2f}). Image index: {info['image_index']}")
|
|
is_outlier = True
|
|
|
|
if is_outlier:
|
|
outlier_image_indices.add(info['image_index'])
|
|
|
|
message_keyword = "**Category or color**" if is_color_check_enabled else "**category**"
|
|
if len(outlier_image_indices) == 1 or len(outlier_image_indices) == 0:
|
|
message = f"AI review complete. Found {len(outlier_image_indices)} image with potential instances of {message_keyword} confusion." if outlier_image_indices else "AI review completed. No obvious labeling confusion issues were found."
|
|
else:
|
|
message = f"AI review complete. Found {len(outlier_image_indices)} images with potential instances of {message_keyword} confusion." if outlier_image_indices else "AI review completed. No obvious labeling confusion issues were found."
|
|
|
|
return jsonify({
|
|
'success': True,
|
|
'message': message,
|
|
'outlier_image_indices': list(outlier_image_indices)
|
|
})
|
|
|
|
except Exception as e:
|
|
logging.error(f"On-demand consistency check failed: {e}", exc_info=True)
|
|
return jsonify({'success': False, 'message': f'审查失败: {e}'}), 500
|
|
|
|
|
|
@app.route('/lam_predict', methods=['POST'])
|
|
def lam_predict_route():
|
|
data = request.json
|
|
video_uuid = data.get('video_uuid')
|
|
frame_number = data.get('frame_number')
|
|
point = data.get('point')
|
|
|
|
if not all([video_uuid, frame_number is not None, point]):
|
|
return jsonify({'success': False, 'message': 'Missing required request parameters.'}), 400
|
|
|
|
try:
|
|
point_coords = (int(point['x']), int(point['y']))
|
|
result, error_msg = ai_models.lam_predict(video_uuid, int(frame_number), point_coords)
|
|
|
|
if error_msg:
|
|
return jsonify({'success': False, 'message': error_msg})
|
|
|
|
return jsonify({'success': True, **result})
|
|
|
|
except Exception as e:
|
|
logging.error(f"LAM 预测失败: {e}", exc_info=True)
|
|
return jsonify({'success': False, 'message': f'Internal Server Error: {str(e)}'}), 500
|
|
|
|
@app.route('/api/previewAugmentations', methods=['POST'])
|
|
def preview_augmentations():
|
|
if not background_tasks.A:
|
|
return jsonify({'success': False, 'message': 'Albumentations library not installed on server.'}), 501
|
|
|
|
data = request.json
|
|
video_uuid = data.get('video_uuid')
|
|
frame_number = data.get('frame_number')
|
|
augmentation_options = data.get('augmentation_options')
|
|
sample_pool = data.get('sample_pool')
|
|
|
|
if not all([video_uuid, frame_number is not None, augmentation_options]):
|
|
return jsonify({'success': False, 'message': 'Missing required data.'}), 400
|
|
|
|
try:
|
|
if augmentation_options.get('mosaic', {}).get('enabled'):
|
|
if not sample_pool:
|
|
return jsonify({'success': False, 'message': 'Sample pool is required for Mosaic preview.'}), 400
|
|
if random.random() < augmentation_options['mosaic'].get('p', 1.0):
|
|
previews = generate_mosaic_previews(sample_pool, video_uuid, frame_number)
|
|
return jsonify({'success': True, 'previews': previews})
|
|
|
|
frame_path = file_storage.get_frame_path(video_uuid, frame_number)
|
|
if not os.path.exists(frame_path):
|
|
return jsonify({'success': False, 'message': 'Frame image not found.'}), 404
|
|
|
|
image = cv2.imread(frame_path)
|
|
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
|
video_info = database.get_video_entity(video_uuid)
|
|
conn = database.get_db_connection()
|
|
frame_db_info = conn.execute('SELECT bboxes_text FROM video_frames WHERE video_uuid = ? AND frame_number = ?',
|
|
(video_uuid, frame_number)).fetchone()
|
|
conn.close()
|
|
|
|
if not frame_db_info or not frame_db_info['bboxes_text']:
|
|
return jsonify({'success': False, 'message': 'No labels found for this frame.'}), 404
|
|
|
|
augmentation_options['mosaic'] = {'enabled': False}
|
|
pipeline = background_tasks.build_augmentation_pipeline(augmentation_options)
|
|
if not pipeline:
|
|
return jsonify({'success': False, 'message': 'No valid augmentations selected.'}), 400
|
|
|
|
all_labels = database.get_all_class_labels()
|
|
class_map = {name: i for i, name in enumerate(all_labels)}
|
|
yolo_bboxes, class_indices = file_storage.get_yolo_bboxes(frame_db_info['bboxes_text'], video_info['width'],
|
|
video_info['height'], class_map)
|
|
if not yolo_bboxes:
|
|
return jsonify({'success': False, 'message': 'Could not parse labels into YOLO format.'}), 500
|
|
|
|
previews = []
|
|
for _ in range(6):
|
|
transformed = pipeline(image=image_rgb, bboxes=yolo_bboxes, class_labels=class_indices)
|
|
aug_image_rgb = transformed['image']
|
|
aug_bboxes_yolo = transformed['bboxes']
|
|
aug_labels_indices = transformed['class_labels']
|
|
h, w, _ = aug_image_rgb.shape
|
|
vis_image = aug_image_rgb.copy()
|
|
for i, bbox in enumerate(aug_bboxes_yolo):
|
|
class_index = int(aug_labels_indices[i])
|
|
class_name = all_labels[class_index]
|
|
color = string_to_color_bgr(class_name)
|
|
x_center, y_center, width_norm, height_norm = bbox
|
|
x1 = int((x_center - width_norm / 2) * w)
|
|
y1 = int((y_center - height_norm / 2) * h)
|
|
x2 = int((x_center + width_norm / 2) * w)
|
|
y2 = int((y_center + height_norm / 2) * h)
|
|
cv2.rectangle(vis_image, (x1, y1), (x2, y2), color, 2)
|
|
|
|
vis_image_bgr = cv2.cvtColor(vis_image, cv2.COLOR_RGB2BGR)
|
|
_, buffer = cv2.imencode('.jpg', vis_image_bgr)
|
|
img_base64 = base64.b64encode(buffer).decode('utf-8')
|
|
previews.append(f"data:image/jpeg;base64,{img_base64}")
|
|
|
|
return jsonify({'success': True, 'previews': previews})
|
|
|
|
except Exception as e:
|
|
logging.error(f"Augmentation preview failed: {e}", exc_info=True)
|
|
return jsonify({'success': False, 'message': str(e)}), 500
|
|
|
|
|
|
if __name__ == '__main__':
|
|
from waitress import serve
|
|
init(autoreset=True)
|
|
|
|
root_logger = logging.getLogger()
|
|
root_logger.setLevel(logging.INFO)
|
|
|
|
if root_logger.hasHandlers():
|
|
root_logger.handlers.clear()
|
|
|
|
console_handler = logging.StreamHandler()
|
|
formatter = ColoredFormatter('%(asctime)s - %(levelname)s - %(threadName)s - %(message)s')
|
|
console_handler.setFormatter(formatter)
|
|
root_logger.addHandler(console_handler)
|
|
logging.info("正在初始化AI模型,请稍候...")
|
|
ai_models.startup_ai_models()
|
|
atexit.register(ai_models.save_preprocessed_cache_to_disk)
|
|
atexit.register(ai_models.save_prototypes_to_disk)
|
|
time.sleep(0.01)
|
|
|
|
logging.info("=" * 121)
|
|
logging.info("███████╗ ███████╗ ██████╗ ██████╗ ██████╗ ██╗ ██╗ ██████╗ ██╗ ██████╗ ██╗ ██╗ █████╗ ██████╗ ██████╗")
|
|
logging.info("╚══███╔╝ ██╔════╝ ██╔══██╗ ██╔═══██╗ ╚════██╗ ╚██╗ ██╔╝ ██╔═══██╗ ██║ ██╔═══██╗ ╚██╗ ██╔╝ ██╔══██╗ ██╔══██╗ ██╔══██╗")
|
|
logging.info("███╔╝ █████╗ ██████╔╝ ██║ ██║ █████╔╝ ╚████╔╝ ██║ ██║ ██║ ██║ ██║ ╚████╔╝ ███████║ ██████╔╝ ██║ ██║")
|
|
logging.info("███╔╝ ██╔══╝ ██╔══██╗ ██║ ██║ ██╔═══╝ ╚██╔╝ ██║ ██║ ██║ ██║ ██║ ╚██╔╝ ██╔══██║ ██╔══██╗ ██║ ██║")
|
|
logging.info("███████╗ ███████╗ ██║ ██║ ╚██████╔╝ ███████╗ ██║ ╚██████╔╝ ███████╗ ╚██████╔╝ ██║ ██║ ██║ ██║ ██║ ██████╔╝")
|
|
logging.info("╚══════╝ ╚══════╝ ╚═╝ ╚═╝ ╚═════╝ ╚══════╝ ╚═╝ ╚═════╝ ╚══════╝ ╚═════╝ ╚═╝ ╚═╝ ╚═╝ ╚═╝ ╚═╝ ╚═════╝ ")
|
|
logging.info("Developed by BlueDarkUP from FIRST Tech Challenge team 27570 Be based on -- FIRST Machine Learning Toolchain --")
|
|
logging.info("Open your web browser and go to http://127.0.0.1:5000")
|
|
logging.info("=" * 121)
|
|
serve(app, host='0.0.0.0', port=5000) |