signed weighted mask

This commit is contained in:
SaiD 2025-12-13 11:57:42 +05:30
parent a392807855
commit 36c122d9a7
3 changed files with 176 additions and 113 deletions

View File

@ -2,6 +2,7 @@ package com.example.livingai.data.camera
import android.content.Context
import android.graphics.Bitmap
import android.graphics.Color
import android.graphics.RectF
import android.util.Log
import com.example.livingai.R
@ -58,7 +59,7 @@ class DefaultTiltChecker : TiltChecker {
override suspend fun analyze(input: PipelineInput): Instruction {
Log.d("TiltChecker", "Required Orientation: ${input.requiredOrientation}, Pitch: ${input.devicePitch}, Roll: ${input.deviceRoll}, Azimuth: ${input.deviceAzimuth}")
val tolerance = 10.0f
val tolerance = 25.0f
val isLevel: Boolean
if (input.requiredOrientation == CameraOrientation.PORTRAIT) {
@ -87,6 +88,7 @@ class TFLiteObjectDetector(context: Context) : ObjectDetector {
private var labels: List<String> = emptyList()
private var modelInputWidth: Int = 0
private var modelInputHeight: Int = 0
private var maxDetections: Int = 25
init {
try {
@ -102,6 +104,14 @@ class TFLiteObjectDetector(context: Context) : ObjectDetector {
} else {
Log.e("TFLiteObjectDetector", "Invalid input tensor shape.")
}
val outputTensor = interpreter?.getOutputTensor(0)
val outputShape = outputTensor?.shape()
if (outputShape != null && outputShape.size >= 2) {
maxDetections = outputShape[1]
Log.d("TFLiteObjectDetector", "Max detections from model: $maxDetections")
}
Log.d("TFLiteObjectDetector", "TFLite model loaded successfully.")
} catch (e: IOException) {
Log.e("TFLiteObjectDetector", "Error loading TFLite model or labels from assets.", e)
@ -121,7 +131,6 @@ class TFLiteObjectDetector(context: Context) : ObjectDetector {
val byteBuffer = convertBitmapToByteBuffer(resizedBitmap)
// Define model outputs with the correct size
val maxDetections = 25
val outputLocations = Array(1) { Array(maxDetections) { FloatArray(4) } }
val outputClasses = Array(1) { FloatArray(maxDetections) }
val outputScores = Array(1) { FloatArray(maxDetections) }
@ -223,120 +232,172 @@ class MockPoseAnalyzer : PoseAnalyzer {
}
override suspend fun analyze(input: PipelineInput): Instruction {
val detectionResult = input.previousDetectionResult ?: return Instruction("No detection result", isValid = false)
val detectionResult =
input.previousDetectionResult
?: return Instruction("No detection result", isValid = false)
if (!detectionResult.isAnimalDetected || detectionResult.animalBounds == null) {
return Instruction("Animal not detected", isValid = false)
}
val image = input.image ?: return Instruction("No image", isValid = false)
val animalMask = getAnimalSegmentation(image) // BooleanArray
val image =
input.image
?: return Instruction("No image", isValid = false)
// --- Get weighted silhouette mask ---
var weightedSilhouette = SilhouetteManager.getWeightedMask(input.orientation)
?: return Instruction("Silhouette not found for ${input.orientation}", isValid = false)
// --------------------------------------------------------------------
// 1. Get reference signed-weight silhouette (Array<FloatArray>)
// --------------------------------------------------------------------
val weightedSilhouette: Array<FloatArray> =
SilhouetteManager.getWeightedMask(input.orientation)
?: return Instruction(
"Silhouette not found for ${input.orientation}",
isValid = false
)
// Ensure silhouette mask matches camera frame size
if (weightedSilhouette.width != image.width || weightedSilhouette.height != image.height) {
weightedSilhouette =
Bitmap.createScaledBitmap(weightedSilhouette, image.width, image.height, true)
val refHeight = weightedSilhouette.size
val refWidth = weightedSilhouette[0].size
// --------------------------------------------------------------------
// 2. Scale input image to match reference resolution
// --------------------------------------------------------------------
val scaledImage = Bitmap.createScaledBitmap(
image,
refWidth,
refHeight,
true
)
// --------------------------------------------------------------------
// 3. Get animal segmentation → BINARY mask (0 / 1)
// --------------------------------------------------------------------
val animalMask: Array<IntArray> =
getAnimalSegmentationMask(scaledImage)
?: return Instruction("Segmentation failed", isValid = false)
// --------------------------------------------------------------------
// 4. Compute signed-weight alignment score
// --------------------------------------------------------------------
val similarity = calculateSignedSimilarity(
animalMask,
weightedSilhouette
)
Log.d("MockPoseAnalyzer", "Signed Similarity = $similarity")
val isValid = similarity >= 0.40f
val msg = if (isValid) {
"Pose Correct"
} else {
"Pose Incorrect (Score: %.2f)".format(similarity)
}
val weightMap = convertWeightedBitmapToFloatArray(weightedSilhouette)
// --- Compute weighted Jaccard ---
val jaccard = calculateWeightedJaccard(animalMask, weightMap)
Log.d("MockPoseAnalyzer", "Weighted Jaccard Similarity = $jaccard")
val isValid = jaccard >= 0.40f
val msg = if (isValid) "Pose Correct" else "Pose Incorrect (Jaccard: %.2f)".format(jaccard)
return Instruction(message = msg, isValid = isValid, result = detectionResult)
return Instruction(
message = msg,
isValid = isValid,
result = detectionResult
)
}
/**
* Converts segmentation output to a binary mask:
* 1 = animal present
* 0 = background
*/
private suspend fun getAnimalSegmentationMask(
bitmap: Bitmap
): Array<IntArray>? {
val maskBitmap = getAnimalSegmentation(bitmap) ?: return null
val w = maskBitmap.width
val h = maskBitmap.height
val pixels = IntArray(w * h)
maskBitmap.getPixels(pixels, 0, w, 0, 0, w, h)
val result = Array(h) { IntArray(w) }
for (y in 0 until h) {
for (x in 0 until w) {
val alpha = pixels[y * w + x] ushr 24
result[y][x] = if (alpha > 0) 1 else 0
}
}
return result
}
// ----------------------------------------------------------------------
// REAL segmentation using ML Kit
// ----------------------------------------------------------------------
private suspend fun getAnimalSegmentation(bitmap: Bitmap): BooleanArray = suspendCancellableCoroutine { continuation ->
private suspend fun getAnimalSegmentation(bitmap: Bitmap): Bitmap = suspendCancellableCoroutine { continuation ->
val inputImage = InputImage.fromBitmap(bitmap, 0)
segmenter.process(inputImage)
.addOnSuccessListener { result ->
val mask = result.foregroundConfidenceMask
if (mask != null) {
val floatArray = FloatArray(mask.capacity())
mask.get(floatArray)
val booleanArray = BooleanArray(floatArray.size) { i ->
floatArray[i] > 0.5f
val maskBitmap = Bitmap.createBitmap(bitmap.width, bitmap.height, Bitmap.Config.ARGB_8888)
val buffer = mask
buffer.rewind()
for (y in 0 until bitmap.height) {
for (x in 0 until bitmap.width) {
val confidence = buffer.get()
val alpha = (confidence * 255).toInt()
maskBitmap.setPixel(x, y, Color.argb(alpha, 255, 255, 255))
}
}
continuation.resume(booleanArray)
continuation.resume(maskBitmap)
} else {
Log.e("MockPoseAnalyzer", "Segmentation result null")
continuation.resume(BooleanArray(bitmap.width * bitmap.height) { false })
continuation.resume(Bitmap.createBitmap(bitmap.width, bitmap.height, Bitmap.Config.ARGB_8888))
}
}
.addOnFailureListener { e ->
Log.e("MockPoseAnalyzer", "Segmentation failed", e)
continuation.resume(BooleanArray(bitmap.width * bitmap.height) { false })
continuation.resume(Bitmap.createBitmap(bitmap.width, bitmap.height, Bitmap.Config.ARGB_8888))
}
}
// ----------------------------------------------------------------------
// Convert weighted mask bitmap → float[] values in range -1..1
// ----------------------------------------------------------------------
private fun convertWeightedBitmapToFloatArray(bitmap: Bitmap): FloatArray {
val w = bitmap.width
val h = bitmap.height
val pixels = IntArray(w * h)
bitmap.getPixels(pixels, 0, w, 0, 0, w, h)
val out = FloatArray(w * h)
for (i in pixels.indices) {
val color = pixels[i] and 0xFF
val norm = (color / 255f) * 2f - 1f // Converts 0..255 → -1..1
out[i] = norm
}
return out
}
// ----------------------------------------------------------------------
// Weighted Jaccard Similarity
// mask = predicted (BooleanArray)
// weightMap = ground truth silhouette weights (-1..1)
// mask = predicted (Bitmap)
// weightMap = ground truth silhouette weights (Bitmap)
// ----------------------------------------------------------------------
private fun calculateWeightedJaccard(predMask: BooleanArray, weight: FloatArray): Float {
/**
* Signed alignment similarity
* Range: [-1, 1]
*/
private fun calculateSignedSimilarity(
mask: Array<IntArray>, // 0 / 1
reference: Array<FloatArray> // [-1, +1]
): Float {
if (predMask.size != weight.size) return 0f
var score = 0f
var maxScore = 0f
var weightedIntersection = 0f
var weightedUnion = 0f
val h = reference.size
val w = reference[0].size
for (i in predMask.indices) {
for (y in 0 until h) {
for (x in 0 until w) {
val r = reference[y][x]
val m = mask[y][x].toFloat()
val w = weight[i] // -1.0 .. 1.0
val pred = predMask[i] // true/false
val silhouetteInside = w > 0f
val intersection = pred && silhouetteInside
val union = pred || silhouetteInside
if (intersection) weightedIntersection += w.coerceAtLeast(0f)
if (union) {
// Penalize far outside with negative weight also
weightedUnion += if (silhouetteInside) {
w.coerceAtLeast(0f)
} else {
(-w) // penalty
// accumulate positive reference mass only
if (r > 0f) {
maxScore += r
}
score += m * r
}
}
if (weightedUnion == 0f) return 0f
return weightedIntersection / weightedUnion
return if (maxScore == 0f) 0f else score / maxScore
}
}
class DefaultCaptureHandler : CaptureHandler {

View File

@ -13,11 +13,11 @@ object SilhouetteManager {
private val originals = ConcurrentHashMap<String, Bitmap>()
private val invertedPurple = ConcurrentHashMap<String, Bitmap>()
private val weightedMasks = ConcurrentHashMap<String, Bitmap>()
private val weightedMasks = ConcurrentHashMap<String, Array<FloatArray>>()
fun getOriginal(name: String): Bitmap? = originals[name]
fun getInvertedPurple(name: String): Bitmap? = invertedPurple[name]
fun getWeightedMask(name: String): Bitmap? = weightedMasks[name]
fun getWeightedMask(name: String): Array<FloatArray>? = weightedMasks[name]
fun initialize(context: Context, width: Int, height: Int) {
val resources = context.resources
@ -77,15 +77,18 @@ object SilhouetteManager {
return Bitmap.createScaledBitmap(inverted, targetWidth, targetHeight, true)
}
// ------------------------------------------------------------------------
// STEP 2: Create signed weighted mask (-1 to 1)
// ------------------------------------------------------------------------
private fun createSignedWeightedMask(
/**
* Creates a signed weighted mask in range [-1, +1]
*
* +1 : deep inside object
* 0 : object boundary
* -1 : far outside object
*/
fun createSignedWeightedMask(
bitmap: Bitmap,
fadeInside: Int = 10,
fadeOutside: Int = 20
): Bitmap {
): Array<FloatArray> {
val w = bitmap.width
val h = bitmap.height
@ -95,24 +98,25 @@ object SilhouetteManager {
fun idx(x: Int, y: Int) = y * w + x
// inside = 1 → silhouette purple
// inside = 0 → outside
// --------------------------------------------------------------------
// 1. Binary mask: inside = 1, outside = 0
// Assumption: NON-transparent pixels are inside the object
// --------------------------------------------------------------------
val inside = IntArray(w * h)
for (i in pixels.indices) {
val alpha = pixels[i] ushr 24
inside[i] = if (alpha == 0) 1 else 0
inside[i] = if (alpha > 0) 1 else 0
}
// --------------------------------------------------------------------
// DISTANCES FOR INSIDE PIXELS (to nearest OUTSIDE pixel)
// 2. Distance to nearest OUTSIDE pixel (for inside pixels)
// --------------------------------------------------------------------
val distInside = IntArray(w * h) { Int.MAX_VALUE }
for (i in inside.indices) {
if (inside[i] == 0) distInside[i] = 0
}
// forward
// forward pass
for (y in 0 until h) {
for (x in 0 until w) {
val i = idx(x, y)
@ -123,7 +127,7 @@ object SilhouetteManager {
}
}
// backward
// backward pass
for (y in h - 1 downTo 0) {
for (x in w - 1 downTo 0) {
val i = idx(x, y)
@ -135,15 +139,14 @@ object SilhouetteManager {
}
// --------------------------------------------------------------------
// DISTANCES FOR OUTSIDE PIXELS (to nearest INSIDE pixel)
// 3. Distance to nearest INSIDE pixel (for outside pixels)
// --------------------------------------------------------------------
val distOutside = IntArray(w * h) { Int.MAX_VALUE }
for (i in inside.indices) {
if (inside[i] == 1) distOutside[i] = 0
}
// forward
// forward pass
for (y in 0 until h) {
for (x in 0 until w) {
val i = idx(x, y)
@ -154,7 +157,7 @@ object SilhouetteManager {
}
}
// backward
// backward pass
for (y in h - 1 downTo 0) {
for (x in w - 1 downTo 0) {
val i = idx(x, y)
@ -166,31 +169,30 @@ object SilhouetteManager {
}
// --------------------------------------------------------------------
// BUILD FINAL SIGNED MASK (-1 to +1)
// 4. Build signed weight map [-1, +1]
// --------------------------------------------------------------------
val result = Array(h) { FloatArray(w) }
val out = Bitmap.createBitmap(w, h, Bitmap.Config.ARGB_8888)
val outPixels = IntArray(w * h)
for (y in 0 until h) {
for (x in 0 until w) {
val i = idx(x, y)
for (i in outPixels.indices) {
val weight = if (inside[i] == 1) {
// Inside: +1 → 0
val d = distInside[i]
if (d >= fadeInside) 1f
else d.toFloat() / fadeInside
} else {
// Outside: 0 → -1
val d = distOutside[i]
val v = -(d.toFloat() / fadeOutside)
v.coerceAtLeast(-1f)
}
val weight: Float = if (inside[i] == 1) {
// Inside silhouette: +1 to 0
val d = distInside[i]
if (d >= fadeInside) 1f else d.toFloat() / fadeInside
} else {
// Outside: 0 to -1
val d = distOutside[i]
val neg = -(d.toFloat() / fadeOutside)
neg.coerceAtLeast(-1f)
result[y][x] = weight
}
// Convert -1..1 → grayscale for debugging
val gray = (((weight + 1f) / 2f) * 255).toInt().coerceIn(0, 255)
outPixels[i] = Color.argb(255, gray, gray, gray)
}
out.setPixels(outPixels, 0, w, 0, 0, w, h)
return out
return result
}
}