signed weighted mask
This commit is contained in:
parent
a392807855
commit
36c122d9a7
|
|
@ -2,6 +2,7 @@ package com.example.livingai.data.camera
|
|||
|
||||
import android.content.Context
|
||||
import android.graphics.Bitmap
|
||||
import android.graphics.Color
|
||||
import android.graphics.RectF
|
||||
import android.util.Log
|
||||
import com.example.livingai.R
|
||||
|
|
@ -58,7 +59,7 @@ class DefaultTiltChecker : TiltChecker {
|
|||
override suspend fun analyze(input: PipelineInput): Instruction {
|
||||
Log.d("TiltChecker", "Required Orientation: ${input.requiredOrientation}, Pitch: ${input.devicePitch}, Roll: ${input.deviceRoll}, Azimuth: ${input.deviceAzimuth}")
|
||||
|
||||
val tolerance = 10.0f
|
||||
val tolerance = 25.0f
|
||||
val isLevel: Boolean
|
||||
|
||||
if (input.requiredOrientation == CameraOrientation.PORTRAIT) {
|
||||
|
|
@ -87,6 +88,7 @@ class TFLiteObjectDetector(context: Context) : ObjectDetector {
|
|||
private var labels: List<String> = emptyList()
|
||||
private var modelInputWidth: Int = 0
|
||||
private var modelInputHeight: Int = 0
|
||||
private var maxDetections: Int = 25
|
||||
|
||||
init {
|
||||
try {
|
||||
|
|
@ -102,6 +104,14 @@ class TFLiteObjectDetector(context: Context) : ObjectDetector {
|
|||
} else {
|
||||
Log.e("TFLiteObjectDetector", "Invalid input tensor shape.")
|
||||
}
|
||||
|
||||
val outputTensor = interpreter?.getOutputTensor(0)
|
||||
val outputShape = outputTensor?.shape()
|
||||
if (outputShape != null && outputShape.size >= 2) {
|
||||
maxDetections = outputShape[1]
|
||||
Log.d("TFLiteObjectDetector", "Max detections from model: $maxDetections")
|
||||
}
|
||||
|
||||
Log.d("TFLiteObjectDetector", "TFLite model loaded successfully.")
|
||||
} catch (e: IOException) {
|
||||
Log.e("TFLiteObjectDetector", "Error loading TFLite model or labels from assets.", e)
|
||||
|
|
@ -121,7 +131,6 @@ class TFLiteObjectDetector(context: Context) : ObjectDetector {
|
|||
val byteBuffer = convertBitmapToByteBuffer(resizedBitmap)
|
||||
|
||||
// Define model outputs with the correct size
|
||||
val maxDetections = 25
|
||||
val outputLocations = Array(1) { Array(maxDetections) { FloatArray(4) } }
|
||||
val outputClasses = Array(1) { FloatArray(maxDetections) }
|
||||
val outputScores = Array(1) { FloatArray(maxDetections) }
|
||||
|
|
@ -223,122 +232,174 @@ class MockPoseAnalyzer : PoseAnalyzer {
|
|||
}
|
||||
|
||||
override suspend fun analyze(input: PipelineInput): Instruction {
|
||||
val detectionResult = input.previousDetectionResult ?: return Instruction("No detection result", isValid = false)
|
||||
|
||||
val detectionResult =
|
||||
input.previousDetectionResult
|
||||
?: return Instruction("No detection result", isValid = false)
|
||||
|
||||
if (!detectionResult.isAnimalDetected || detectionResult.animalBounds == null) {
|
||||
return Instruction("Animal not detected", isValid = false)
|
||||
}
|
||||
|
||||
val image = input.image ?: return Instruction("No image", isValid = false)
|
||||
val animalMask = getAnimalSegmentation(image) // BooleanArray
|
||||
val image =
|
||||
input.image
|
||||
?: return Instruction("No image", isValid = false)
|
||||
|
||||
// --- Get weighted silhouette mask ---
|
||||
var weightedSilhouette = SilhouetteManager.getWeightedMask(input.orientation)
|
||||
?: return Instruction("Silhouette not found for ${input.orientation}", isValid = false)
|
||||
// --------------------------------------------------------------------
|
||||
// 1. Get reference signed-weight silhouette (Array<FloatArray>)
|
||||
// --------------------------------------------------------------------
|
||||
val weightedSilhouette: Array<FloatArray> =
|
||||
SilhouetteManager.getWeightedMask(input.orientation)
|
||||
?: return Instruction(
|
||||
"Silhouette not found for ${input.orientation}",
|
||||
isValid = false
|
||||
)
|
||||
|
||||
// Ensure silhouette mask matches camera frame size
|
||||
if (weightedSilhouette.width != image.width || weightedSilhouette.height != image.height) {
|
||||
weightedSilhouette =
|
||||
Bitmap.createScaledBitmap(weightedSilhouette, image.width, image.height, true)
|
||||
val refHeight = weightedSilhouette.size
|
||||
val refWidth = weightedSilhouette[0].size
|
||||
|
||||
// --------------------------------------------------------------------
|
||||
// 2. Scale input image to match reference resolution
|
||||
// --------------------------------------------------------------------
|
||||
val scaledImage = Bitmap.createScaledBitmap(
|
||||
image,
|
||||
refWidth,
|
||||
refHeight,
|
||||
true
|
||||
)
|
||||
|
||||
// --------------------------------------------------------------------
|
||||
// 3. Get animal segmentation → BINARY mask (0 / 1)
|
||||
// --------------------------------------------------------------------
|
||||
val animalMask: Array<IntArray> =
|
||||
getAnimalSegmentationMask(scaledImage)
|
||||
?: return Instruction("Segmentation failed", isValid = false)
|
||||
|
||||
// --------------------------------------------------------------------
|
||||
// 4. Compute signed-weight alignment score
|
||||
// --------------------------------------------------------------------
|
||||
val similarity = calculateSignedSimilarity(
|
||||
animalMask,
|
||||
weightedSilhouette
|
||||
)
|
||||
|
||||
Log.d("MockPoseAnalyzer", "Signed Similarity = $similarity")
|
||||
|
||||
val isValid = similarity >= 0.40f
|
||||
val msg = if (isValid) {
|
||||
"Pose Correct"
|
||||
} else {
|
||||
"Pose Incorrect (Score: %.2f)".format(similarity)
|
||||
}
|
||||
|
||||
val weightMap = convertWeightedBitmapToFloatArray(weightedSilhouette)
|
||||
|
||||
// --- Compute weighted Jaccard ---
|
||||
val jaccard = calculateWeightedJaccard(animalMask, weightMap)
|
||||
|
||||
Log.d("MockPoseAnalyzer", "Weighted Jaccard Similarity = $jaccard")
|
||||
|
||||
val isValid = jaccard >= 0.40f
|
||||
val msg = if (isValid) "Pose Correct" else "Pose Incorrect (Jaccard: %.2f)".format(jaccard)
|
||||
|
||||
return Instruction(message = msg, isValid = isValid, result = detectionResult)
|
||||
return Instruction(
|
||||
message = msg,
|
||||
isValid = isValid,
|
||||
result = detectionResult
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Converts segmentation output to a binary mask:
|
||||
* 1 = animal present
|
||||
* 0 = background
|
||||
*/
|
||||
private suspend fun getAnimalSegmentationMask(
|
||||
bitmap: Bitmap
|
||||
): Array<IntArray>? {
|
||||
|
||||
val maskBitmap = getAnimalSegmentation(bitmap) ?: return null
|
||||
|
||||
val w = maskBitmap.width
|
||||
val h = maskBitmap.height
|
||||
|
||||
val pixels = IntArray(w * h)
|
||||
maskBitmap.getPixels(pixels, 0, w, 0, 0, w, h)
|
||||
|
||||
val result = Array(h) { IntArray(w) }
|
||||
|
||||
for (y in 0 until h) {
|
||||
for (x in 0 until w) {
|
||||
val alpha = pixels[y * w + x] ushr 24
|
||||
result[y][x] = if (alpha > 0) 1 else 0
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
|
||||
// ----------------------------------------------------------------------
|
||||
// REAL segmentation using ML Kit
|
||||
// ----------------------------------------------------------------------
|
||||
private suspend fun getAnimalSegmentation(bitmap: Bitmap): BooleanArray = suspendCancellableCoroutine { continuation ->
|
||||
private suspend fun getAnimalSegmentation(bitmap: Bitmap): Bitmap = suspendCancellableCoroutine { continuation ->
|
||||
val inputImage = InputImage.fromBitmap(bitmap, 0)
|
||||
segmenter.process(inputImage)
|
||||
.addOnSuccessListener { result ->
|
||||
val mask = result.foregroundConfidenceMask
|
||||
if (mask != null) {
|
||||
val floatArray = FloatArray(mask.capacity())
|
||||
mask.get(floatArray)
|
||||
val booleanArray = BooleanArray(floatArray.size) { i ->
|
||||
floatArray[i] > 0.5f
|
||||
val maskBitmap = Bitmap.createBitmap(bitmap.width, bitmap.height, Bitmap.Config.ARGB_8888)
|
||||
val buffer = mask
|
||||
buffer.rewind()
|
||||
for (y in 0 until bitmap.height) {
|
||||
for (x in 0 until bitmap.width) {
|
||||
val confidence = buffer.get()
|
||||
val alpha = (confidence * 255).toInt()
|
||||
maskBitmap.setPixel(x, y, Color.argb(alpha, 255, 255, 255))
|
||||
}
|
||||
continuation.resume(booleanArray)
|
||||
}
|
||||
continuation.resume(maskBitmap)
|
||||
} else {
|
||||
Log.e("MockPoseAnalyzer", "Segmentation result null")
|
||||
continuation.resume(BooleanArray(bitmap.width * bitmap.height) { false })
|
||||
continuation.resume(Bitmap.createBitmap(bitmap.width, bitmap.height, Bitmap.Config.ARGB_8888))
|
||||
}
|
||||
}
|
||||
.addOnFailureListener { e ->
|
||||
Log.e("MockPoseAnalyzer", "Segmentation failed", e)
|
||||
continuation.resume(BooleanArray(bitmap.width * bitmap.height) { false })
|
||||
continuation.resume(Bitmap.createBitmap(bitmap.width, bitmap.height, Bitmap.Config.ARGB_8888))
|
||||
}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------
|
||||
// Convert weighted mask bitmap → float[] values in range -1..1
|
||||
// ----------------------------------------------------------------------
|
||||
private fun convertWeightedBitmapToFloatArray(bitmap: Bitmap): FloatArray {
|
||||
val w = bitmap.width
|
||||
val h = bitmap.height
|
||||
|
||||
val pixels = IntArray(w * h)
|
||||
bitmap.getPixels(pixels, 0, w, 0, 0, w, h)
|
||||
|
||||
val out = FloatArray(w * h)
|
||||
|
||||
for (i in pixels.indices) {
|
||||
val color = pixels[i] and 0xFF
|
||||
val norm = (color / 255f) * 2f - 1f // Converts 0..255 → -1..1
|
||||
out[i] = norm
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------
|
||||
// Weighted Jaccard Similarity
|
||||
// mask = predicted (BooleanArray)
|
||||
// weightMap = ground truth silhouette weights (-1..1)
|
||||
// mask = predicted (Bitmap)
|
||||
// weightMap = ground truth silhouette weights (Bitmap)
|
||||
// ----------------------------------------------------------------------
|
||||
private fun calculateWeightedJaccard(predMask: BooleanArray, weight: FloatArray): Float {
|
||||
/**
|
||||
* Signed alignment similarity
|
||||
* Range: [-1, 1]
|
||||
*/
|
||||
private fun calculateSignedSimilarity(
|
||||
mask: Array<IntArray>, // 0 / 1
|
||||
reference: Array<FloatArray> // [-1, +1]
|
||||
): Float {
|
||||
|
||||
if (predMask.size != weight.size) return 0f
|
||||
var score = 0f
|
||||
var maxScore = 0f
|
||||
|
||||
var weightedIntersection = 0f
|
||||
var weightedUnion = 0f
|
||||
val h = reference.size
|
||||
val w = reference[0].size
|
||||
|
||||
for (i in predMask.indices) {
|
||||
for (y in 0 until h) {
|
||||
for (x in 0 until w) {
|
||||
val r = reference[y][x]
|
||||
val m = mask[y][x].toFloat()
|
||||
|
||||
val w = weight[i] // -1.0 .. 1.0
|
||||
|
||||
val pred = predMask[i] // true/false
|
||||
val silhouetteInside = w > 0f
|
||||
|
||||
val intersection = pred && silhouetteInside
|
||||
val union = pred || silhouetteInside
|
||||
|
||||
if (intersection) weightedIntersection += w.coerceAtLeast(0f)
|
||||
if (union) {
|
||||
// Penalize far outside with negative weight also
|
||||
weightedUnion += if (silhouetteInside) {
|
||||
w.coerceAtLeast(0f)
|
||||
} else {
|
||||
(-w) // penalty
|
||||
}
|
||||
}
|
||||
// accumulate positive reference mass only
|
||||
if (r > 0f) {
|
||||
maxScore += r
|
||||
}
|
||||
|
||||
if (weightedUnion == 0f) return 0f
|
||||
return weightedIntersection / weightedUnion
|
||||
score += m * r
|
||||
}
|
||||
}
|
||||
|
||||
return if (maxScore == 0f) 0f else score / maxScore
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
class DefaultCaptureHandler : CaptureHandler {
|
||||
override suspend fun capture(input: PipelineInput, detectionResult: DetectionResult): CaptureData {
|
||||
val image = input.image ?: throw IllegalStateException("Image cannot be null during capture")
|
||||
|
|
|
|||
|
|
@ -13,11 +13,11 @@ object SilhouetteManager {
|
|||
|
||||
private val originals = ConcurrentHashMap<String, Bitmap>()
|
||||
private val invertedPurple = ConcurrentHashMap<String, Bitmap>()
|
||||
private val weightedMasks = ConcurrentHashMap<String, Bitmap>()
|
||||
private val weightedMasks = ConcurrentHashMap<String, Array<FloatArray>>()
|
||||
|
||||
fun getOriginal(name: String): Bitmap? = originals[name]
|
||||
fun getInvertedPurple(name: String): Bitmap? = invertedPurple[name]
|
||||
fun getWeightedMask(name: String): Bitmap? = weightedMasks[name]
|
||||
fun getWeightedMask(name: String): Array<FloatArray>? = weightedMasks[name]
|
||||
|
||||
fun initialize(context: Context, width: Int, height: Int) {
|
||||
val resources = context.resources
|
||||
|
|
@ -77,15 +77,18 @@ object SilhouetteManager {
|
|||
return Bitmap.createScaledBitmap(inverted, targetWidth, targetHeight, true)
|
||||
}
|
||||
|
||||
// ------------------------------------------------------------------------
|
||||
// STEP 2: Create signed weighted mask (-1 to 1)
|
||||
// ------------------------------------------------------------------------
|
||||
|
||||
private fun createSignedWeightedMask(
|
||||
/**
|
||||
* Creates a signed weighted mask in range [-1, +1]
|
||||
*
|
||||
* +1 : deep inside object
|
||||
* 0 : object boundary
|
||||
* -1 : far outside object
|
||||
*/
|
||||
fun createSignedWeightedMask(
|
||||
bitmap: Bitmap,
|
||||
fadeInside: Int = 10,
|
||||
fadeOutside: Int = 20
|
||||
): Bitmap {
|
||||
): Array<FloatArray> {
|
||||
|
||||
val w = bitmap.width
|
||||
val h = bitmap.height
|
||||
|
|
@ -95,24 +98,25 @@ object SilhouetteManager {
|
|||
|
||||
fun idx(x: Int, y: Int) = y * w + x
|
||||
|
||||
// inside = 1 → silhouette purple
|
||||
// inside = 0 → outside
|
||||
// --------------------------------------------------------------------
|
||||
// 1. Binary mask: inside = 1, outside = 0
|
||||
// Assumption: NON-transparent pixels are inside the object
|
||||
// --------------------------------------------------------------------
|
||||
val inside = IntArray(w * h)
|
||||
for (i in pixels.indices) {
|
||||
val alpha = pixels[i] ushr 24
|
||||
inside[i] = if (alpha == 0) 1 else 0
|
||||
inside[i] = if (alpha > 0) 1 else 0
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------
|
||||
// DISTANCES FOR INSIDE PIXELS (to nearest OUTSIDE pixel)
|
||||
// 2. Distance to nearest OUTSIDE pixel (for inside pixels)
|
||||
// --------------------------------------------------------------------
|
||||
val distInside = IntArray(w * h) { Int.MAX_VALUE }
|
||||
|
||||
for (i in inside.indices) {
|
||||
if (inside[i] == 0) distInside[i] = 0
|
||||
}
|
||||
|
||||
// forward
|
||||
// forward pass
|
||||
for (y in 0 until h) {
|
||||
for (x in 0 until w) {
|
||||
val i = idx(x, y)
|
||||
|
|
@ -123,7 +127,7 @@ object SilhouetteManager {
|
|||
}
|
||||
}
|
||||
|
||||
// backward
|
||||
// backward pass
|
||||
for (y in h - 1 downTo 0) {
|
||||
for (x in w - 1 downTo 0) {
|
||||
val i = idx(x, y)
|
||||
|
|
@ -135,15 +139,14 @@ object SilhouetteManager {
|
|||
}
|
||||
|
||||
// --------------------------------------------------------------------
|
||||
// DISTANCES FOR OUTSIDE PIXELS (to nearest INSIDE pixel)
|
||||
// 3. Distance to nearest INSIDE pixel (for outside pixels)
|
||||
// --------------------------------------------------------------------
|
||||
val distOutside = IntArray(w * h) { Int.MAX_VALUE }
|
||||
|
||||
for (i in inside.indices) {
|
||||
if (inside[i] == 1) distOutside[i] = 0
|
||||
}
|
||||
|
||||
// forward
|
||||
// forward pass
|
||||
for (y in 0 until h) {
|
||||
for (x in 0 until w) {
|
||||
val i = idx(x, y)
|
||||
|
|
@ -154,7 +157,7 @@ object SilhouetteManager {
|
|||
}
|
||||
}
|
||||
|
||||
// backward
|
||||
// backward pass
|
||||
for (y in h - 1 downTo 0) {
|
||||
for (x in w - 1 downTo 0) {
|
||||
val i = idx(x, y)
|
||||
|
|
@ -166,31 +169,30 @@ object SilhouetteManager {
|
|||
}
|
||||
|
||||
// --------------------------------------------------------------------
|
||||
// BUILD FINAL SIGNED MASK (-1 to +1)
|
||||
// 4. Build signed weight map [-1, +1]
|
||||
// --------------------------------------------------------------------
|
||||
val result = Array(h) { FloatArray(w) }
|
||||
|
||||
val out = Bitmap.createBitmap(w, h, Bitmap.Config.ARGB_8888)
|
||||
val outPixels = IntArray(w * h)
|
||||
for (y in 0 until h) {
|
||||
for (x in 0 until w) {
|
||||
val i = idx(x, y)
|
||||
|
||||
for (i in outPixels.indices) {
|
||||
|
||||
val weight: Float = if (inside[i] == 1) {
|
||||
// Inside silhouette: +1 to 0
|
||||
val weight = if (inside[i] == 1) {
|
||||
// Inside: +1 → 0
|
||||
val d = distInside[i]
|
||||
if (d >= fadeInside) 1f else d.toFloat() / fadeInside
|
||||
if (d >= fadeInside) 1f
|
||||
else d.toFloat() / fadeInside
|
||||
} else {
|
||||
// Outside: 0 to -1
|
||||
// Outside: 0 → -1
|
||||
val d = distOutside[i]
|
||||
val neg = -(d.toFloat() / fadeOutside)
|
||||
neg.coerceAtLeast(-1f)
|
||||
val v = -(d.toFloat() / fadeOutside)
|
||||
v.coerceAtLeast(-1f)
|
||||
}
|
||||
|
||||
// Convert -1..1 → grayscale for debugging
|
||||
val gray = (((weight + 1f) / 2f) * 255).toInt().coerceIn(0, 255)
|
||||
outPixels[i] = Color.argb(255, gray, gray, gray)
|
||||
result[y][x] = weight
|
||||
}
|
||||
}
|
||||
|
||||
out.setPixels(outPixels, 0, w, 0, 0, w, h)
|
||||
return out
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue