signed weighted mask

This commit is contained in:
SaiD 2025-12-13 11:57:42 +05:30
parent a392807855
commit 36c122d9a7
3 changed files with 176 additions and 113 deletions

View File

@ -2,6 +2,7 @@ package com.example.livingai.data.camera
import android.content.Context import android.content.Context
import android.graphics.Bitmap import android.graphics.Bitmap
import android.graphics.Color
import android.graphics.RectF import android.graphics.RectF
import android.util.Log import android.util.Log
import com.example.livingai.R import com.example.livingai.R
@ -58,7 +59,7 @@ class DefaultTiltChecker : TiltChecker {
override suspend fun analyze(input: PipelineInput): Instruction { override suspend fun analyze(input: PipelineInput): Instruction {
Log.d("TiltChecker", "Required Orientation: ${input.requiredOrientation}, Pitch: ${input.devicePitch}, Roll: ${input.deviceRoll}, Azimuth: ${input.deviceAzimuth}") Log.d("TiltChecker", "Required Orientation: ${input.requiredOrientation}, Pitch: ${input.devicePitch}, Roll: ${input.deviceRoll}, Azimuth: ${input.deviceAzimuth}")
val tolerance = 10.0f val tolerance = 25.0f
val isLevel: Boolean val isLevel: Boolean
if (input.requiredOrientation == CameraOrientation.PORTRAIT) { if (input.requiredOrientation == CameraOrientation.PORTRAIT) {
@ -87,6 +88,7 @@ class TFLiteObjectDetector(context: Context) : ObjectDetector {
private var labels: List<String> = emptyList() private var labels: List<String> = emptyList()
private var modelInputWidth: Int = 0 private var modelInputWidth: Int = 0
private var modelInputHeight: Int = 0 private var modelInputHeight: Int = 0
private var maxDetections: Int = 25
init { init {
try { try {
@ -102,6 +104,14 @@ class TFLiteObjectDetector(context: Context) : ObjectDetector {
} else { } else {
Log.e("TFLiteObjectDetector", "Invalid input tensor shape.") Log.e("TFLiteObjectDetector", "Invalid input tensor shape.")
} }
val outputTensor = interpreter?.getOutputTensor(0)
val outputShape = outputTensor?.shape()
if (outputShape != null && outputShape.size >= 2) {
maxDetections = outputShape[1]
Log.d("TFLiteObjectDetector", "Max detections from model: $maxDetections")
}
Log.d("TFLiteObjectDetector", "TFLite model loaded successfully.") Log.d("TFLiteObjectDetector", "TFLite model loaded successfully.")
} catch (e: IOException) { } catch (e: IOException) {
Log.e("TFLiteObjectDetector", "Error loading TFLite model or labels from assets.", e) Log.e("TFLiteObjectDetector", "Error loading TFLite model or labels from assets.", e)
@ -121,7 +131,6 @@ class TFLiteObjectDetector(context: Context) : ObjectDetector {
val byteBuffer = convertBitmapToByteBuffer(resizedBitmap) val byteBuffer = convertBitmapToByteBuffer(resizedBitmap)
// Define model outputs with the correct size // Define model outputs with the correct size
val maxDetections = 25
val outputLocations = Array(1) { Array(maxDetections) { FloatArray(4) } } val outputLocations = Array(1) { Array(maxDetections) { FloatArray(4) } }
val outputClasses = Array(1) { FloatArray(maxDetections) } val outputClasses = Array(1) { FloatArray(maxDetections) }
val outputScores = Array(1) { FloatArray(maxDetections) } val outputScores = Array(1) { FloatArray(maxDetections) }
@ -223,120 +232,172 @@ class MockPoseAnalyzer : PoseAnalyzer {
} }
override suspend fun analyze(input: PipelineInput): Instruction { override suspend fun analyze(input: PipelineInput): Instruction {
val detectionResult = input.previousDetectionResult ?: return Instruction("No detection result", isValid = false)
val detectionResult =
input.previousDetectionResult
?: return Instruction("No detection result", isValid = false)
if (!detectionResult.isAnimalDetected || detectionResult.animalBounds == null) { if (!detectionResult.isAnimalDetected || detectionResult.animalBounds == null) {
return Instruction("Animal not detected", isValid = false) return Instruction("Animal not detected", isValid = false)
} }
val image = input.image ?: return Instruction("No image", isValid = false) val image =
val animalMask = getAnimalSegmentation(image) // BooleanArray input.image
?: return Instruction("No image", isValid = false)
// --- Get weighted silhouette mask --- // --------------------------------------------------------------------
var weightedSilhouette = SilhouetteManager.getWeightedMask(input.orientation) // 1. Get reference signed-weight silhouette (Array<FloatArray>)
?: return Instruction("Silhouette not found for ${input.orientation}", isValid = false) // --------------------------------------------------------------------
val weightedSilhouette: Array<FloatArray> =
SilhouetteManager.getWeightedMask(input.orientation)
?: return Instruction(
"Silhouette not found for ${input.orientation}",
isValid = false
)
// Ensure silhouette mask matches camera frame size val refHeight = weightedSilhouette.size
if (weightedSilhouette.width != image.width || weightedSilhouette.height != image.height) { val refWidth = weightedSilhouette[0].size
weightedSilhouette =
Bitmap.createScaledBitmap(weightedSilhouette, image.width, image.height, true) // --------------------------------------------------------------------
// 2. Scale input image to match reference resolution
// --------------------------------------------------------------------
val scaledImage = Bitmap.createScaledBitmap(
image,
refWidth,
refHeight,
true
)
// --------------------------------------------------------------------
// 3. Get animal segmentation → BINARY mask (0 / 1)
// --------------------------------------------------------------------
val animalMask: Array<IntArray> =
getAnimalSegmentationMask(scaledImage)
?: return Instruction("Segmentation failed", isValid = false)
// --------------------------------------------------------------------
// 4. Compute signed-weight alignment score
// --------------------------------------------------------------------
val similarity = calculateSignedSimilarity(
animalMask,
weightedSilhouette
)
Log.d("MockPoseAnalyzer", "Signed Similarity = $similarity")
val isValid = similarity >= 0.40f
val msg = if (isValid) {
"Pose Correct"
} else {
"Pose Incorrect (Score: %.2f)".format(similarity)
} }
val weightMap = convertWeightedBitmapToFloatArray(weightedSilhouette) return Instruction(
message = msg,
// --- Compute weighted Jaccard --- isValid = isValid,
val jaccard = calculateWeightedJaccard(animalMask, weightMap) result = detectionResult
)
Log.d("MockPoseAnalyzer", "Weighted Jaccard Similarity = $jaccard")
val isValid = jaccard >= 0.40f
val msg = if (isValid) "Pose Correct" else "Pose Incorrect (Jaccard: %.2f)".format(jaccard)
return Instruction(message = msg, isValid = isValid, result = detectionResult)
} }
/**
* Converts segmentation output to a binary mask:
* 1 = animal present
* 0 = background
*/
private suspend fun getAnimalSegmentationMask(
bitmap: Bitmap
): Array<IntArray>? {
val maskBitmap = getAnimalSegmentation(bitmap) ?: return null
val w = maskBitmap.width
val h = maskBitmap.height
val pixels = IntArray(w * h)
maskBitmap.getPixels(pixels, 0, w, 0, 0, w, h)
val result = Array(h) { IntArray(w) }
for (y in 0 until h) {
for (x in 0 until w) {
val alpha = pixels[y * w + x] ushr 24
result[y][x] = if (alpha > 0) 1 else 0
}
}
return result
}
// ---------------------------------------------------------------------- // ----------------------------------------------------------------------
// REAL segmentation using ML Kit // REAL segmentation using ML Kit
// ---------------------------------------------------------------------- // ----------------------------------------------------------------------
private suspend fun getAnimalSegmentation(bitmap: Bitmap): BooleanArray = suspendCancellableCoroutine { continuation -> private suspend fun getAnimalSegmentation(bitmap: Bitmap): Bitmap = suspendCancellableCoroutine { continuation ->
val inputImage = InputImage.fromBitmap(bitmap, 0) val inputImage = InputImage.fromBitmap(bitmap, 0)
segmenter.process(inputImage) segmenter.process(inputImage)
.addOnSuccessListener { result -> .addOnSuccessListener { result ->
val mask = result.foregroundConfidenceMask val mask = result.foregroundConfidenceMask
if (mask != null) { if (mask != null) {
val floatArray = FloatArray(mask.capacity()) val maskBitmap = Bitmap.createBitmap(bitmap.width, bitmap.height, Bitmap.Config.ARGB_8888)
mask.get(floatArray) val buffer = mask
val booleanArray = BooleanArray(floatArray.size) { i -> buffer.rewind()
floatArray[i] > 0.5f for (y in 0 until bitmap.height) {
for (x in 0 until bitmap.width) {
val confidence = buffer.get()
val alpha = (confidence * 255).toInt()
maskBitmap.setPixel(x, y, Color.argb(alpha, 255, 255, 255))
}
} }
continuation.resume(booleanArray) continuation.resume(maskBitmap)
} else { } else {
Log.e("MockPoseAnalyzer", "Segmentation result null") Log.e("MockPoseAnalyzer", "Segmentation result null")
continuation.resume(BooleanArray(bitmap.width * bitmap.height) { false }) continuation.resume(Bitmap.createBitmap(bitmap.width, bitmap.height, Bitmap.Config.ARGB_8888))
} }
} }
.addOnFailureListener { e -> .addOnFailureListener { e ->
Log.e("MockPoseAnalyzer", "Segmentation failed", e) Log.e("MockPoseAnalyzer", "Segmentation failed", e)
continuation.resume(BooleanArray(bitmap.width * bitmap.height) { false }) continuation.resume(Bitmap.createBitmap(bitmap.width, bitmap.height, Bitmap.Config.ARGB_8888))
} }
} }
// ----------------------------------------------------------------------
// Convert weighted mask bitmap → float[] values in range -1..1
// ----------------------------------------------------------------------
private fun convertWeightedBitmapToFloatArray(bitmap: Bitmap): FloatArray {
val w = bitmap.width
val h = bitmap.height
val pixels = IntArray(w * h)
bitmap.getPixels(pixels, 0, w, 0, 0, w, h)
val out = FloatArray(w * h)
for (i in pixels.indices) {
val color = pixels[i] and 0xFF
val norm = (color / 255f) * 2f - 1f // Converts 0..255 → -1..1
out[i] = norm
}
return out
}
// ---------------------------------------------------------------------- // ----------------------------------------------------------------------
// Weighted Jaccard Similarity // Weighted Jaccard Similarity
// mask = predicted (BooleanArray) // mask = predicted (Bitmap)
// weightMap = ground truth silhouette weights (-1..1) // weightMap = ground truth silhouette weights (Bitmap)
// ---------------------------------------------------------------------- // ----------------------------------------------------------------------
private fun calculateWeightedJaccard(predMask: BooleanArray, weight: FloatArray): Float { /**
* Signed alignment similarity
* Range: [-1, 1]
*/
private fun calculateSignedSimilarity(
mask: Array<IntArray>, // 0 / 1
reference: Array<FloatArray> // [-1, +1]
): Float {
if (predMask.size != weight.size) return 0f var score = 0f
var maxScore = 0f
var weightedIntersection = 0f val h = reference.size
var weightedUnion = 0f val w = reference[0].size
for (i in predMask.indices) { for (y in 0 until h) {
for (x in 0 until w) {
val r = reference[y][x]
val m = mask[y][x].toFloat()
val w = weight[i] // -1.0 .. 1.0 // accumulate positive reference mass only
if (r > 0f) {
val pred = predMask[i] // true/false maxScore += r
val silhouetteInside = w > 0f
val intersection = pred && silhouetteInside
val union = pred || silhouetteInside
if (intersection) weightedIntersection += w.coerceAtLeast(0f)
if (union) {
// Penalize far outside with negative weight also
weightedUnion += if (silhouetteInside) {
w.coerceAtLeast(0f)
} else {
(-w) // penalty
} }
score += m * r
} }
} }
if (weightedUnion == 0f) return 0f return if (maxScore == 0f) 0f else score / maxScore
return weightedIntersection / weightedUnion
} }
} }
class DefaultCaptureHandler : CaptureHandler { class DefaultCaptureHandler : CaptureHandler {

View File

@ -13,11 +13,11 @@ object SilhouetteManager {
private val originals = ConcurrentHashMap<String, Bitmap>() private val originals = ConcurrentHashMap<String, Bitmap>()
private val invertedPurple = ConcurrentHashMap<String, Bitmap>() private val invertedPurple = ConcurrentHashMap<String, Bitmap>()
private val weightedMasks = ConcurrentHashMap<String, Bitmap>() private val weightedMasks = ConcurrentHashMap<String, Array<FloatArray>>()
fun getOriginal(name: String): Bitmap? = originals[name] fun getOriginal(name: String): Bitmap? = originals[name]
fun getInvertedPurple(name: String): Bitmap? = invertedPurple[name] fun getInvertedPurple(name: String): Bitmap? = invertedPurple[name]
fun getWeightedMask(name: String): Bitmap? = weightedMasks[name] fun getWeightedMask(name: String): Array<FloatArray>? = weightedMasks[name]
fun initialize(context: Context, width: Int, height: Int) { fun initialize(context: Context, width: Int, height: Int) {
val resources = context.resources val resources = context.resources
@ -77,15 +77,18 @@ object SilhouetteManager {
return Bitmap.createScaledBitmap(inverted, targetWidth, targetHeight, true) return Bitmap.createScaledBitmap(inverted, targetWidth, targetHeight, true)
} }
// ------------------------------------------------------------------------ /**
// STEP 2: Create signed weighted mask (-1 to 1) * Creates a signed weighted mask in range [-1, +1]
// ------------------------------------------------------------------------ *
* +1 : deep inside object
private fun createSignedWeightedMask( * 0 : object boundary
* -1 : far outside object
*/
fun createSignedWeightedMask(
bitmap: Bitmap, bitmap: Bitmap,
fadeInside: Int = 10, fadeInside: Int = 10,
fadeOutside: Int = 20 fadeOutside: Int = 20
): Bitmap { ): Array<FloatArray> {
val w = bitmap.width val w = bitmap.width
val h = bitmap.height val h = bitmap.height
@ -95,24 +98,25 @@ object SilhouetteManager {
fun idx(x: Int, y: Int) = y * w + x fun idx(x: Int, y: Int) = y * w + x
// inside = 1 → silhouette purple // --------------------------------------------------------------------
// inside = 0 → outside // 1. Binary mask: inside = 1, outside = 0
// Assumption: NON-transparent pixels are inside the object
// --------------------------------------------------------------------
val inside = IntArray(w * h) val inside = IntArray(w * h)
for (i in pixels.indices) { for (i in pixels.indices) {
val alpha = pixels[i] ushr 24 val alpha = pixels[i] ushr 24
inside[i] = if (alpha == 0) 1 else 0 inside[i] = if (alpha > 0) 1 else 0
} }
// -------------------------------------------------------------------- // --------------------------------------------------------------------
// DISTANCES FOR INSIDE PIXELS (to nearest OUTSIDE pixel) // 2. Distance to nearest OUTSIDE pixel (for inside pixels)
// -------------------------------------------------------------------- // --------------------------------------------------------------------
val distInside = IntArray(w * h) { Int.MAX_VALUE } val distInside = IntArray(w * h) { Int.MAX_VALUE }
for (i in inside.indices) { for (i in inside.indices) {
if (inside[i] == 0) distInside[i] = 0 if (inside[i] == 0) distInside[i] = 0
} }
// forward // forward pass
for (y in 0 until h) { for (y in 0 until h) {
for (x in 0 until w) { for (x in 0 until w) {
val i = idx(x, y) val i = idx(x, y)
@ -123,7 +127,7 @@ object SilhouetteManager {
} }
} }
// backward // backward pass
for (y in h - 1 downTo 0) { for (y in h - 1 downTo 0) {
for (x in w - 1 downTo 0) { for (x in w - 1 downTo 0) {
val i = idx(x, y) val i = idx(x, y)
@ -135,15 +139,14 @@ object SilhouetteManager {
} }
// -------------------------------------------------------------------- // --------------------------------------------------------------------
// DISTANCES FOR OUTSIDE PIXELS (to nearest INSIDE pixel) // 3. Distance to nearest INSIDE pixel (for outside pixels)
// -------------------------------------------------------------------- // --------------------------------------------------------------------
val distOutside = IntArray(w * h) { Int.MAX_VALUE } val distOutside = IntArray(w * h) { Int.MAX_VALUE }
for (i in inside.indices) { for (i in inside.indices) {
if (inside[i] == 1) distOutside[i] = 0 if (inside[i] == 1) distOutside[i] = 0
} }
// forward // forward pass
for (y in 0 until h) { for (y in 0 until h) {
for (x in 0 until w) { for (x in 0 until w) {
val i = idx(x, y) val i = idx(x, y)
@ -154,7 +157,7 @@ object SilhouetteManager {
} }
} }
// backward // backward pass
for (y in h - 1 downTo 0) { for (y in h - 1 downTo 0) {
for (x in w - 1 downTo 0) { for (x in w - 1 downTo 0) {
val i = idx(x, y) val i = idx(x, y)
@ -166,31 +169,30 @@ object SilhouetteManager {
} }
// -------------------------------------------------------------------- // --------------------------------------------------------------------
// BUILD FINAL SIGNED MASK (-1 to +1) // 4. Build signed weight map [-1, +1]
// -------------------------------------------------------------------- // --------------------------------------------------------------------
val result = Array(h) { FloatArray(w) }
val out = Bitmap.createBitmap(w, h, Bitmap.Config.ARGB_8888) for (y in 0 until h) {
val outPixels = IntArray(w * h) for (x in 0 until w) {
val i = idx(x, y)
for (i in outPixels.indices) { val weight = if (inside[i] == 1) {
// Inside: +1 → 0
val d = distInside[i]
if (d >= fadeInside) 1f
else d.toFloat() / fadeInside
} else {
// Outside: 0 → -1
val d = distOutside[i]
val v = -(d.toFloat() / fadeOutside)
v.coerceAtLeast(-1f)
}
val weight: Float = if (inside[i] == 1) { result[y][x] = weight
// Inside silhouette: +1 to 0
val d = distInside[i]
if (d >= fadeInside) 1f else d.toFloat() / fadeInside
} else {
// Outside: 0 to -1
val d = distOutside[i]
val neg = -(d.toFloat() / fadeOutside)
neg.coerceAtLeast(-1f)
} }
// Convert -1..1 → grayscale for debugging
val gray = (((weight + 1f) / 2f) * 255).toInt().coerceIn(0, 255)
outPixels[i] = Color.argb(255, gray, gray, gray)
} }
out.setPixels(outPixels, 0, w, 0, 0, w, h) return result
return out
} }
} }