import flask from flask import Response, request, jsonify, render_template, send_from_directory, send_file import os import time import json import threading import uuid import logging import cv2 import numpy as np import base64 from collections import Counter, defaultdict from skimage import io as skio import itertools import random import torch import torch.nn.functional as F import atexit import settings_manager import config import database import file_storage import background_tasks import ai_models from bbox_writer import validate_bboxes_text, convert_text_to_rects_and_labels, extract_labels from concurrent.futures import ThreadPoolExecutor from colorama import Fore, Style, init class ColoredFormatter(logging.Formatter): COLORS = { 'TIMESTAMP': Fore.WHITE + Style.DIM, 'THREAD': Fore.CYAN, 'LEVEL_DEFAULT': Fore.WHITE, 'MESSAGE_DEFAULT': Fore.WHITE + Style.NORMAL, logging.DEBUG: {'level': Fore.MAGENTA, 'message': Fore.MAGENTA}, logging.INFO: {'level': Fore.GREEN + Style.BRIGHT, 'message': Fore.WHITE}, logging.WARNING: {'level': Fore.YELLOW + Style.BRIGHT, 'message': Fore.YELLOW}, logging.ERROR: {'level': Fore.RED + Style.BRIGHT, 'message': Fore.RED}, logging.CRITICAL: {'level': Fore.RED + Style.BRIGHT, 'message': Fore.RED + Style.BRIGHT}, } def __init__(self, fmt=None, datefmt=None, style='%'): super().__init__(fmt, datefmt, style) def format(self, record): level_colors = self.COLORS.get(record.levelno, { 'level': self.COLORS['LEVEL_DEFAULT'], 'message': self.COLORS['MESSAGE_DEFAULT'] }) asctime = self.formatTime(record, self.datefmt) colored_asctime = f"{self.COLORS['TIMESTAMP']}{asctime}{Style.RESET_ALL}" colored_levelname = f"{level_colors['level']}{record.levelname:<8}{Style.RESET_ALL}" colored_threadname = f"{self.COLORS['THREAD']}[{record.threadName}]{Style.RESET_ALL}" message = record.getMessage() colored_message = f"{level_colors['message']}{message}{Style.RESET_ALL}" if record.exc_info: exc_text = self.formatException(record.exc_info) colored_message += f"\n{level_colors['message']}{exc_text}{Style.RESET_ALL}" return f"{colored_asctime} - {colored_levelname} - {colored_threadname} - {colored_message}" try: import yaml except ImportError: logging.error("PyYAML is not installed! Dataset export will fail. Please run 'pip install pyyaml'.") yaml = None app = flask.Flask(__name__) app.secret_key = os.urandom(24) prototype_executor = ThreadPoolExecutor(max_workers=8) logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(threadName)s - %(message)s') with app.app_context(): database.init_db() database.migrate_db() file_storage.init_storage() def validate_description(desc, existing_descriptions): if not (1 <= len(desc) <= config.MAX_DESCRIPTION_LENGTH): return False, "Description must be between 1 and 30 characters." if desc in existing_descriptions: return False, "Description is a duplicate." return True, "" def sanitize_dict(d): return d def string_to_color_bgr(s): hash_val = 0 for char in s: hash_val = ord(char) + ((hash_val << 5) - hash_val) hue = hash_val % 180 color_hsv = np.uint8([[[hue, 200, 200]]]) color_bgr = cv2.cvtColor(color_hsv, cv2.COLOR_HSV2BGR)[0][0] return tuple(map(int, color_bgr)) def calculate_iou(boxA, boxB): xA = max(boxA[0], boxB[0]) yA = max(boxA[1], boxB[1]) xB = min(boxA[2], boxB[2]) yB = min(boxA[3], boxB[3]) interArea = max(0, xB - xA) * max(0, yB - yA) if interArea == 0: return 0.0 boxAArea = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1]) boxBArea = (boxB[2] - boxB[0]) * (boxB[3] - boxB[1]) iou = interArea / float(boxAArea + boxBArea - interArea) return iou def generate_mosaic_previews(sample_pool, selected_video_uuid, selected_frame_number): if len(sample_pool) < 4: sample_pool.extend(sample_pool * (4 - len(sample_pool))) all_labels = database.get_all_class_labels() class_map = {name: i for i, name in enumerate(all_labels)} conn = database.get_db_connection() image_infos = [] for sample in sample_pool: video_info = database.get_video_entity(sample['video_uuid']) frame_info = conn.execute( 'SELECT bboxes_text FROM video_frames WHERE video_uuid = ? AND frame_number = ?', (sample['video_uuid'], sample['frame_number']) ).fetchone() if video_info and frame_info and frame_info['bboxes_text']: image_infos.append({ "video_uuid": sample['video_uuid'], "frame_number": sample['frame_number'], "bboxes_text": frame_info['bboxes_text'], "width": video_info['width'], "height": video_info['height'] }) conn.close() if len(image_infos) < 4: return jsonify({'success': False, 'message': 'Not enough labeled images in the sample pool to generate a mosaic preview.'}), 400 previews = [] selected_image_info = next((info for info in image_infos if info['video_uuid'] == selected_video_uuid and info[ 'frame_number'] == selected_frame_number), None) for _ in range(6): other_images = [info for info in image_infos if info != selected_image_info] random.shuffle(other_images) mosaic_set = [selected_image_info] + other_images[:3] if selected_image_info else other_images[:4] random.shuffle(mosaic_set) mosaic_img, final_bboxes = file_storage.create_mosaic_image(mosaic_set, class_map) h, w, _ = mosaic_img.shape vis_image = mosaic_img.copy() for bbox_data in final_bboxes: class_index, x_center, y_center, width_norm, height_norm = bbox_data class_name = all_labels[class_index] color = string_to_color_bgr(class_name) x1 = int((x_center - width_norm / 2) * w) y1 = int((y_center - height_norm / 2) * h) x2 = int((x_center + width_norm / 2) * w) y2 = int((y_center + height_norm / 2) * h) cv2.rectangle(vis_image, (x1, y1), (x2, y2), color, 2) _, buffer = cv2.imencode('.jpg', vis_image) img_base64 = base64.b64encode(buffer).decode('utf-8') previews.append(f"data:image/jpeg;base64,{img_base64}") return previews @app.route('/') def index(): return render_template('root.html', limit_data=config.get_limit_data_for_render_template(), tracker_fns=config.TRACKER_FNS) @app.route('/labelVideo') def label_video(): task_uuid = request.args.get('task_uuid') if not task_uuid: return "Task UUID is required.", 400 task_entity = database.get_task_entity(task_uuid) if not task_entity: return "Annotation task not found.", 404 if task_entity['status'] == 'PENDING': database.update_task_status(task_uuid, 'IN_PROGRESS') task_entity = database.get_task_entity(task_uuid) video_entity = database.get_video_entity(task_entity['video_uuid']) if not video_entity: return "Associated video not found.", 404 first_frame_url = f"/media/frames/{video_entity['video_uuid']}/frame_{task_entity['start_frame']:05d}.jpg" settings = settings_manager.load_settings() return render_template('labelVideo.html', task_entity=sanitize_dict(task_entity), video_entity=sanitize_dict(video_entity), first_frame_url=first_frame_url, settings=settings, limit_data=config.get_limit_data_for_render_template()) @app.route('/media/') def send_media(path): return send_from_directory(config.STORAGE_DIR, path) @app.route('/media/annotated_frame//.jpg') def serve_annotated_frame(video_uuid, frame_number): try: frame_path = file_storage.get_frame_path(video_uuid, frame_number) if not os.path.exists(frame_path): return "Frame not found", 404 image = cv2.imread(frame_path) if image is None: return "Could not read frame image", 500 conn = database.get_db_connection() frame_data = conn.execute( 'SELECT bboxes_text FROM video_frames WHERE video_uuid = ? AND frame_number = ?', (video_uuid, frame_number) ).fetchone() conn.close() bboxes_text = frame_data['bboxes_text'] if frame_data else None if bboxes_text and bboxes_text.strip(): rects, labels, _ = convert_text_to_rects_and_labels(bboxes_text) for i, rect in enumerate(rects): label = labels[i] color = string_to_color_bgr(label) x1, y1, x2, y2 = rect cv2.rectangle(image, (x1, y1), (x2, y2), color, 2) (text_width, text_height), baseline = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1) cv2.rectangle(image, (x1, y1 - text_height - 5), (x1 + text_width, y1), color, -1) cv2.putText(image, label, (x1, y1 - 4), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1) success, buffer = cv2.imencode('.jpg', image) if not success: return "Failed to encode image", 500 return Response(buffer.tobytes(), mimetype='image/jpeg') except Exception as e: logging.error(f"Error generating annotated frame for {video_uuid}/{frame_number}: {e}") return "Internal server error", 500 @app.route('/listVideos', methods=['GET']) def list_videos(): all_videos = database.get_all_video_list() ready_videos = database.get_ready_videos_with_labels() return jsonify({ 'all_videos': [sanitize_dict(v) for v in all_videos], 'ready_videos_for_dataset': [sanitize_dict(v) for v in ready_videos] }) @app.route('/uploadVideo', methods=['POST']) def upload_video(): desc = request.form.get('description') video_file = request.files.get('video_file') is_valid, message = validate_description(desc, [v['description'] for v in database.get_all_video_list()]) if not is_valid: return jsonify({'success': False, 'message': message}), 400 if not video_file: return jsonify({'success': False, 'message': 'No video file provided.'}), 400 create_time_ms = int(time.time() * 1000) video_uuid = database.create_video_entry(desc, video_file.filename, 0, create_time_ms) file_storage.save_uploaded_video(video_file, video_uuid) threading.Thread(target=background_tasks.extract_frames_task, args=(video_uuid,), name=f"Extractor-{video_uuid[:6]}").start() return jsonify({'success': True, 'video_uuid': video_uuid}) @app.route('/importFrames', methods=['POST']) def import_frames(): video_uuid = request.form.get('video_uuid') frame_files = request.files.getlist('frame_files') if not video_uuid or not frame_files: return jsonify({'success': False, 'message': 'Missing video UUID or frame files.'}), 400 video = database.get_video_entity(video_uuid) if not video: return jsonify({'success': False, 'message': 'Video not found.'}), 404 try: imported_count = database.add_frames_from_upload(video_uuid, frame_files) return jsonify({'success': True, 'imported_count': imported_count}) except Exception as e: logging.error(f"Failed to import frames for {video_uuid}: {e}") return jsonify({'success': False, 'message': str(e)}), 500 @app.route('/retrieveVideoEntity', methods=['POST']) def retrieve_video_entity(): video_uuid = request.json.get('video_uuid') entity = database.get_video_entity(video_uuid) if entity: return jsonify({'success': True, 'video_entity': sanitize_dict(entity)}) return jsonify({'success': False, 'message': 'Video not found.'}) @app.route('/deleteVideo', methods=['POST']) def delete_video(): video_uuid = request.json.get('video_uuid') database.delete_video(video_uuid) file_storage.delete_video_file(video_uuid) file_storage.delete_frames_for_video(video_uuid) return jsonify({'success': True}) @app.route('/retrieveVideoFrames', methods=['POST']) def retrieve_video_frames(): video_uuid = request.json.get('video_uuid') frames = database.get_video_frames(video_uuid) for frame in frames: frame['image_url'] = f"/media/frames/{video_uuid}/frame_{frame['frame_number']:05d}.jpg" return jsonify({'success': True, 'frames': [sanitize_dict(f) for f in frames]}) @app.route('/storeVideoFrameBboxesText', methods=['POST']) def store_video_frame_bboxes_text(): data = request.json video_uuid = data.get('video_uuid') frame_number = int(data.get('frame_number')) bboxes_text = validate_bboxes_text(data.get('bboxes_text')) database.save_frame_bboxes(video_uuid, frame_number, bboxes_text) return jsonify({'success': True}) @app.route('/listTasks', methods=['GET']) def list_tasks(): video_uuid = request.args.get('video_uuid') if not video_uuid: return jsonify({'success': False, 'message': 'Video UUID is required.'}), 400 tasks = database.get_tasks_for_video(video_uuid) return jsonify({'success': True, 'tasks': [sanitize_dict(t) for t in tasks]}) @app.route('/createTask', methods=['POST']) def create_task(): data = request.json video_uuid = data.get('video_uuid') assigned_to = data.get('assigned_to') description = data.get('description', '') start_frame = data.get('start_frame') end_frame = data.get('end_frame') if not all([video_uuid, assigned_to, start_frame is not None, end_frame is not None]): return jsonify({'success': False, 'message': 'Missing required fields.'}), 400 try: start_frame, end_frame = int(start_frame), int(end_frame) except (ValueError, TypeError): return jsonify({'success': False, 'message': 'Frame numbers must be integers.'}), 400 video = database.get_video_entity(video_uuid) if not video: return jsonify({'success': False, 'message': 'Video not found.'}), 404 if not (0 <= start_frame < end_frame < video['frame_count']): return jsonify({'success': False, 'message': f'Invalid frame range. Must be within 0 and {video["frame_count"] - 1}.'}), 400 try: task_uuid = database.create_annotation_task(video_uuid, assigned_to, description, start_frame, end_frame) return jsonify({'success': True, 'task_uuid': task_uuid}) except ValueError as e: return jsonify({'success': False, 'message': str(e)}), 400 @app.route('/deleteTask', methods=['POST']) def delete_task(): task_uuid = request.json.get('task_uuid') database.delete_task(task_uuid) return jsonify({'success': True}) @app.route('/updateTaskStatus', methods=['POST']) def update_task_status(): data = request.json task_uuid = data.get('task_uuid') status = data.get('status') if not task_uuid or status not in ['PENDING', 'IN_PROGRESS', 'COMPLETED']: return jsonify({'success': False, 'message': 'Invalid task UUID or status.'}), 400 database.update_task_status(task_uuid, status) return jsonify({'success': True}) @app.route('/listClasses', methods=['GET']) def list_classes(): labels = database.get_all_class_labels() return jsonify({'success': True, 'labels': labels}) @app.route('/api/rebuild_prototypes', methods=['POST']) def rebuild_prototypes(): data = request.json class_name = data.get('class_name') if class_name: logging.info(f"Manual prototype rebuild requested for class: {class_name}") prototype_executor.submit(ai_models.update_prototype_for_class, class_name) message = f"Started rebuilding prototype for '{class_name}'." else: logging.info("Manual prototype rebuild requested for ALL classes.") all_labels = database.get_all_class_labels() for label in all_labels: prototype_executor.submit(ai_models.update_prototype_for_class, label) message = f"Started rebuilding prototypes for all {len(all_labels)} classes in the background." return jsonify({'success': True, 'message': message}) @app.route('/api/interpolateBboxes', methods=['POST']) def interpolate_bboxes(): data = request.json video_uuid = data.get('video_uuid') object_id = data.get('object_id') start_frame_data = data.get('start_frame') end_frame_data = data.get('end_frame') if not all([video_uuid, object_id, start_frame_data, end_frame_data]): return jsonify({'success': False, 'message': 'Missing required data.'}), 400 try: start_frame_num = int(start_frame_data['frame_number']) end_frame_num = int(end_frame_data['frame_number']) start_bbox = start_frame_data['bbox'] end_bbox = end_frame_data['bbox'] label = start_bbox['label'] if start_frame_num >= end_frame_num: start_frame_num, end_frame_num = end_frame_num, start_frame_num start_bbox, end_bbox = end_bbox, start_bbox total_steps = end_frame_num - start_frame_num if total_steps <= 1: return jsonify({'success': True, 'message': 'No frames to interpolate.'}) for i in range(1, total_steps): current_frame_num = start_frame_num + i t = i / float(total_steps) interp_x1 = int(start_bbox['x1'] + (end_bbox['x1'] - start_bbox['x1']) * t) interp_y1 = int(start_bbox['y1'] + (end_bbox['y1'] - start_bbox['y1']) * t) interp_x2 = int(start_bbox['x2'] + (end_bbox['x2'] - start_bbox['x2']) * t) interp_y2 = int(start_bbox['y2'] + (end_bbox['y2'] - start_bbox['y2']) * t) new_bbox_line = f"{interp_x1},{interp_y1},{interp_x2},{interp_y2},{label},{object_id}" conn = database.get_db_connection() frame_db = conn.execute('SELECT bboxes_text FROM video_frames WHERE video_uuid = ? AND frame_number = ?', (video_uuid, current_frame_num)).fetchone() existing_bboxes = frame_db['bboxes_text'] if frame_db else '' lines = existing_bboxes.split('\n') if existing_bboxes else [] updated_lines = [line for line in lines if not line.endswith(f',{object_id}')] updated_lines.append(new_bbox_line) final_bboxes_text = '\n'.join(filter(None, updated_lines)) conn.close() database.save_frame_bboxes(video_uuid, current_frame_num, final_bboxes_text) return jsonify({'success': True, 'message': f'Interpolated {total_steps - 1} frames successfully.'}) except Exception as e: logging.error(f"Interpolation failed: {e}", exc_info=True) return jsonify({'success': False, 'message': str(e)}), 500 @app.route('/addClass', methods=['POST']) def add_class(): data = request.json label_name = data.get('label_name', '').strip() if not label_name: return jsonify({'success': False, 'message': 'Label name cannot be empty.'}), 400 database.add_class_label(label_name) return jsonify({'success': True}) @app.route('/deleteClass', methods=['POST']) def delete_class(): data = request.json label_name = data.get('label_name') if not label_name: return jsonify({'success': False, 'message': 'Label name is required.'}), 400 database.delete_class_label(label_name) return jsonify({'success': True}) @app.route('/api/settings', methods=['GET']) def get_settings(): settings = settings_manager.load_settings() return jsonify({'success': True, 'settings': settings}) @app.route('/api/settings', methods=['POST']) def save_settings(): new_settings = request.json if not new_settings: return jsonify({'success': False, 'message': 'No settings data provided.'}), 400 current_settings = settings_manager.load_settings() sam_model_changed = current_settings.get('sam_model_checkpoint') != new_settings.get('sam_model_checkpoint') mobilenet_model_changed = current_settings.get('feature_extractor_model_name') != new_settings.get('feature_extractor_model_name') device_changed = current_settings.get('gpu_device') != new_settings.get('gpu_device') restart_required = sam_model_changed or mobilenet_model_changed or device_changed if settings_manager.save_settings(new_settings): if sam_model_changed or device_changed: logging.info("SAM model or device setting changed. Clearing SAM cache.") try: from ultralytics_sam_tasks import _sam_model_cache _sam_model_cache["model"] = None _sam_model_cache["path"] = None except (ImportError, AttributeError): logging.warning("Could not clear SAM model cache.") if mobilenet_model_changed or device_changed: logging.info("Feature extractor model or device setting changed. Clearing MobileNet cache.") ai_models.clear_feature_extractor_cache() if device_changed: settings_manager.update_device() return jsonify({ 'success': True, 'message': 'Settings saved successfully!', 'restart_required': restart_required }) else: return jsonify({'success': False, 'message': 'Failed to save settings to file.'}), 500 @app.route('/api/clear_cache', methods=['POST']) def clear_cache(): try: count = len(ai_models.PREPROCESSED_DATA_CACHE) ai_models.PREPROCESSED_DATA_CACHE.clear() logging.info(f"Cleared {count} items from PREPROCESSED_DATA_CACHE.") return jsonify({'success': True, 'message': f'Successfully cleared {count} cached items.'}) except Exception as e: logging.error(f"Failed to clear cache: {e}") return jsonify({'success': False, 'message': 'An error occurred while clearing the cache.'}), 500 @app.route('/samPredict', methods=['POST']) def sam_predict(): try: from ultralytics_sam_tasks import predict_box_from_point_ultralytics, get_sam_model if not get_sam_model(): return jsonify({'success': False, 'message': 'Ultralytics SAM model is not available.'}), 501 except ImportError: return jsonify({'success': False, 'message': 'SAM features are not installed on server.'}), 501 data = request.json video_uuid = data.get('video_uuid') frame_number = data.get('frame_number') point_coords = data.get('point') if not all([video_uuid, frame_number is not None, point_coords]): return jsonify({'success': False, 'message': 'Missing required data (video_uuid, frame_number, point).'}), 400 try: image_path = file_storage.get_frame_path(video_uuid, int(frame_number)) if not os.path.exists(image_path): return jsonify({'success': False, 'message': 'Frame image not found on server.'}), 404 coords_tuple = (int(point_coords['x']), int(point_coords['y'])) bbox = predict_box_from_point_ultralytics(image_path, coords_tuple) if bbox: return jsonify({'success': True, 'bbox': bbox}) else: return jsonify({'success': False, 'message': 'No object found at the specified point.'}) except Exception as e: logging.error(f"SAM prediction failed: {e}", exc_info=True) return jsonify({'success': False, 'message': str(e)}), 500 @app.route('/interactive_segment/preprocess', methods=['POST']) def interactive_segment_preprocess_route(): data = request.json video_uuid = data.get('video_uuid') frame_number = int(data.get('frame_number')) if video_uuid is None or frame_number is None: return jsonify({'success': False, 'message': 'Missing video_uuid or frame_number.'}), 400 try: ai_models.get_features_for_all_masks(video_uuid, frame_number) cache_key = f"{video_uuid}_{frame_number}" return jsonify({'success': True, 'message': 'Preprocessing successful', 'cache_key': cache_key}) except Exception as e: logging.error(f"智能选择预处理失败: {e}", exc_info=True) return jsonify({'success': False, 'message': f'Internal Server Error: {str(e)}'}), 500 @app.route('/interactive_segment/predict', methods=['POST']) def interactive_segment_predict_route(): data = request.json video_uuid = data.get('video_uuid') frame_number = int(data.get('frame_number')) prompt_boxes = data.get('prompt_boxes', []) if not all([video_uuid, frame_number is not None, prompt_boxes]): return jsonify({'success': False, 'message': 'Missing required data.'}), 400 if not prompt_boxes: return jsonify({'success': False, 'message': 'Positive prompt boxes are required.'}), 400 try: positive_prompt_box = prompt_boxes[0] results = ai_models.predict_from_one_shot(video_uuid, frame_number, positive_prompt_box) return jsonify({'success': True, 'results': results}) except Exception as e: logging.error(f"智能选择预测失败: {e}", exc_info=True) return jsonify({'success': False, 'message': f'Internal Server Error: {str(e)}'}), 500 @app.route('/interactive_segment/predict_from_dataset', methods=['POST']) def predict_from_dataset_route(): data = request.json video_uuid = data.get('video_uuid') frame_number = int(data.get('frame_number')) class_name = data.get('class_name') if not all([video_uuid, frame_number is not None, class_name]): return jsonify({'success': False, 'message': 'Missing required data.'}), 400 try: positive_prototypes = ai_models.build_prototypes_for_class(class_name) if positive_prototypes is None or len(positive_prototypes) == 0: return jsonify({'success': False, 'message': f"No labeled examples found for class '{class_name}' in the dataset, or failed to extract features."}) results = ai_models.predict_with_prototypes(video_uuid, frame_number, positive_prototypes) return jsonify({'success': True, 'results': results}) except Exception as e: logging.error(f"Dataset-driven prediction failed: {e}", exc_info=True) return jsonify({'success': False, 'message': f'Internal Server Error: {str(e)}'}), 500 @app.route('/api/background_preprocess_frame', methods=['POST']) def background_preprocess_frame(): data = request.json video_uuid = data.get('video_uuid') frame_number = data.get('frame_number') if not video_uuid or frame_number is None: return jsonify({'success': False, 'message': 'Missing data.'}), 400 if background_tasks.active_tasks.get(video_uuid): return jsonify({'success': False, 'message': 'Another task is active.'}) cache_key = f"{video_uuid}_{frame_number}" if cache_key in ai_models.PREPROCESSED_DATA_CACHE: return jsonify({'success': True, 'message': 'Already cached.'}) threading.Thread( target=lambda: ai_models.get_features_for_all_masks(video_uuid, frame_number), name=f"Preprocess-{video_uuid[:6]}-{frame_number}" ).start() return jsonify({'success': True, 'message': 'Preprocessing started in background.'}) @app.route('/api/get_random_frames_for_neg_sampling', methods=['POST']) def get_random_frames_for_neg_sampling(): data = request.json video_uuid = data.get('video_uuid') count = int(data.get('count', 10)) if not video_uuid: return jsonify({'success': False, 'message': 'Video UUID is required.'}), 400 try: all_frame_numbers = database.get_frame_numbers_for_video(video_uuid) if len(all_frame_numbers) < count: sampled_numbers = all_frame_numbers else: sampled_numbers = random.sample(all_frame_numbers, count) frames_data = [] for fn in sorted(sampled_numbers): frames_data.append({ 'video_uuid': video_uuid, 'frame_number': fn, 'image_url': f"/media/frames/{video_uuid}/frame_{fn:05d}.jpg" }) return jsonify({'success': True, 'frames': frames_data}) except Exception as e: logging.error(f"Error getting random frames: {e}", exc_info=True) return jsonify({'success': False, 'message': str(e)}), 500 @app.route('/apply_prototypes_to_video', methods=['POST']) def apply_prototypes_to_video_route(): data = request.json video_uuid = data.get('video_uuid') class_name = data.get('class_name') negative_samples = data.get('negative_samples', None) confidence_threshold = float(data.get('confidence_threshold', 0.5)) if not video_uuid or not class_name: return jsonify({'success': False, 'message': 'Video UUID and Class Name are required.'}), 400 if background_tasks.active_tasks.get(video_uuid): return jsonify({'success': False, 'message': 'Another task is already running for this video.'}), 409 threading.Thread( target=background_tasks.apply_prototypes_to_video_task, args=(video_uuid, class_name, negative_samples, confidence_threshold, app.app_context()), name=f"ApplyPrototypes-{video_uuid[:6]}" ).start() return jsonify({'success': True, 'message': 'Task to apply suggestions has started.'}) @app.route('/startSam2Tracking', methods=['POST']) def start_sam2_tracking(): try: from ultralytics_sam_tasks import get_sam_model if not get_sam_model(): return jsonify({'success': False, 'message': 'SAM tracking feature is not available on the server.'}), 501 except ImportError: return jsonify({'success': False, 'message': 'SAM features are not installed on server.'}), 501 data = request.json video_uuid = data.get('video_uuid') start_frame = int(data.get('start_frame')) end_frame = int(data.get('end_frame')) init_bboxes_text = data.get('init_bboxes_text') if not all([video_uuid, start_frame is not None, end_frame is not None, init_bboxes_text]): return jsonify({'success': False, 'message': 'Missing required data for tracking.'}), 400 if background_tasks.active_tasks.get(video_uuid): return jsonify( {'success': False, 'message': 'Another task (extraction or tracking) is already running for this video.'}) tracker_uuid = str(uuid.uuid4().hex) threading.Thread(target=background_tasks.start_sam2_tracking_task, args=( video_uuid, tracker_uuid, start_frame, end_frame, init_bboxes_text ), name=f"SAM-Tracker-{video_uuid[:6]}").start() return jsonify({'success': True, 'tracker_uuid': tracker_uuid}) @app.route('/startSam2BatchTracking', methods=['POST']) def start_sam2_batch_tracking(): try: from ultralytics_sam_tasks import get_sam_model if not get_sam_model(): return jsonify({'success': False, 'message': 'SAM tracking feature is not available on the server.'}), 501 except ImportError: return jsonify({'success': False, 'message': 'SAM features are not installed on server.'}), 501 data = request.json video_uuid = data.get('video_uuid') start_frame = int(data.get('start_frame')) end_frame = int(data.get('end_frame')) init_bboxes_text = data.get('init_bboxes_text') if not all([video_uuid, start_frame is not None, end_frame is not None, init_bboxes_text]): return jsonify({'success': False, 'message': 'Missing required data for batch tracking.'}), 400 if background_tasks.active_tasks.get(video_uuid): return jsonify( {'success': False, 'message': 'Another task (extraction or tracking) is already running for this video.'}) tracker_uuid = str(uuid.uuid4().hex) threading.Thread(target=background_tasks.start_sam2_batch_tracking_task, args=( video_uuid, tracker_uuid, start_frame, end_frame, init_bboxes_text ), name=f"SAM-Batch-Tracker-{video_uuid[:6]}").start() return jsonify({'success': True, 'tracker_uuid': tracker_uuid}) @app.route('/streamSam2Tracking/') def stream_sam2_tracking(tracker_uuid): def generate_events(): while tracker_uuid not in background_tasks.tracking_sessions: time.sleep(0.1) session = background_tasks.tracking_sessions.get(tracker_uuid) if not session: error_event = {"event": "error", "message": "Tracking session not found or failed to start."} yield f"data: {json.dumps(error_event)}\n\n" return last_sent_frame = -1 try: while True: status = session.get('status', 'STARTING') sorted_frames = sorted([k for k in session.get('results', {}).keys() if k > last_sent_frame]) for frame_num in sorted_frames: result_data = { "event": "update", "frame_number": frame_num, "bboxes_text": session['results'][frame_num], "progress": session.get('progress', 0), "total": session.get('total', 0) } yield f"data: {json.dumps(result_data)}\n\n" last_sent_frame = frame_num if status in ['COMPLETED', 'STOPPED', 'FAILED']: final_event = {"event": status.lower(), "message": session.get('message', '')} yield f"data: {json.dumps(final_event)}\n\n" break time.sleep(0.2) except GeneratorExit: logging.info(f"Client disconnected from SSE stream for tracker {tracker_uuid}") finally: logging.info(f"SSE stream for tracker {tracker_uuid} is closing.") return Response(generate_events(), mimetype='text/event-stream') @app.route('/stopSam2Tracking', methods=['POST']) def stop_sam2_tracking(): tracker_uuid = request.json.get('tracker_uuid') if tracker_uuid in background_tasks.tracking_sessions: session = background_tasks.tracking_sessions[tracker_uuid] session['stop_requested'] = True logging.info(f"Stop request received for SAM tracking session {tracker_uuid}") return jsonify({'success': True}) return jsonify({'success': False, 'message': 'Tracker not found.'}) @app.route('/prepareToStartTracking', methods=['POST']) def prepare_to_start_tracking(): data = request.json video_uuid = data.get('video_uuid') if background_tasks.active_tasks.get(video_uuid): return jsonify( {'success': False, 'message': 'Another task (extraction or tracking) is already running for this video.'}) tracker_uuid = str(uuid.uuid4().hex) threading.Thread(target=background_tasks.start_tracking_task, args=( video_uuid, tracker_uuid, data.get('tracker_name'), float(data.get('scale')), int(data.get('init_frame_number')), data.get('init_bboxes_text'), ), name=f"Tracker-{video_uuid[:6]}").start() return jsonify({'success': True, 'tracker_uuid': tracker_uuid}) @app.route('/retrieveTrackedBboxes', methods=['POST']) def retrieve_tracked_bboxes(): tracker_uuid = request.json.get('tracker_uuid') session = background_tasks.tracking_sessions.get(tracker_uuid) if session: session['last_client_update'] = time.time() return jsonify({ 'success': True, 'tracker_failed': session['status'] in ['FAILED', 'TIMED OUT'], 'frame_number': session.get('current_frame'), 'bboxes_text': session.get('bboxes_text'), }) return jsonify({'success': False, 'tracker_failed': True}) @app.route('/continueTracking', methods=['POST']) def continue_tracking(): data = request.json tracker_uuid = data.get('tracker_uuid') session = background_tasks.tracking_sessions.get(tracker_uuid) if session and session['status'] == 'RUNNING': session['last_client_update'] = time.time() session['bboxes_text'] = data.get('bboxes_text') session['current_frame'] = int(data.get('frame_number')) + 1 database.save_frame_bboxes(data.get('video_uuid'), int(data.get('frame_number')), data.get('bboxes_text')) return jsonify({'success': True}) return jsonify({'success': False}) @app.route('/stopTracking', methods=['POST']) def stop_tracking(): tracker_uuid = request.json.get('tracker_uuid') if tracker_uuid in background_tasks.tracking_sessions: background_tasks.tracking_sessions[tracker_uuid]['stop_requested'] = True return jsonify({'success': True}) @app.route('/listDatasets', methods=['GET']) def list_datasets(): datasets = database.get_dataset_list() return jsonify({'datasets': [sanitize_dict(d) for d in datasets]}) @app.route('/createDataset', methods=['POST']) def create_dataset(): data = request.json desc = data.get('description') video_uuids = data.get('video_uuids') eval_percent = float(data.get('eval_percent', 20.0)) test_percent = float(data.get('test_percent', 10.0)) augmentation_options = data.get('augmentation_options', {}) is_valid, message = validate_description(desc, [d['description'] for d in database.get_dataset_list()]) if not is_valid: return jsonify({'success': False, 'message': message}), 400 if not video_uuids: return jsonify({'success': False, 'message': 'Please select at least one video.'}), 400 create_time = int(time.time() * 1000) dataset_uuid = database.create_dataset_entry(desc, video_uuids, create_time, eval_percent, test_percent) threading.Thread(target=background_tasks.create_dataset_task, args=( dataset_uuid, video_uuids, eval_percent, test_percent, augmentation_options ), name=f"Dataset-{dataset_uuid[:6]}").start() return jsonify({'success': True, 'dataset_uuid': dataset_uuid}) @app.route('/regenerateDataset', methods=['POST']) def regenerate_dataset(): dataset_uuid = request.json.get('dataset_uuid') if not dataset_uuid: return jsonify({'success': False, 'message': 'Dataset UUID is required.'}), 400 dataset = database.get_dataset_entity(dataset_uuid) if not dataset: return jsonify({'success': False, 'message': 'Dataset not found.'}), 404 file_storage.delete_dataset_files(dataset_uuid) database.update_dataset_status(dataset_uuid, 'PENDING') video_uuids = json.loads(dataset['video_uuids']) eval_percent = dataset.get('eval_percent') test_percent = dataset.get('test_percent') augmentation_options = {'enabled': False} threading.Thread(target=background_tasks.create_dataset_task, args=( dataset_uuid, video_uuids, eval_percent, test_percent, augmentation_options ), name=f"Dataset-Regen-{dataset_uuid[:6]}").start() return jsonify({'success': True, 'message': 'Dataset regeneration started.'}) @app.route('/downloadDataset/') def download_dataset(dataset_uuid): dataset = database.get_dataset_entity(dataset_uuid) if not dataset or dataset['status'] != 'READY' or not dataset['zip_path']: return "Dataset not found or not ready.", 404 try: return send_file(dataset['zip_path'], as_attachment=True) except Exception as e: logging.error(f"Could not send file: {e}") return "Error downloading file.", 500 @app.route('/deleteDataset', methods=['POST']) def delete_dataset(): dataset_uuid = request.json.get('dataset_uuid') database.delete_dataset(dataset_uuid) file_storage.delete_dataset_files(dataset_uuid) return jsonify({'success': True}) @app.route('/listModels', methods=['GET']) def list_models(): models = database.get_model_list() return jsonify({'models': [sanitize_dict(m) for m in models]}) @app.route('/importModel', methods=['POST']) def import_model(): desc = request.form.get('description') model_file = request.files.get('model_file') label_file = request.files.get('label_file') model_type = request.form.get('model_type') is_valid, message = validate_description(desc, [m['description'] for m in database.get_model_list()]) if not is_valid: return jsonify({'success': False, 'message': message}), 400 if not model_file or not model_file.filename.endswith('.tflite'): return jsonify({'success': False, 'message': 'Please provide a .tflite model file.'}), 400 if not label_file or not (label_file.filename.endswith('.txt') or label_file.filename.endswith('.labels')): return jsonify({'success': False, 'message': 'Please provide a .txt or .labels file.'}), 400 if not model_type in ['float32', 'uint8']: return jsonify({'success': False, 'message': 'Invalid model type selected.'}), 400 create_time = int(time.time() * 1000) model_uuid = database.import_model_metadata(desc, label_file.filename, model_type, create_time) file_storage.save_imported_model(model_file, model_uuid) file_storage.save_imported_label_file(label_file, model_uuid) return jsonify({'success': True, 'model_uuid': model_uuid}) @app.route('/deleteModel', methods=['POST']) def delete_model(): model_uuid = request.json.get('model_uuid') database.delete_model(model_uuid) file_storage.delete_model_file(model_uuid) file_storage.delete_label_file(model_uuid) return jsonify({'success': True}) @app.route('/startPreAnnotation', methods=['POST']) def start_pre_annotation(): data = request.json video_uuid = data.get('video_uuid') model_uuid = data.get('model_uuid') options = data.get('options', {}) if not video_uuid or not model_uuid: return jsonify({'success': False, 'message': 'Video UUID and Model UUID are required.'}), 400 video = database.get_video_entity(video_uuid) if not video: return jsonify({'success': False, 'message': 'Video not found.'}), 404 if video['status'] != 'READY': return jsonify({'success': False, 'message': f"Video must be in READY state, but is {video['status']}."}), 400 if background_tasks.active_tasks.get(video_uuid): return jsonify({'success': False, 'message': 'Another task is already running for this video.'}), 409 try: options['start_frame'] = int(options.get('start_frame', 0)) options['end_frame'] = int(options.get('end_frame', video['frame_count'] - 1)) options['confidence'] = float(options.get('confidence', 0.5)) options['merge_strategy'] = options.get('merge_strategy', 'overwrite') except (ValueError, TypeError) as e: return jsonify({'success': False, 'message': f'Invalid options provided: {e}'}), 400 threading.Thread( target=background_tasks.pre_annotate_video_task, args=(video_uuid, model_uuid, options), name=f"PreAnnotator-{video_uuid[:6]}" ).start() return jsonify({'success': True, 'message': 'Pre-annotation task started.'}) @app.route('/cancelTask', methods=['POST']) def cancel_task(): video_uuid = request.json.get('video_uuid') if not video_uuid: return jsonify({'success': False, 'message': 'Video UUID is required.'}), 400 video = database.get_video_entity(video_uuid) if not video: return jsonify({'success': False, 'message': 'Video not found.'}), 404 if video['status'] in ['PRE_ANNOTATING', 'APPLYING_PROTOTYPES']: database.update_video_status(video_uuid, 'CANCELLING', 'Cancellation requested by user.') return jsonify({'success': True, 'message': 'Cancellation request sent.'}) else: return jsonify({'success': False, 'message': f'Cannot cancel task, video status is {video["status"]}.'}), 400 @app.route('/datasetAnalysis/') def dataset_analysis(dataset_uuid): dataset = database.get_dataset_entity(dataset_uuid) if not dataset: return "Dataset not found", 404 return render_template('dataset_analysis.html', dataset=sanitize_dict(dataset), limit_data=config.get_limit_data_for_render_template()) @app.route('/api/datasetAnalysis/', methods=['GET']) def get_dataset_analysis_data(dataset_uuid): dataset = database.get_dataset_entity(dataset_uuid) if not dataset: return jsonify({'success': False, 'message': 'Dataset not found.'}), 404 video_uuids = json.loads(dataset.get('video_uuids', '[]')) tasks_by_video = {vu: database.get_tasks_for_video(vu) for vu in video_uuids} video_info_cache = {vu: database.get_video_entity(vu) for vu in video_uuids} def get_task_for_frame(video_uuid, frame_number): for task in tasks_by_video.get(video_uuid, []): if task['start_frame'] <= frame_number <= task['end_frame']: return task['task_uuid'] return None all_frames = [ {**dict(frame), 'video_uuid': vu, 'video_description': video_info_cache[vu].get('description')} for vu in video_uuids for frame in database.get_video_frames(vu) if frame.get('bboxes_text', '').strip() ] class_counts = Counter() aspect_ratios, objects_per_image, center_points, brightness_levels = [], [], [], [] all_bboxes_for_outliers = [] suspicious_pairs = [] image_class_map = {} for i, frame in enumerate(all_frames): video_uuid, frame_number, bboxes_text = frame['video_uuid'], frame['frame_number'], frame['bboxes_text'] rects, labels, _ = convert_text_to_rects_and_labels(bboxes_text) image_class_map[i] = list(set(labels)) objects_per_image.append(len(labels)) if len(rects) > 1: for (idx1, rect1), (idx2, rect2) in itertools.combinations(enumerate(rects), 2): iou = calculate_iou(rect1, rect2) if iou > 0.95: suspicious_pairs.append({ 'image_index': i, 'iou': iou, 'box1_label': labels[idx1], 'box2_label': labels[idx2] }) try: image_gray = skio.imread(file_storage.get_frame_path(video_uuid, frame_number), as_gray=True) brightness_levels.append(np.mean(image_gray) * 255) except Exception: pass for j, rect in enumerate(rects): class_counts[labels[j]] += 1 width, height = int(rect[2] - rect[0]), int(rect[3] - rect[1]) if width > 0 and height > 0: aspect_ratios.append(width / height) video_info = video_info_cache[video_uuid] if video_info and video_info['width'] > 0 and video_info['height'] > 0: center_x = (float(rect[0]) + float(rect[2])) / 2.0 / float(video_info['width']) center_y = (float(rect[1]) + float(rect[3])) / 2.0 / float(video_info['height']) center_points.append({'x': center_x, 'y': center_y}) all_bboxes_for_outliers.append( {'id': f'{video_uuid}_{frame_number}_{j}', 'image_index': i, 'area': width * height, 'aspect_ratio': width / height}) annotator_stats = {} all_tasks = [task for vid_tasks in tasks_by_video.values() for task in vid_tasks] for task in all_tasks: user = task['assigned_to'] if user not in annotator_stats: annotator_stats[user] = {'image_count': 0, 'class_counts': Counter()} user_frame_sets = {user: set() for user in annotator_stats.keys()} for frame in all_frames: for task in all_tasks: if task['video_uuid'] == frame['video_uuid'] and task['start_frame'] <= frame['frame_number'] <= task[ 'end_frame']: user = task['assigned_to'] user_frame_sets[user].add(f"{frame['video_uuid']}_{frame['frame_number']}") annotator_stats[user]['class_counts'].update(extract_labels(frame['bboxes_text'])) for user, frame_set in user_frame_sets.items(): annotator_stats[user]['image_count'] = len(frame_set) gallery_images = [{ 'original_url': f"/media/frames/{f['video_uuid']}/frame_{f['frame_number']:05d}.jpg", 'video': f['video_description'], 'frame': f['frame_number'], 'video_uuid': f['video_uuid'], 'task_uuid': get_task_for_frame(f['video_uuid'], f['frame_number']) } for f in all_frames] warnings = [] total_instances = sum(class_counts.values()) if class_counts: avg_instances = total_instances / len(class_counts) for class_name, count in class_counts.items(): if count < 10 or count < avg_instances * 0.1: warnings.append( f"Class Imbalance: Class '{class_name}' has very few instances ({count}), which may affect model performance.") small_object_threshold = 100 small_object_count = sum(1 for bbox in all_bboxes_for_outliers if bbox['area'] < small_object_threshold) if small_object_count > 0: warnings.append( f"Small Object Warning: Found {small_object_count} objects with an area smaller than {small_object_threshold} pixels. Please check if they are labeling errors.") if suspicious_pairs: warnings.append( f"Potential Duplicate Warning: Found {len(suspicious_pairs)} pairs of bounding boxes with high overlap (IoU > 0.95), which might be duplicate labels.") summary_text = f"This dataset contains {len(class_counts)} classes, with a total of {total_instances} instances across {len(all_frames)} labeled images." return jsonify({ 'success': True, 'summary_text': summary_text, 'warnings': warnings, 'class_counts': dict(class_counts), 'aspect_ratios': aspect_ratios, 'objects_per_image': objects_per_image, 'center_points': center_points, 'brightness_levels': brightness_levels, 'annotator_stats': {u: {'image_count': d['image_count'], 'class_counts': dict(d['class_counts'])} for u, d in annotator_stats.items()}, 'all_bboxes': all_bboxes_for_outliers, 'suspicious_pairs': suspicious_pairs, 'image_class_map': image_class_map, 'gallery_images': gallery_images }) def calculate_color_histogram(image_path, rect): try: image = cv2.imread(image_path) if image is None: return None x1, y1, x2, y2 = map(int, rect) roi = image[y1:y2, x1:x2] if roi.size == 0: return None hsv_roi = cv2.cvtColor(roi, cv2.COLOR_BGR2HSV) hist = cv2.calcHist([hsv_roi], [0, 1], None, [16, 16], [0, 180, 0, 256]) cv2.normalize(hist, hist, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX) return hist.flatten() except Exception as e: logging.warning(f"计算颜色直方图失败: {e} for path {image_path}") return None @app.route('/api/datasetAnalysis//consistency_check', methods=['POST']) def run_consistency_check(dataset_uuid): try: request_data = request.get_json() is_color_check_enabled = request_data.get('enable_color_check', True) if is_color_check_enabled: logging.info("Starting AI Quality Control with SEMANTIC and COLOR checks.") else: logging.info("Starting AI Quality Control with SEMANTIC check ONLY.") dataset = database.get_dataset_entity(dataset_uuid) if not dataset: return jsonify({'success': False, 'message': 'Dataset not found.'}), 404 video_uuids = json.loads(dataset.get('video_uuids', '[]')) all_frames = [ dict(frame) for vu in video_uuids for frame in database.get_video_frames(vu) if frame.get('bboxes_text', '').strip() ] all_bboxes_info = [] frames_to_process = defaultdict(list) for i, frame in enumerate(all_frames): rects, labels, _ = convert_text_to_rects_and_labels(frame['bboxes_text']) for j, rect in enumerate(rects): box_info = {'image_index': i, 'rect': rect, 'label': labels[j], 'video_uuid': frame['video_uuid'], 'frame_number': frame['frame_number'], 'embedding': None, 'color_hist': None} frames_to_process[f"{frame['video_uuid']};{frame['frame_number']}"].append((len(all_bboxes_info), rect)) all_bboxes_info.append(box_info) for frame_key, rect_data in frames_to_process.items(): video_uuid, frame_number_str = frame_key.split(';') frame_number = int(frame_number_str) image_path = file_storage.get_frame_path(video_uuid, frame_number) rects_in_frame = [r for _, r in rect_data] embeddings = ai_models.get_features_for_specific_bboxes(video_uuid, frame_number, rects_in_frame) for i, (global_idx, rect) in enumerate(rect_data): if embeddings is not None and i < len(embeddings): all_bboxes_info[global_idx]['embedding'] = embeddings[i] color_hist = calculate_color_histogram(image_path, rect) if color_hist is not None: all_bboxes_info[global_idx]['color_hist'] = color_hist all_features_by_class = defaultdict(list) for info in all_bboxes_info: if info['embedding'] is not None and (not is_color_check_enabled or info['color_hist'] is not None): all_features_by_class[info['label']].append(info) prototype_library = {} for class_name, infos in all_features_by_class.items(): if len(infos) > 0: prototype_library[class_name] = {} embeddings_tensor = torch.stack([info['embedding'] for info in infos]) prototype_library[class_name]['semantic'] = torch.mean(embeddings_tensor, dim=0) if is_color_check_enabled: color_hists = [info['color_hist'] for info in infos if info['color_hist'] is not None] if color_hists: prototype_library[class_name]['color'] = np.mean(np.array(color_hists), axis=0) outlier_image_indices = set() COLOR_CONFUSION_FACTOR = 2.0 for class_name, infos in all_features_by_class.items(): if len(infos) < 5 or class_name not in prototype_library: continue for info in infos: is_outlier = False candidate_embedding = info['embedding'] own_semantic_sim = F.cosine_similarity(candidate_embedding, prototype_library[class_name]['semantic'], dim=0).item() for other_class, other_prototypes in prototype_library.items(): if other_class == class_name: continue other_semantic_sim = F.cosine_similarity(candidate_embedding, other_prototypes['semantic'], dim=0).item() if other_semantic_sim > own_semantic_sim + 0.2: logging.info( f"SEMANTIC outlier: A '{class_name}' (sim: {own_semantic_sim:.2f}) looks more like a '{other_class}' (sim: {other_semantic_sim:.2f}). Image index: {info['image_index']}") is_outlier = True break if is_outlier: outlier_image_indices.add(info['image_index']) continue if is_color_check_enabled: if 'color' not in prototype_library[class_name] or info['color_hist'] is None: continue if len(prototype_library) < 2: continue candidate_color_hist = info['color_hist'] dist_to_own_color = np.sum(np.abs(candidate_color_hist - prototype_library[class_name]['color'])) min_dist_to_other_color = float('inf') closest_other_class = None for other_class, other_prototypes in prototype_library.items(): if other_class == class_name or 'color' not in other_prototypes: continue dist = np.sum(np.abs(candidate_color_hist - other_prototypes['color'])) if dist < min_dist_to_other_color: min_dist_to_other_color = dist closest_other_class = other_class if closest_other_class and min_dist_to_other_color * COLOR_CONFUSION_FACTOR < dist_to_own_color: logging.info( f"COLOR outlier: A '{class_name}' (color dist: {dist_to_own_color:.2f}) has a color profile much closer to '{closest_other_class}' (color dist: {min_dist_to_other_color:.2f}). Image index: {info['image_index']}") is_outlier = True if is_outlier: outlier_image_indices.add(info['image_index']) message_keyword = "**Category or color**" if is_color_check_enabled else "**category**" if len(outlier_image_indices) == 1 or len(outlier_image_indices) == 0: message = f"AI review complete. Found {len(outlier_image_indices)} image with potential instances of {message_keyword} confusion." if outlier_image_indices else "AI review completed. No obvious labeling confusion issues were found." else: message = f"AI review complete. Found {len(outlier_image_indices)} images with potential instances of {message_keyword} confusion." if outlier_image_indices else "AI review completed. No obvious labeling confusion issues were found." return jsonify({ 'success': True, 'message': message, 'outlier_image_indices': list(outlier_image_indices) }) except Exception as e: logging.error(f"On-demand consistency check failed: {e}", exc_info=True) return jsonify({'success': False, 'message': f'审查失败: {e}'}), 500 @app.route('/lam_predict', methods=['POST']) def lam_predict_route(): data = request.json video_uuid = data.get('video_uuid') frame_number = data.get('frame_number') point = data.get('point') if not all([video_uuid, frame_number is not None, point]): return jsonify({'success': False, 'message': 'Missing required request parameters.'}), 400 try: point_coords = (int(point['x']), int(point['y'])) result, error_msg = ai_models.lam_predict(video_uuid, int(frame_number), point_coords) if error_msg: return jsonify({'success': False, 'message': error_msg}) return jsonify({'success': True, **result}) except Exception as e: logging.error(f"LAM 预测失败: {e}", exc_info=True) return jsonify({'success': False, 'message': f'Internal Server Error: {str(e)}'}), 500 @app.route('/api/previewAugmentations', methods=['POST']) def preview_augmentations(): if not background_tasks.A: return jsonify({'success': False, 'message': 'Albumentations library not installed on server.'}), 501 data = request.json video_uuid = data.get('video_uuid') frame_number = data.get('frame_number') augmentation_options = data.get('augmentation_options') sample_pool = data.get('sample_pool') if not all([video_uuid, frame_number is not None, augmentation_options]): return jsonify({'success': False, 'message': 'Missing required data.'}), 400 try: if augmentation_options.get('mosaic', {}).get('enabled'): if not sample_pool: return jsonify({'success': False, 'message': 'Sample pool is required for Mosaic preview.'}), 400 if random.random() < augmentation_options['mosaic'].get('p', 1.0): previews = generate_mosaic_previews(sample_pool, video_uuid, frame_number) return jsonify({'success': True, 'previews': previews}) frame_path = file_storage.get_frame_path(video_uuid, frame_number) if not os.path.exists(frame_path): return jsonify({'success': False, 'message': 'Frame image not found.'}), 404 image = cv2.imread(frame_path) image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) video_info = database.get_video_entity(video_uuid) conn = database.get_db_connection() frame_db_info = conn.execute('SELECT bboxes_text FROM video_frames WHERE video_uuid = ? AND frame_number = ?', (video_uuid, frame_number)).fetchone() conn.close() if not frame_db_info or not frame_db_info['bboxes_text']: return jsonify({'success': False, 'message': 'No labels found for this frame.'}), 404 augmentation_options['mosaic'] = {'enabled': False} pipeline = background_tasks.build_augmentation_pipeline(augmentation_options) if not pipeline: return jsonify({'success': False, 'message': 'No valid augmentations selected.'}), 400 all_labels = database.get_all_class_labels() class_map = {name: i for i, name in enumerate(all_labels)} yolo_bboxes, class_indices = file_storage.get_yolo_bboxes(frame_db_info['bboxes_text'], video_info['width'], video_info['height'], class_map) if not yolo_bboxes: return jsonify({'success': False, 'message': 'Could not parse labels into YOLO format.'}), 500 previews = [] for _ in range(6): transformed = pipeline(image=image_rgb, bboxes=yolo_bboxes, class_labels=class_indices) aug_image_rgb = transformed['image'] aug_bboxes_yolo = transformed['bboxes'] aug_labels_indices = transformed['class_labels'] h, w, _ = aug_image_rgb.shape vis_image = aug_image_rgb.copy() for i, bbox in enumerate(aug_bboxes_yolo): class_index = int(aug_labels_indices[i]) class_name = all_labels[class_index] color = string_to_color_bgr(class_name) x_center, y_center, width_norm, height_norm = bbox x1 = int((x_center - width_norm / 2) * w) y1 = int((y_center - height_norm / 2) * h) x2 = int((x_center + width_norm / 2) * w) y2 = int((y_center + height_norm / 2) * h) cv2.rectangle(vis_image, (x1, y1), (x2, y2), color, 2) vis_image_bgr = cv2.cvtColor(vis_image, cv2.COLOR_RGB2BGR) _, buffer = cv2.imencode('.jpg', vis_image_bgr) img_base64 = base64.b64encode(buffer).decode('utf-8') previews.append(f"data:image/jpeg;base64,{img_base64}") return jsonify({'success': True, 'previews': previews}) except Exception as e: logging.error(f"Augmentation preview failed: {e}", exc_info=True) return jsonify({'success': False, 'message': str(e)}), 500 if __name__ == '__main__': from waitress import serve init(autoreset=True) root_logger = logging.getLogger() root_logger.setLevel(logging.INFO) if root_logger.hasHandlers(): root_logger.handlers.clear() console_handler = logging.StreamHandler() formatter = ColoredFormatter('%(asctime)s - %(levelname)s - %(threadName)s - %(message)s') console_handler.setFormatter(formatter) root_logger.addHandler(console_handler) logging.info("正在初始化AI模型,请稍候...") ai_models.startup_ai_models() atexit.register(ai_models.save_preprocessed_cache_to_disk) atexit.register(ai_models.save_prototypes_to_disk) time.sleep(0.01) logging.info("=" * 121) logging.info("███████╗ ███████╗ ██████╗ ██████╗ ██████╗ ██╗ ██╗ ██████╗ ██╗ ██████╗ ██╗ ██╗ █████╗ ██████╗ ██████╗") logging.info("╚══███╔╝ ██╔════╝ ██╔══██╗ ██╔═══██╗ ╚════██╗ ╚██╗ ██╔╝ ██╔═══██╗ ██║ ██╔═══██╗ ╚██╗ ██╔╝ ██╔══██╗ ██╔══██╗ ██╔══██╗") logging.info("███╔╝ █████╗ ██████╔╝ ██║ ██║ █████╔╝ ╚████╔╝ ██║ ██║ ██║ ██║ ██║ ╚████╔╝ ███████║ ██████╔╝ ██║ ██║") logging.info("███╔╝ ██╔══╝ ██╔══██╗ ██║ ██║ ██╔═══╝ ╚██╔╝ ██║ ██║ ██║ ██║ ██║ ╚██╔╝ ██╔══██║ ██╔══██╗ ██║ ██║") logging.info("███████╗ ███████╗ ██║ ██║ ╚██████╔╝ ███████╗ ██║ ╚██████╔╝ ███████╗ ╚██████╔╝ ██║ ██║ ██║ ██║ ██║ ██████╔╝") logging.info("╚══════╝ ╚══════╝ ╚═╝ ╚═╝ ╚═════╝ ╚══════╝ ╚═╝ ╚═════╝ ╚══════╝ ╚═════╝ ╚═╝ ╚═╝ ╚═╝ ╚═╝ ╚═╝ ╚═════╝ ") logging.info("Developed by BlueDarkUP from FIRST Tech Challenge team 27570 Be based on -- FIRST Machine Learning Toolchain --") logging.info("Open your web browser and go to http://127.0.0.1:5000") logging.info("=" * 121) serve(app, host='0.0.0.0', port=5000)