From 36c122d9a76e4a76079c4c9230890635c388963c Mon Sep 17 00:00:00 2001 From: SaiD Date: Sat, 13 Dec 2025 11:57:42 +0530 Subject: [PATCH] signed weighted mask --- ...kotlin-compiler-1475630269848992605.salive | 0 .../data/camera/PipelineImplementations.kt | 211 +++++++++++------- .../livingai/utils/SilhouetteManager.kt | 78 +++---- 3 files changed, 176 insertions(+), 113 deletions(-) create mode 100644 .kotlin/sessions/kotlin-compiler-1475630269848992605.salive diff --git a/.kotlin/sessions/kotlin-compiler-1475630269848992605.salive b/.kotlin/sessions/kotlin-compiler-1475630269848992605.salive new file mode 100644 index 0000000..e69de29 diff --git a/app/src/main/java/com/example/livingai/data/camera/PipelineImplementations.kt b/app/src/main/java/com/example/livingai/data/camera/PipelineImplementations.kt index 6e5b09a..f45cfa2 100644 --- a/app/src/main/java/com/example/livingai/data/camera/PipelineImplementations.kt +++ b/app/src/main/java/com/example/livingai/data/camera/PipelineImplementations.kt @@ -2,6 +2,7 @@ package com.example.livingai.data.camera import android.content.Context import android.graphics.Bitmap +import android.graphics.Color import android.graphics.RectF import android.util.Log import com.example.livingai.R @@ -58,7 +59,7 @@ class DefaultTiltChecker : TiltChecker { override suspend fun analyze(input: PipelineInput): Instruction { Log.d("TiltChecker", "Required Orientation: ${input.requiredOrientation}, Pitch: ${input.devicePitch}, Roll: ${input.deviceRoll}, Azimuth: ${input.deviceAzimuth}") - val tolerance = 10.0f + val tolerance = 25.0f val isLevel: Boolean if (input.requiredOrientation == CameraOrientation.PORTRAIT) { @@ -87,6 +88,7 @@ class TFLiteObjectDetector(context: Context) : ObjectDetector { private var labels: List = emptyList() private var modelInputWidth: Int = 0 private var modelInputHeight: Int = 0 + private var maxDetections: Int = 25 init { try { @@ -102,6 +104,14 @@ class TFLiteObjectDetector(context: Context) : ObjectDetector { } else { Log.e("TFLiteObjectDetector", "Invalid input tensor shape.") } + + val outputTensor = interpreter?.getOutputTensor(0) + val outputShape = outputTensor?.shape() + if (outputShape != null && outputShape.size >= 2) { + maxDetections = outputShape[1] + Log.d("TFLiteObjectDetector", "Max detections from model: $maxDetections") + } + Log.d("TFLiteObjectDetector", "TFLite model loaded successfully.") } catch (e: IOException) { Log.e("TFLiteObjectDetector", "Error loading TFLite model or labels from assets.", e) @@ -121,7 +131,6 @@ class TFLiteObjectDetector(context: Context) : ObjectDetector { val byteBuffer = convertBitmapToByteBuffer(resizedBitmap) // Define model outputs with the correct size - val maxDetections = 25 val outputLocations = Array(1) { Array(maxDetections) { FloatArray(4) } } val outputClasses = Array(1) { FloatArray(maxDetections) } val outputScores = Array(1) { FloatArray(maxDetections) } @@ -223,120 +232,172 @@ class MockPoseAnalyzer : PoseAnalyzer { } override suspend fun analyze(input: PipelineInput): Instruction { - val detectionResult = input.previousDetectionResult ?: return Instruction("No detection result", isValid = false) + + val detectionResult = + input.previousDetectionResult + ?: return Instruction("No detection result", isValid = false) if (!detectionResult.isAnimalDetected || detectionResult.animalBounds == null) { return Instruction("Animal not detected", isValid = false) } - val image = input.image ?: return Instruction("No image", isValid = false) - val animalMask = getAnimalSegmentation(image) // BooleanArray + val image = + input.image + ?: return Instruction("No image", isValid = false) - // --- Get weighted silhouette mask --- - var weightedSilhouette = SilhouetteManager.getWeightedMask(input.orientation) - ?: return Instruction("Silhouette not found for ${input.orientation}", isValid = false) + // -------------------------------------------------------------------- + // 1. Get reference signed-weight silhouette (Array) + // -------------------------------------------------------------------- + val weightedSilhouette: Array = + SilhouetteManager.getWeightedMask(input.orientation) + ?: return Instruction( + "Silhouette not found for ${input.orientation}", + isValid = false + ) - // Ensure silhouette mask matches camera frame size - if (weightedSilhouette.width != image.width || weightedSilhouette.height != image.height) { - weightedSilhouette = - Bitmap.createScaledBitmap(weightedSilhouette, image.width, image.height, true) + val refHeight = weightedSilhouette.size + val refWidth = weightedSilhouette[0].size + + // -------------------------------------------------------------------- + // 2. Scale input image to match reference resolution + // -------------------------------------------------------------------- + val scaledImage = Bitmap.createScaledBitmap( + image, + refWidth, + refHeight, + true + ) + + // -------------------------------------------------------------------- + // 3. Get animal segmentation → BINARY mask (0 / 1) + // -------------------------------------------------------------------- + val animalMask: Array = + getAnimalSegmentationMask(scaledImage) + ?: return Instruction("Segmentation failed", isValid = false) + + // -------------------------------------------------------------------- + // 4. Compute signed-weight alignment score + // -------------------------------------------------------------------- + val similarity = calculateSignedSimilarity( + animalMask, + weightedSilhouette + ) + + Log.d("MockPoseAnalyzer", "Signed Similarity = $similarity") + + val isValid = similarity >= 0.40f + val msg = if (isValid) { + "Pose Correct" + } else { + "Pose Incorrect (Score: %.2f)".format(similarity) } - val weightMap = convertWeightedBitmapToFloatArray(weightedSilhouette) - - // --- Compute weighted Jaccard --- - val jaccard = calculateWeightedJaccard(animalMask, weightMap) - - Log.d("MockPoseAnalyzer", "Weighted Jaccard Similarity = $jaccard") - - val isValid = jaccard >= 0.40f - val msg = if (isValid) "Pose Correct" else "Pose Incorrect (Jaccard: %.2f)".format(jaccard) - - return Instruction(message = msg, isValid = isValid, result = detectionResult) + return Instruction( + message = msg, + isValid = isValid, + result = detectionResult + ) } + + /** + * Converts segmentation output to a binary mask: + * 1 = animal present + * 0 = background + */ + private suspend fun getAnimalSegmentationMask( + bitmap: Bitmap + ): Array? { + + val maskBitmap = getAnimalSegmentation(bitmap) ?: return null + + val w = maskBitmap.width + val h = maskBitmap.height + + val pixels = IntArray(w * h) + maskBitmap.getPixels(pixels, 0, w, 0, 0, w, h) + + val result = Array(h) { IntArray(w) } + + for (y in 0 until h) { + for (x in 0 until w) { + val alpha = pixels[y * w + x] ushr 24 + result[y][x] = if (alpha > 0) 1 else 0 + } + } + + return result + } + + // ---------------------------------------------------------------------- // REAL segmentation using ML Kit // ---------------------------------------------------------------------- - private suspend fun getAnimalSegmentation(bitmap: Bitmap): BooleanArray = suspendCancellableCoroutine { continuation -> + private suspend fun getAnimalSegmentation(bitmap: Bitmap): Bitmap = suspendCancellableCoroutine { continuation -> val inputImage = InputImage.fromBitmap(bitmap, 0) segmenter.process(inputImage) .addOnSuccessListener { result -> val mask = result.foregroundConfidenceMask if (mask != null) { - val floatArray = FloatArray(mask.capacity()) - mask.get(floatArray) - val booleanArray = BooleanArray(floatArray.size) { i -> - floatArray[i] > 0.5f + val maskBitmap = Bitmap.createBitmap(bitmap.width, bitmap.height, Bitmap.Config.ARGB_8888) + val buffer = mask + buffer.rewind() + for (y in 0 until bitmap.height) { + for (x in 0 until bitmap.width) { + val confidence = buffer.get() + val alpha = (confidence * 255).toInt() + maskBitmap.setPixel(x, y, Color.argb(alpha, 255, 255, 255)) + } } - continuation.resume(booleanArray) + continuation.resume(maskBitmap) } else { Log.e("MockPoseAnalyzer", "Segmentation result null") - continuation.resume(BooleanArray(bitmap.width * bitmap.height) { false }) + continuation.resume(Bitmap.createBitmap(bitmap.width, bitmap.height, Bitmap.Config.ARGB_8888)) } } .addOnFailureListener { e -> Log.e("MockPoseAnalyzer", "Segmentation failed", e) - continuation.resume(BooleanArray(bitmap.width * bitmap.height) { false }) + continuation.resume(Bitmap.createBitmap(bitmap.width, bitmap.height, Bitmap.Config.ARGB_8888)) } } - // ---------------------------------------------------------------------- - // Convert weighted mask bitmap → float[] values in range -1..1 - // ---------------------------------------------------------------------- - private fun convertWeightedBitmapToFloatArray(bitmap: Bitmap): FloatArray { - val w = bitmap.width - val h = bitmap.height - - val pixels = IntArray(w * h) - bitmap.getPixels(pixels, 0, w, 0, 0, w, h) - - val out = FloatArray(w * h) - - for (i in pixels.indices) { - val color = pixels[i] and 0xFF - val norm = (color / 255f) * 2f - 1f // Converts 0..255 → -1..1 - out[i] = norm - } - return out - } - // ---------------------------------------------------------------------- // Weighted Jaccard Similarity - // mask = predicted (BooleanArray) - // weightMap = ground truth silhouette weights (-1..1) + // mask = predicted (Bitmap) + // weightMap = ground truth silhouette weights (Bitmap) // ---------------------------------------------------------------------- - private fun calculateWeightedJaccard(predMask: BooleanArray, weight: FloatArray): Float { + /** + * Signed alignment similarity + * Range: [-1, 1] + */ + private fun calculateSignedSimilarity( + mask: Array, // 0 / 1 + reference: Array // [-1, +1] + ): Float { - if (predMask.size != weight.size) return 0f + var score = 0f + var maxScore = 0f - var weightedIntersection = 0f - var weightedUnion = 0f + val h = reference.size + val w = reference[0].size - for (i in predMask.indices) { + for (y in 0 until h) { + for (x in 0 until w) { + val r = reference[y][x] + val m = mask[y][x].toFloat() - val w = weight[i] // -1.0 .. 1.0 - - val pred = predMask[i] // true/false - val silhouetteInside = w > 0f - - val intersection = pred && silhouetteInside - val union = pred || silhouetteInside - - if (intersection) weightedIntersection += w.coerceAtLeast(0f) - if (union) { - // Penalize far outside with negative weight also - weightedUnion += if (silhouetteInside) { - w.coerceAtLeast(0f) - } else { - (-w) // penalty + // accumulate positive reference mass only + if (r > 0f) { + maxScore += r } + + score += m * r } } - if (weightedUnion == 0f) return 0f - return weightedIntersection / weightedUnion + return if (maxScore == 0f) 0f else score / maxScore } + } class DefaultCaptureHandler : CaptureHandler { diff --git a/app/src/main/java/com/example/livingai/utils/SilhouetteManager.kt b/app/src/main/java/com/example/livingai/utils/SilhouetteManager.kt index fe8e3df..9844729 100644 --- a/app/src/main/java/com/example/livingai/utils/SilhouetteManager.kt +++ b/app/src/main/java/com/example/livingai/utils/SilhouetteManager.kt @@ -13,11 +13,11 @@ object SilhouetteManager { private val originals = ConcurrentHashMap() private val invertedPurple = ConcurrentHashMap() - private val weightedMasks = ConcurrentHashMap() + private val weightedMasks = ConcurrentHashMap>() fun getOriginal(name: String): Bitmap? = originals[name] fun getInvertedPurple(name: String): Bitmap? = invertedPurple[name] - fun getWeightedMask(name: String): Bitmap? = weightedMasks[name] + fun getWeightedMask(name: String): Array? = weightedMasks[name] fun initialize(context: Context, width: Int, height: Int) { val resources = context.resources @@ -77,15 +77,18 @@ object SilhouetteManager { return Bitmap.createScaledBitmap(inverted, targetWidth, targetHeight, true) } - // ------------------------------------------------------------------------ - // STEP 2: Create signed weighted mask (-1 to 1) - // ------------------------------------------------------------------------ - - private fun createSignedWeightedMask( + /** + * Creates a signed weighted mask in range [-1, +1] + * + * +1 : deep inside object + * 0 : object boundary + * -1 : far outside object + */ + fun createSignedWeightedMask( bitmap: Bitmap, fadeInside: Int = 10, fadeOutside: Int = 20 - ): Bitmap { + ): Array { val w = bitmap.width val h = bitmap.height @@ -95,24 +98,25 @@ object SilhouetteManager { fun idx(x: Int, y: Int) = y * w + x - // inside = 1 → silhouette purple - // inside = 0 → outside + // -------------------------------------------------------------------- + // 1. Binary mask: inside = 1, outside = 0 + // Assumption: NON-transparent pixels are inside the object + // -------------------------------------------------------------------- val inside = IntArray(w * h) for (i in pixels.indices) { val alpha = pixels[i] ushr 24 - inside[i] = if (alpha == 0) 1 else 0 + inside[i] = if (alpha > 0) 1 else 0 } // -------------------------------------------------------------------- - // DISTANCES FOR INSIDE PIXELS (to nearest OUTSIDE pixel) + // 2. Distance to nearest OUTSIDE pixel (for inside pixels) // -------------------------------------------------------------------- val distInside = IntArray(w * h) { Int.MAX_VALUE } - for (i in inside.indices) { if (inside[i] == 0) distInside[i] = 0 } - // forward + // forward pass for (y in 0 until h) { for (x in 0 until w) { val i = idx(x, y) @@ -123,7 +127,7 @@ object SilhouetteManager { } } - // backward + // backward pass for (y in h - 1 downTo 0) { for (x in w - 1 downTo 0) { val i = idx(x, y) @@ -135,15 +139,14 @@ object SilhouetteManager { } // -------------------------------------------------------------------- - // DISTANCES FOR OUTSIDE PIXELS (to nearest INSIDE pixel) + // 3. Distance to nearest INSIDE pixel (for outside pixels) // -------------------------------------------------------------------- val distOutside = IntArray(w * h) { Int.MAX_VALUE } - for (i in inside.indices) { if (inside[i] == 1) distOutside[i] = 0 } - // forward + // forward pass for (y in 0 until h) { for (x in 0 until w) { val i = idx(x, y) @@ -154,7 +157,7 @@ object SilhouetteManager { } } - // backward + // backward pass for (y in h - 1 downTo 0) { for (x in w - 1 downTo 0) { val i = idx(x, y) @@ -166,31 +169,30 @@ object SilhouetteManager { } // -------------------------------------------------------------------- - // BUILD FINAL SIGNED MASK (-1 to +1) + // 4. Build signed weight map [-1, +1] // -------------------------------------------------------------------- + val result = Array(h) { FloatArray(w) } - val out = Bitmap.createBitmap(w, h, Bitmap.Config.ARGB_8888) - val outPixels = IntArray(w * h) + for (y in 0 until h) { + for (x in 0 until w) { + val i = idx(x, y) - for (i in outPixels.indices) { + val weight = if (inside[i] == 1) { + // Inside: +1 → 0 + val d = distInside[i] + if (d >= fadeInside) 1f + else d.toFloat() / fadeInside + } else { + // Outside: 0 → -1 + val d = distOutside[i] + val v = -(d.toFloat() / fadeOutside) + v.coerceAtLeast(-1f) + } - val weight: Float = if (inside[i] == 1) { - // Inside silhouette: +1 to 0 - val d = distInside[i] - if (d >= fadeInside) 1f else d.toFloat() / fadeInside - } else { - // Outside: 0 to -1 - val d = distOutside[i] - val neg = -(d.toFloat() / fadeOutside) - neg.coerceAtLeast(-1f) + result[y][x] = weight } - - // Convert -1..1 → grayscale for debugging - val gray = (((weight + 1f) / 2f) * 255).toInt().coerceIn(0, 255) - outPixels[i] = Color.argb(255, gray, gray, gray) } - out.setPixels(outPixels, 0, w, 0, 0, w, h) - return out + return result } }