optimizations
This commit is contained in:
parent
36c122d9a7
commit
5d58695eb4
|
|
@ -221,7 +221,6 @@ class TFLiteObjectDetector(context: Context) : ObjectDetector {
|
||||||
data class Detection(val label: String, val confidence: Float, val bounds: RectF)
|
data class Detection(val label: String, val confidence: Float, val bounds: RectF)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class MockPoseAnalyzer : PoseAnalyzer {
|
class MockPoseAnalyzer : PoseAnalyzer {
|
||||||
|
|
||||||
private val segmenter by lazy {
|
private val segmenter by lazy {
|
||||||
|
|
@ -233,173 +232,114 @@ class MockPoseAnalyzer : PoseAnalyzer {
|
||||||
|
|
||||||
override suspend fun analyze(input: PipelineInput): Instruction {
|
override suspend fun analyze(input: PipelineInput): Instruction {
|
||||||
|
|
||||||
val detectionResult =
|
val detection = input.previousDetectionResult
|
||||||
input.previousDetectionResult
|
?: return Instruction("No detection", isValid = false)
|
||||||
?: return Instruction("No detection result", isValid = false)
|
|
||||||
|
|
||||||
if (!detectionResult.isAnimalDetected || detectionResult.animalBounds == null) {
|
val bounds = detection.animalBounds
|
||||||
return Instruction("Animal not detected", isValid = false)
|
?: return Instruction("Animal not detected", isValid = false)
|
||||||
}
|
|
||||||
|
|
||||||
val image =
|
val image = input.image
|
||||||
input.image
|
|
||||||
?: return Instruction("No image", isValid = false)
|
?: return Instruction("No image", isValid = false)
|
||||||
|
|
||||||
// --------------------------------------------------------------------
|
// --------------------------------------------------------------------
|
||||||
// 1. Get reference signed-weight silhouette (Array<FloatArray>)
|
// 1. Reference silhouette (FloatArray)
|
||||||
// --------------------------------------------------------------------
|
// --------------------------------------------------------------------
|
||||||
val weightedSilhouette: Array<FloatArray> =
|
val reference = SilhouetteManager.getWeightedMask(input.orientation)
|
||||||
SilhouetteManager.getWeightedMask(input.orientation)
|
?: return Instruction("Silhouette missing", isValid = false)
|
||||||
?: return Instruction(
|
|
||||||
"Silhouette not found for ${input.orientation}",
|
|
||||||
isValid = false
|
|
||||||
)
|
|
||||||
|
|
||||||
val refHeight = weightedSilhouette.size
|
val refH = reference.size
|
||||||
val refWidth = weightedSilhouette[0].size
|
val refW = reference[0].size
|
||||||
|
|
||||||
// --------------------------------------------------------------------
|
// --------------------------------------------------------------------
|
||||||
// 2. Scale input image to match reference resolution
|
// 2. Crop only animal region (BIG WIN)
|
||||||
// --------------------------------------------------------------------
|
// --------------------------------------------------------------------
|
||||||
val scaledImage = Bitmap.createScaledBitmap(
|
val cropped = Bitmap.createBitmap(
|
||||||
image,
|
image,
|
||||||
refWidth,
|
bounds.left.toInt(),
|
||||||
refHeight,
|
bounds.top.toInt(),
|
||||||
true
|
bounds.width().toInt(),
|
||||||
|
bounds.height().toInt()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
val scaled = Bitmap.createScaledBitmap(cropped, refW, refH, true)
|
||||||
|
|
||||||
// --------------------------------------------------------------------
|
// --------------------------------------------------------------------
|
||||||
// 3. Get animal segmentation → BINARY mask (0 / 1)
|
// 3. Get binary segmentation mask (ByteArray)
|
||||||
// --------------------------------------------------------------------
|
// --------------------------------------------------------------------
|
||||||
val animalMask: Array<IntArray> =
|
val mask = getAnimalMaskFast(scaled)
|
||||||
getAnimalSegmentationMask(scaledImage)
|
|
||||||
?: return Instruction("Segmentation failed", isValid = false)
|
?: return Instruction("Segmentation failed", isValid = false)
|
||||||
|
|
||||||
// --------------------------------------------------------------------
|
// --------------------------------------------------------------------
|
||||||
// 4. Compute signed-weight alignment score
|
// 4. Fast signed similarity
|
||||||
// --------------------------------------------------------------------
|
// --------------------------------------------------------------------
|
||||||
val similarity = calculateSignedSimilarity(
|
val score = calculateSignedSimilarityFast(mask, reference)
|
||||||
animalMask,
|
|
||||||
weightedSilhouette
|
|
||||||
)
|
|
||||||
|
|
||||||
Log.d("MockPoseAnalyzer", "Signed Similarity = $similarity")
|
|
||||||
|
|
||||||
val isValid = similarity >= 0.40f
|
|
||||||
val msg = if (isValid) {
|
|
||||||
"Pose Correct"
|
|
||||||
} else {
|
|
||||||
"Pose Incorrect (Score: %.2f)".format(similarity)
|
|
||||||
}
|
|
||||||
|
|
||||||
|
val valid = score >= 0.40f
|
||||||
return Instruction(
|
return Instruction(
|
||||||
message = msg,
|
message = if (valid) "Pose Correct" else "Pose Incorrect (%.2f)".format(score),
|
||||||
isValid = isValid,
|
isValid = valid,
|
||||||
result = detectionResult
|
result = detection
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Converts segmentation output to a binary mask:
|
* ML Kit → Binary mask (0/1) using ByteArray
|
||||||
* 1 = animal present
|
|
||||||
* 0 = background
|
|
||||||
*/
|
*/
|
||||||
private suspend fun getAnimalSegmentationMask(
|
private suspend fun getAnimalMaskFast(bitmap: Bitmap): ByteArray? =
|
||||||
bitmap: Bitmap
|
suspendCancellableCoroutine { cont ->
|
||||||
): Array<IntArray>? {
|
|
||||||
|
|
||||||
val maskBitmap = getAnimalSegmentation(bitmap) ?: return null
|
val image = InputImage.fromBitmap(bitmap, 0)
|
||||||
|
segmenter.process(image)
|
||||||
val w = maskBitmap.width
|
|
||||||
val h = maskBitmap.height
|
|
||||||
|
|
||||||
val pixels = IntArray(w * h)
|
|
||||||
maskBitmap.getPixels(pixels, 0, w, 0, 0, w, h)
|
|
||||||
|
|
||||||
val result = Array(h) { IntArray(w) }
|
|
||||||
|
|
||||||
for (y in 0 until h) {
|
|
||||||
for (x in 0 until w) {
|
|
||||||
val alpha = pixels[y * w + x] ushr 24
|
|
||||||
result[y][x] = if (alpha > 0) 1 else 0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
// ----------------------------------------------------------------------
|
|
||||||
// REAL segmentation using ML Kit
|
|
||||||
// ----------------------------------------------------------------------
|
|
||||||
private suspend fun getAnimalSegmentation(bitmap: Bitmap): Bitmap = suspendCancellableCoroutine { continuation ->
|
|
||||||
val inputImage = InputImage.fromBitmap(bitmap, 0)
|
|
||||||
segmenter.process(inputImage)
|
|
||||||
.addOnSuccessListener { result ->
|
.addOnSuccessListener { result ->
|
||||||
val mask = result.foregroundConfidenceMask
|
val mask = result.foregroundConfidenceMask
|
||||||
if (mask != null) {
|
?: run {
|
||||||
val maskBitmap = Bitmap.createBitmap(bitmap.width, bitmap.height, Bitmap.Config.ARGB_8888)
|
cont.resume(null)
|
||||||
|
return@addOnSuccessListener
|
||||||
|
}
|
||||||
|
|
||||||
val buffer = mask
|
val buffer = mask
|
||||||
buffer.rewind()
|
buffer.rewind()
|
||||||
for (y in 0 until bitmap.height) {
|
|
||||||
for (x in 0 until bitmap.width) {
|
val out = ByteArray(bitmap.width * bitmap.height)
|
||||||
val confidence = buffer.get()
|
|
||||||
val alpha = (confidence * 255).toInt()
|
for (i in out.indices) {
|
||||||
maskBitmap.setPixel(x, y, Color.argb(alpha, 255, 255, 255))
|
out[i] = if (buffer.get() > 0.5f) 1 else 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
cont.resume(out)
|
||||||
}
|
}
|
||||||
continuation.resume(maskBitmap)
|
.addOnFailureListener {
|
||||||
} else {
|
cont.resume(null)
|
||||||
Log.e("MockPoseAnalyzer", "Segmentation result null")
|
|
||||||
continuation.resume(Bitmap.createBitmap(bitmap.width, bitmap.height, Bitmap.Config.ARGB_8888))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
.addOnFailureListener { e ->
|
|
||||||
Log.e("MockPoseAnalyzer", "Segmentation failed", e)
|
|
||||||
continuation.resume(Bitmap.createBitmap(bitmap.width, bitmap.height, Bitmap.Config.ARGB_8888))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ----------------------------------------------------------------------
|
|
||||||
// Weighted Jaccard Similarity
|
|
||||||
// mask = predicted (Bitmap)
|
|
||||||
// weightMap = ground truth silhouette weights (Bitmap)
|
|
||||||
// ----------------------------------------------------------------------
|
|
||||||
/**
|
/**
|
||||||
* Signed alignment similarity
|
* Signed similarity (flat arrays, cache-friendly)
|
||||||
* Range: [-1, 1]
|
* Range [-1, 1]
|
||||||
*/
|
*/
|
||||||
private fun calculateSignedSimilarity(
|
private fun calculateSignedSimilarityFast(
|
||||||
mask: Array<IntArray>, // 0 / 1
|
mask: ByteArray,
|
||||||
reference: Array<FloatArray> // [-1, +1]
|
reference: Array<FloatArray>
|
||||||
): Float {
|
): Float {
|
||||||
|
|
||||||
var score = 0f
|
var score = 0f
|
||||||
var maxScore = 0f
|
var maxScore = 0f
|
||||||
|
var idx = 0
|
||||||
|
|
||||||
val h = reference.size
|
for (y in reference.indices) {
|
||||||
val w = reference[0].size
|
val row = reference[y]
|
||||||
|
for (x in row.indices) {
|
||||||
for (y in 0 until h) {
|
val r = row[x]
|
||||||
for (x in 0 until w) {
|
if (r > 0f) maxScore += r
|
||||||
val r = reference[y][x]
|
score += mask[idx++] * r
|
||||||
val m = mask[y][x].toFloat()
|
|
||||||
|
|
||||||
// accumulate positive reference mass only
|
|
||||||
if (r > 0f) {
|
|
||||||
maxScore += r
|
|
||||||
}
|
|
||||||
|
|
||||||
score += m * r
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return if (maxScore == 0f) 0f else score / maxScore
|
return if (maxScore == 0f) 0f else score / maxScore
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class DefaultCaptureHandler : CaptureHandler {
|
class DefaultCaptureHandler : CaptureHandler {
|
||||||
override suspend fun capture(input: PipelineInput, detectionResult: DetectionResult): CaptureData {
|
override suspend fun capture(input: PipelineInput, detectionResult: DetectionResult): CaptureData {
|
||||||
val image = input.image ?: throw IllegalStateException("Image cannot be null during capture")
|
val image = input.image ?: throw IllegalStateException("Image cannot be null during capture")
|
||||||
|
|
|
||||||
|
|
@ -25,6 +25,9 @@ class CameraViewModel(
|
||||||
val uiState: StateFlow<CameraUiState> = _uiState.asStateFlow()
|
val uiState: StateFlow<CameraUiState> = _uiState.asStateFlow()
|
||||||
private val tilt = tiltSensorManager.tilt
|
private val tilt = tiltSensorManager.tilt
|
||||||
|
|
||||||
|
private var frameCounter = 0
|
||||||
|
private val frameSkipInterval = 5
|
||||||
|
|
||||||
init {
|
init {
|
||||||
tiltSensorManager.start()
|
tiltSensorManager.start()
|
||||||
}
|
}
|
||||||
|
|
@ -49,6 +52,9 @@ class CameraViewModel(
|
||||||
image: Bitmap,
|
image: Bitmap,
|
||||||
deviceOrientation: Int
|
deviceOrientation: Int
|
||||||
) {
|
) {
|
||||||
|
frameCounter = (frameCounter + 1) % frameSkipInterval
|
||||||
|
if (frameCounter != 0) return
|
||||||
|
|
||||||
viewModelScope.launch {
|
viewModelScope.launch {
|
||||||
val currentTilt = tilt.value
|
val currentTilt = tilt.value
|
||||||
val input = PipelineInput(
|
val input = PipelineInput(
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue