optimizations

This commit is contained in:
SaiD 2025-12-13 12:15:03 +05:30
parent 36c122d9a7
commit 5d58695eb4
2 changed files with 72 additions and 126 deletions

View File

@ -221,7 +221,6 @@ class TFLiteObjectDetector(context: Context) : ObjectDetector {
data class Detection(val label: String, val confidence: Float, val bounds: RectF) data class Detection(val label: String, val confidence: Float, val bounds: RectF)
} }
class MockPoseAnalyzer : PoseAnalyzer { class MockPoseAnalyzer : PoseAnalyzer {
private val segmenter by lazy { private val segmenter by lazy {
@ -233,173 +232,114 @@ class MockPoseAnalyzer : PoseAnalyzer {
override suspend fun analyze(input: PipelineInput): Instruction { override suspend fun analyze(input: PipelineInput): Instruction {
val detectionResult = val detection = input.previousDetectionResult
input.previousDetectionResult ?: return Instruction("No detection", isValid = false)
?: return Instruction("No detection result", isValid = false)
if (!detectionResult.isAnimalDetected || detectionResult.animalBounds == null) { val bounds = detection.animalBounds
return Instruction("Animal not detected", isValid = false) ?: return Instruction("Animal not detected", isValid = false)
}
val image = val image = input.image
input.image ?: return Instruction("No image", isValid = false)
?: return Instruction("No image", isValid = false)
// -------------------------------------------------------------------- // --------------------------------------------------------------------
// 1. Get reference signed-weight silhouette (Array<FloatArray>) // 1. Reference silhouette (FloatArray)
// -------------------------------------------------------------------- // --------------------------------------------------------------------
val weightedSilhouette: Array<FloatArray> = val reference = SilhouetteManager.getWeightedMask(input.orientation)
SilhouetteManager.getWeightedMask(input.orientation) ?: return Instruction("Silhouette missing", isValid = false)
?: return Instruction(
"Silhouette not found for ${input.orientation}",
isValid = false
)
val refHeight = weightedSilhouette.size val refH = reference.size
val refWidth = weightedSilhouette[0].size val refW = reference[0].size
// -------------------------------------------------------------------- // --------------------------------------------------------------------
// 2. Scale input image to match reference resolution // 2. Crop only animal region (BIG WIN)
// -------------------------------------------------------------------- // --------------------------------------------------------------------
val scaledImage = Bitmap.createScaledBitmap( val cropped = Bitmap.createBitmap(
image, image,
refWidth, bounds.left.toInt(),
refHeight, bounds.top.toInt(),
true bounds.width().toInt(),
bounds.height().toInt()
) )
// -------------------------------------------------------------------- val scaled = Bitmap.createScaledBitmap(cropped, refW, refH, true)
// 3. Get animal segmentation → BINARY mask (0 / 1)
// --------------------------------------------------------------------
val animalMask: Array<IntArray> =
getAnimalSegmentationMask(scaledImage)
?: return Instruction("Segmentation failed", isValid = false)
// -------------------------------------------------------------------- // --------------------------------------------------------------------
// 4. Compute signed-weight alignment score // 3. Get binary segmentation mask (ByteArray)
// -------------------------------------------------------------------- // --------------------------------------------------------------------
val similarity = calculateSignedSimilarity( val mask = getAnimalMaskFast(scaled)
animalMask, ?: return Instruction("Segmentation failed", isValid = false)
weightedSilhouette
)
Log.d("MockPoseAnalyzer", "Signed Similarity = $similarity") // --------------------------------------------------------------------
// 4. Fast signed similarity
val isValid = similarity >= 0.40f // --------------------------------------------------------------------
val msg = if (isValid) { val score = calculateSignedSimilarityFast(mask, reference)
"Pose Correct"
} else {
"Pose Incorrect (Score: %.2f)".format(similarity)
}
val valid = score >= 0.40f
return Instruction( return Instruction(
message = msg, message = if (valid) "Pose Correct" else "Pose Incorrect (%.2f)".format(score),
isValid = isValid, isValid = valid,
result = detectionResult result = detection
) )
} }
/** /**
* Converts segmentation output to a binary mask: * ML Kit Binary mask (0/1) using ByteArray
* 1 = animal present
* 0 = background
*/ */
private suspend fun getAnimalSegmentationMask( private suspend fun getAnimalMaskFast(bitmap: Bitmap): ByteArray? =
bitmap: Bitmap suspendCancellableCoroutine { cont ->
): Array<IntArray>? {
val maskBitmap = getAnimalSegmentation(bitmap) ?: return null val image = InputImage.fromBitmap(bitmap, 0)
segmenter.process(image)
.addOnSuccessListener { result ->
val mask = result.foregroundConfidenceMask
?: run {
cont.resume(null)
return@addOnSuccessListener
}
val w = maskBitmap.width
val h = maskBitmap.height
val pixels = IntArray(w * h)
maskBitmap.getPixels(pixels, 0, w, 0, 0, w, h)
val result = Array(h) { IntArray(w) }
for (y in 0 until h) {
for (x in 0 until w) {
val alpha = pixels[y * w + x] ushr 24
result[y][x] = if (alpha > 0) 1 else 0
}
}
return result
}
// ----------------------------------------------------------------------
// REAL segmentation using ML Kit
// ----------------------------------------------------------------------
private suspend fun getAnimalSegmentation(bitmap: Bitmap): Bitmap = suspendCancellableCoroutine { continuation ->
val inputImage = InputImage.fromBitmap(bitmap, 0)
segmenter.process(inputImage)
.addOnSuccessListener { result ->
val mask = result.foregroundConfidenceMask
if (mask != null) {
val maskBitmap = Bitmap.createBitmap(bitmap.width, bitmap.height, Bitmap.Config.ARGB_8888)
val buffer = mask val buffer = mask
buffer.rewind() buffer.rewind()
for (y in 0 until bitmap.height) {
for (x in 0 until bitmap.width) {
val confidence = buffer.get()
val alpha = (confidence * 255).toInt()
maskBitmap.setPixel(x, y, Color.argb(alpha, 255, 255, 255))
}
}
continuation.resume(maskBitmap)
} else {
Log.e("MockPoseAnalyzer", "Segmentation result null")
continuation.resume(Bitmap.createBitmap(bitmap.width, bitmap.height, Bitmap.Config.ARGB_8888))
}
}
.addOnFailureListener { e ->
Log.e("MockPoseAnalyzer", "Segmentation failed", e)
continuation.resume(Bitmap.createBitmap(bitmap.width, bitmap.height, Bitmap.Config.ARGB_8888))
}
}
// ---------------------------------------------------------------------- val out = ByteArray(bitmap.width * bitmap.height)
// Weighted Jaccard Similarity
// mask = predicted (Bitmap) for (i in out.indices) {
// weightMap = ground truth silhouette weights (Bitmap) out[i] = if (buffer.get() > 0.5f) 1 else 0
// ---------------------------------------------------------------------- }
cont.resume(out)
}
.addOnFailureListener {
cont.resume(null)
}
}
/** /**
* Signed alignment similarity * Signed similarity (flat arrays, cache-friendly)
* Range: [-1, 1] * Range [-1, 1]
*/ */
private fun calculateSignedSimilarity( private fun calculateSignedSimilarityFast(
mask: Array<IntArray>, // 0 / 1 mask: ByteArray,
reference: Array<FloatArray> // [-1, +1] reference: Array<FloatArray>
): Float { ): Float {
var score = 0f var score = 0f
var maxScore = 0f var maxScore = 0f
var idx = 0
val h = reference.size for (y in reference.indices) {
val w = reference[0].size val row = reference[y]
for (x in row.indices) {
for (y in 0 until h) { val r = row[x]
for (x in 0 until w) { if (r > 0f) maxScore += r
val r = reference[y][x] score += mask[idx++] * r
val m = mask[y][x].toFloat()
// accumulate positive reference mass only
if (r > 0f) {
maxScore += r
}
score += m * r
} }
} }
return if (maxScore == 0f) 0f else score / maxScore return if (maxScore == 0f) 0f else score / maxScore
} }
} }
class DefaultCaptureHandler : CaptureHandler { class DefaultCaptureHandler : CaptureHandler {
override suspend fun capture(input: PipelineInput, detectionResult: DetectionResult): CaptureData { override suspend fun capture(input: PipelineInput, detectionResult: DetectionResult): CaptureData {
val image = input.image ?: throw IllegalStateException("Image cannot be null during capture") val image = input.image ?: throw IllegalStateException("Image cannot be null during capture")

View File

@ -25,6 +25,9 @@ class CameraViewModel(
val uiState: StateFlow<CameraUiState> = _uiState.asStateFlow() val uiState: StateFlow<CameraUiState> = _uiState.asStateFlow()
private val tilt = tiltSensorManager.tilt private val tilt = tiltSensorManager.tilt
private var frameCounter = 0
private val frameSkipInterval = 5
init { init {
tiltSensorManager.start() tiltSensorManager.start()
} }
@ -49,6 +52,9 @@ class CameraViewModel(
image: Bitmap, image: Bitmap,
deviceOrientation: Int deviceOrientation: Int
) { ) {
frameCounter = (frameCounter + 1) % frameSkipInterval
if (frameCounter != 0) return
viewModelScope.launch { viewModelScope.launch {
val currentTilt = tilt.value val currentTilt = tilt.value
val input = PipelineInput( val input = PipelineInput(