adding additional checks in pipeline, to reduce segmentation load
This commit is contained in:
parent
9750472027
commit
986c0da505
|
|
@ -2,7 +2,6 @@ package com.example.livingai.data.camera
|
|||
|
||||
import android.content.Context
|
||||
import android.graphics.Bitmap
|
||||
import android.graphics.Color
|
||||
import android.graphics.RectF
|
||||
import android.util.Log
|
||||
import com.example.livingai.R
|
||||
|
|
@ -10,225 +9,200 @@ import com.example.livingai.domain.camera.*
|
|||
import com.example.livingai.domain.model.camera.*
|
||||
import com.example.livingai.utils.SignedMask
|
||||
import com.example.livingai.utils.SilhouetteManager
|
||||
import com.google.mlkit.vision.common.InputImage
|
||||
import com.google.mlkit.vision.segmentation.subject.SubjectSegmentation
|
||||
import com.google.mlkit.vision.segmentation.subject.SubjectSegmenterOptions
|
||||
import kotlinx.coroutines.suspendCancellableCoroutine
|
||||
import org.tensorflow.lite.Interpreter
|
||||
import org.tensorflow.lite.support.common.FileUtil
|
||||
import java.io.IOException
|
||||
import java.nio.ByteBuffer
|
||||
import java.nio.ByteOrder
|
||||
import kotlin.math.abs
|
||||
import com.google.mlkit.vision.common.InputImage
|
||||
import com.google.mlkit.vision.segmentation.subject.SubjectSegmentation
|
||||
import com.google.mlkit.vision.segmentation.subject.SubjectSegmenterOptions
|
||||
import kotlinx.coroutines.suspendCancellableCoroutine
|
||||
import kotlin.coroutines.resume
|
||||
import kotlin.math.abs
|
||||
import kotlin.math.min
|
||||
|
||||
/* ============================================================= */
|
||||
/* ORIENTATION CHECKER */
|
||||
/* ============================================================= */
|
||||
|
||||
class DefaultOrientationChecker : OrientationChecker {
|
||||
override suspend fun analyze(input: PipelineInput): Instruction {
|
||||
val orientationLower = input.orientation.lowercase()
|
||||
val isPortraitRequired = orientationLower == "front" || orientationLower == "back"
|
||||
|
||||
// Corrected Logic:
|
||||
// 90 or 270 degrees means the device is held in PORTRAIT
|
||||
val isDevicePortrait = input.deviceOrientation == 90 || input.deviceOrientation == 270
|
||||
// 0 or 180 degrees means the device is held in LANDSCAPE
|
||||
val isDeviceLandscape = input.deviceOrientation == 0 || input.deviceOrientation == 180
|
||||
val isPortraitRequired =
|
||||
input.orientation.lowercase() == "front" ||
|
||||
input.orientation.lowercase() == "back"
|
||||
|
||||
var isValid = true
|
||||
var message = "Orientation Correct"
|
||||
val isPortrait = input.deviceOrientation == 90 || input.deviceOrientation == 270
|
||||
val isLandscape = input.deviceOrientation == 0 || input.deviceOrientation == 180
|
||||
|
||||
if (isPortraitRequired && !isDevicePortrait) {
|
||||
isValid = false
|
||||
message = "Turn to portrait mode"
|
||||
} else if (!isPortraitRequired && !isDeviceLandscape) {
|
||||
isValid = false
|
||||
message = "Turn to landscape mode"
|
||||
}
|
||||
|
||||
val animRes = if (!isValid) R.drawable.ic_launcher_foreground else null
|
||||
val valid = if (isPortraitRequired) isPortrait else isLandscape
|
||||
|
||||
return Instruction(
|
||||
message = message,
|
||||
animationResId = animRes,
|
||||
isValid = isValid,
|
||||
result = OrientationResult(input.deviceOrientation, if (isPortraitRequired) CameraOrientation.PORTRAIT else CameraOrientation.LANDSCAPE)
|
||||
message = if (valid) "Orientation Correct"
|
||||
else if (isPortraitRequired) "Turn to portrait mode"
|
||||
else "Turn to landscape mode",
|
||||
animationResId = if (valid) null else R.drawable.ic_launcher_foreground,
|
||||
isValid = valid,
|
||||
result = OrientationResult(
|
||||
input.deviceOrientation,
|
||||
if (isPortraitRequired) CameraOrientation.PORTRAIT else CameraOrientation.LANDSCAPE
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/* ============================================================= */
|
||||
/* TILT CHECKER */
|
||||
/* ============================================================= */
|
||||
|
||||
class DefaultTiltChecker : TiltChecker {
|
||||
override suspend fun analyze(input: PipelineInput): Instruction {
|
||||
Log.d("TiltChecker", "Required Orientation: ${input.requiredOrientation}, Pitch: ${input.devicePitch}, Roll: ${input.deviceRoll}, Azimuth: ${input.deviceAzimuth}")
|
||||
|
||||
val tolerance = 25.0f
|
||||
val isLevel: Boolean
|
||||
val tolerance = 25f
|
||||
|
||||
if (input.requiredOrientation == CameraOrientation.PORTRAIT) {
|
||||
// Ideal for portrait: pitch around -90, roll around 0
|
||||
val idealPitch = -90.0f
|
||||
isLevel = abs(input.devicePitch - idealPitch) <= tolerance
|
||||
} else { // LANDSCAPE
|
||||
// Ideal for landscape: pitch around 0, roll around +/-90
|
||||
val idealPitch = 0.0f
|
||||
isLevel = abs(input.devicePitch - idealPitch) <= tolerance
|
||||
val isLevel = when (input.requiredOrientation) {
|
||||
CameraOrientation.PORTRAIT ->
|
||||
abs(input.devicePitch + 90f) <= tolerance
|
||||
CameraOrientation.LANDSCAPE ->
|
||||
abs(input.devicePitch) <= tolerance
|
||||
}
|
||||
|
||||
val message = if (isLevel) "Device is level" else "Keep the phone straight"
|
||||
|
||||
return Instruction(
|
||||
message = message,
|
||||
message = if (isLevel) "Device is level" else "Keep the phone straight",
|
||||
isValid = isLevel,
|
||||
result = TiltResult(input.deviceRoll, input.devicePitch, isLevel)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/* ============================================================= */
|
||||
/* TFLITE OBJECT DETECTOR (PRIMARY + REFERENCE OBJECTS) */
|
||||
/* ============================================================= */
|
||||
|
||||
class TFLiteObjectDetector(context: Context) : ObjectDetector {
|
||||
|
||||
private var interpreter: Interpreter? = null
|
||||
private var labels: List<String> = emptyList()
|
||||
private var modelInputWidth: Int = 0
|
||||
private var modelInputHeight: Int = 0
|
||||
private var maxDetections: Int = 25
|
||||
private var inputW = 0
|
||||
private var inputH = 0
|
||||
private var maxDetections = 25
|
||||
|
||||
init {
|
||||
try {
|
||||
val modelBuffer = FileUtil.loadMappedFile(context, "efficientdet-lite0.tflite")
|
||||
interpreter = Interpreter(modelBuffer)
|
||||
interpreter = Interpreter(
|
||||
FileUtil.loadMappedFile(context, "efficientdet-lite0.tflite")
|
||||
)
|
||||
labels = FileUtil.loadLabels(context, "labels.txt")
|
||||
|
||||
val inputTensor = interpreter?.getInputTensor(0)
|
||||
val inputShape = inputTensor?.shape()
|
||||
if (inputShape != null && inputShape.size >= 3) {
|
||||
modelInputWidth = inputShape[1]
|
||||
modelInputHeight = inputShape[2]
|
||||
} else {
|
||||
Log.e("TFLiteObjectDetector", "Invalid input tensor shape.")
|
||||
}
|
||||
val inputShape = interpreter!!.getInputTensor(0).shape()
|
||||
inputW = inputShape[1]
|
||||
inputH = inputShape[2]
|
||||
|
||||
val outputTensor = interpreter?.getOutputTensor(0)
|
||||
val outputShape = outputTensor?.shape()
|
||||
if (outputShape != null && outputShape.size >= 2) {
|
||||
maxDetections = outputShape[1]
|
||||
Log.d("TFLiteObjectDetector", "Max detections from model: $maxDetections")
|
||||
}
|
||||
maxDetections = interpreter!!.getOutputTensor(0).shape()[1]
|
||||
|
||||
Log.d("TFLiteObjectDetector", "TFLite model loaded successfully.")
|
||||
} catch (e: IOException) {
|
||||
Log.e("TFLiteObjectDetector", "Error loading TFLite model or labels from assets.", e)
|
||||
Log.e("TFLiteObjectDetector", "Please ensure 'efficientdet-lite0.tflite' and 'labelmap.txt' are in the 'app/src/main/assets' directory.")
|
||||
Log.e("Detector", "Failed to load model", e)
|
||||
interpreter = null
|
||||
}
|
||||
}
|
||||
|
||||
override suspend fun analyze(input: PipelineInput): Instruction {
|
||||
if (interpreter == null) {
|
||||
return Instruction("Object detector not initialized. Check asset files.", isValid = false)
|
||||
|
||||
val image = input.image
|
||||
?: return Instruction("Waiting for camera", isValid = false)
|
||||
|
||||
val resized = Bitmap.createScaledBitmap(image, inputW, inputH, true)
|
||||
val buffer = bitmapToBuffer(resized)
|
||||
|
||||
val locations = Array(1) { Array(maxDetections) { FloatArray(4) } }
|
||||
val classes = Array(1) { FloatArray(maxDetections) }
|
||||
val scores = Array(1) { FloatArray(maxDetections) }
|
||||
val count = FloatArray(1)
|
||||
|
||||
interpreter?.runForMultipleInputsOutputs(
|
||||
arrayOf(buffer),
|
||||
mapOf(0 to locations, 1 to classes, 2 to scores, 3 to count)
|
||||
)
|
||||
|
||||
val detections = mutableListOf<Detection>()
|
||||
|
||||
for (i in 0 until count[0].toInt()) {
|
||||
if (scores[0][i] < 0.5f) continue
|
||||
|
||||
val label = labels.getOrElse(classes[0][i].toInt()) { "Unknown" }
|
||||
val b = locations[0][i]
|
||||
|
||||
detections += Detection(
|
||||
label,
|
||||
scores[0][i],
|
||||
RectF(
|
||||
b[1] * image.width,
|
||||
b[0] * image.height,
|
||||
b[3] * image.width,
|
||||
b[2] * image.height
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
val image = input.image ?: return Instruction("Waiting for camera...", isValid = false)
|
||||
val primary = detections
|
||||
.filter { it.label.equals(input.targetAnimal, true) }
|
||||
.maxByOrNull { it.confidence }
|
||||
|
||||
val resizedBitmap = Bitmap.createScaledBitmap(image, modelInputWidth, modelInputHeight, true)
|
||||
val byteBuffer = convertBitmapToByteBuffer(resizedBitmap)
|
||||
|
||||
// Define model outputs with the correct size
|
||||
val outputLocations = Array(1) { Array(maxDetections) { FloatArray(4) } }
|
||||
val outputClasses = Array(1) { FloatArray(maxDetections) }
|
||||
val outputScores = Array(1) { FloatArray(maxDetections) }
|
||||
val numDetections = FloatArray(1)
|
||||
|
||||
val outputs: MutableMap<Int, Any> = HashMap()
|
||||
outputs[0] = outputLocations
|
||||
outputs[1] = outputClasses
|
||||
outputs[2] = outputScores
|
||||
outputs[3] = numDetections
|
||||
|
||||
interpreter?.runForMultipleInputsOutputs(arrayOf(byteBuffer), outputs)
|
||||
|
||||
val detectedObjects = mutableListOf<Detection>()
|
||||
val detectionCount = numDetections[0].toInt()
|
||||
|
||||
for (i in 0 until detectionCount) {
|
||||
val score = outputScores[0][i]
|
||||
if (score > 0.5f) { // Confidence threshold
|
||||
val classIndex = outputClasses[0][i].toInt()
|
||||
val label = labels.getOrElse(classIndex) { "Unknown" }
|
||||
|
||||
val location = outputLocations[0][i]
|
||||
// TF Lite model returns ymin, xmin, ymax, xmax in normalized coordinates
|
||||
val ymin = location[0] * image.height
|
||||
val xmin = location[1] * image.width
|
||||
val ymax = location[2] * image.height
|
||||
val xmax = location[3] * image.width
|
||||
|
||||
val boundingBox = RectF(xmin, ymin, xmax, ymax)
|
||||
detectedObjects.add(Detection(label, score, boundingBox))
|
||||
}
|
||||
}
|
||||
|
||||
val targetAnimalDetected = detectedObjects.find { it.label.equals(input.targetAnimal, ignoreCase = true) }
|
||||
val isValid = targetAnimalDetected != null
|
||||
|
||||
val message = if (isValid) {
|
||||
"${input.targetAnimal} Detected"
|
||||
} else {
|
||||
if (detectedObjects.isEmpty()) "No objects detected" else "Animal not detected, move closer or point camera to the animal"
|
||||
}
|
||||
|
||||
val refObjects = detectedObjects
|
||||
.filter { it !== targetAnimalDetected }
|
||||
.mapIndexed { index, detection ->
|
||||
val refs = detections
|
||||
.filter { it !== primary }
|
||||
.mapIndexed { i, d ->
|
||||
ReferenceObject(
|
||||
id = "ref_$index",
|
||||
label = detection.label,
|
||||
bounds = detection.bounds,
|
||||
relativeHeight = detection.bounds.height() / image.height,
|
||||
relativeWidth = detection.bounds.width() / image.width,
|
||||
distance = 1.0f // Placeholder
|
||||
id = "ref_$i",
|
||||
label = d.label,
|
||||
bounds = d.bounds,
|
||||
relativeHeight = d.bounds.height() / image.height,
|
||||
relativeWidth = d.bounds.width() / image.width,
|
||||
distance = null
|
||||
)
|
||||
}
|
||||
|
||||
return Instruction(
|
||||
message = message,
|
||||
isValid = isValid,
|
||||
message = if (primary != null) "Cow detected" else "Cow not detected",
|
||||
isValid = primary != null,
|
||||
result = DetectionResult(
|
||||
isAnimalDetected = isValid,
|
||||
animalBounds = targetAnimalDetected?.bounds,
|
||||
referenceObjects = refObjects,
|
||||
label = targetAnimalDetected?.label,
|
||||
confidence = targetAnimalDetected?.confidence ?: 0f
|
||||
isAnimalDetected = primary != null,
|
||||
animalBounds = primary?.bounds,
|
||||
referenceObjects = refs,
|
||||
label = primary?.label,
|
||||
confidence = primary?.confidence ?: 0f
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
private fun convertBitmapToByteBuffer(bitmap: Bitmap): ByteBuffer {
|
||||
val byteBuffer = ByteBuffer.allocateDirect(1 * modelInputWidth * modelInputHeight * 3)
|
||||
byteBuffer.order(ByteOrder.nativeOrder())
|
||||
val intValues = IntArray(modelInputWidth * modelInputHeight)
|
||||
bitmap.getPixels(intValues, 0, bitmap.width, 0, 0, bitmap.width, bitmap.height)
|
||||
var pixel = 0
|
||||
for (i in 0 until modelInputWidth) {
|
||||
for (j in 0 until modelInputHeight) {
|
||||
val `val` = intValues[pixel++]
|
||||
// Assuming model expects UINT8 [0, 255]
|
||||
byteBuffer.put(((`val` shr 16) and 0xFF).toByte())
|
||||
byteBuffer.put(((`val` shr 8) and 0xFF).toByte())
|
||||
byteBuffer.put((`val` and 0xFF).toByte())
|
||||
private fun bitmapToBuffer(bitmap: Bitmap): ByteBuffer {
|
||||
val buffer = ByteBuffer.allocateDirect(inputW * inputH * 3)
|
||||
buffer.order(ByteOrder.nativeOrder())
|
||||
val pixels = IntArray(inputW * inputH)
|
||||
bitmap.getPixels(pixels, 0, inputW, 0, 0, inputW, inputH)
|
||||
for (p in pixels) {
|
||||
buffer.put(((p shr 16) and 0xFF).toByte())
|
||||
buffer.put(((p shr 8) and 0xFF).toByte())
|
||||
buffer.put((p and 0xFF).toByte())
|
||||
}
|
||||
}
|
||||
return byteBuffer
|
||||
return buffer
|
||||
}
|
||||
|
||||
data class Detection(val label: String, val confidence: Float, val bounds: RectF)
|
||||
}
|
||||
|
||||
/* ============================================================= */
|
||||
/* POSE ANALYZER (ALIGNMENT → CROP → SEGMENT) */
|
||||
/* ============================================================= */
|
||||
|
||||
class MockPoseAnalyzer : PoseAnalyzer {
|
||||
|
||||
private val segmenter by lazy {
|
||||
val options = SubjectSegmenterOptions.Builder()
|
||||
SubjectSegmentation.getClient(
|
||||
SubjectSegmenterOptions.Builder()
|
||||
.enableForegroundConfidenceMask()
|
||||
.build()
|
||||
SubjectSegmentation.getClient(options)
|
||||
)
|
||||
}
|
||||
|
||||
override suspend fun analyze(input: PipelineInput): Instruction {
|
||||
|
|
@ -236,130 +210,124 @@ class MockPoseAnalyzer : PoseAnalyzer {
|
|||
val detection = input.previousDetectionResult
|
||||
?: return Instruction("No detection", isValid = false)
|
||||
|
||||
val bounds = detection.animalBounds
|
||||
?: return Instruction("Animal not detected", isValid = false)
|
||||
val cowBox = detection.animalBounds
|
||||
?: return Instruction("Cow not detected", isValid = false)
|
||||
|
||||
val image = input.image
|
||||
?: return Instruction("No image", isValid = false)
|
||||
|
||||
// --------------------------------------------------------------------
|
||||
// 1. Reference silhouette (FloatArray)
|
||||
// --------------------------------------------------------------------
|
||||
val reference = SilhouetteManager.getWeightedMask(input.orientation)
|
||||
val silhouette = SilhouetteManager.getSilhouette(input.orientation)
|
||||
?: return Instruction("Silhouette missing", isValid = false)
|
||||
|
||||
val refH = reference.mask.size
|
||||
val refW = reference.mask[0].size
|
||||
val align = checkAlignment(cowBox, silhouette.boundingBox, 0.15f)
|
||||
if (align.issue != AlignmentIssue.OK) {
|
||||
return alignmentToInstruction(align)
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------
|
||||
// 2. Crop only animal region (BIG WIN)
|
||||
// --------------------------------------------------------------------
|
||||
val cropped = Bitmap.createBitmap(
|
||||
image,
|
||||
bounds.left.toInt(),
|
||||
bounds.top.toInt(),
|
||||
bounds.width().toInt(),
|
||||
bounds.height().toInt()
|
||||
cowBox.left.toInt(),
|
||||
cowBox.top.toInt(),
|
||||
cowBox.width().toInt(),
|
||||
cowBox.height().toInt()
|
||||
)
|
||||
|
||||
val scaled = Bitmap.createScaledBitmap(cropped, refW, refH, true)
|
||||
val resized = Bitmap.createScaledBitmap(
|
||||
cropped,
|
||||
silhouette.croppedBitmap.width,
|
||||
silhouette.croppedBitmap.height,
|
||||
true
|
||||
)
|
||||
|
||||
// --------------------------------------------------------------------
|
||||
// 3. Get binary segmentation mask (ByteArray)
|
||||
// --------------------------------------------------------------------
|
||||
val mask = getAnimalMaskFast(scaled)
|
||||
val mask = segment(resized)
|
||||
?: return Instruction("Segmentation failed", isValid = false)
|
||||
|
||||
// --------------------------------------------------------------------
|
||||
// 4. Fast signed similarity
|
||||
// --------------------------------------------------------------------
|
||||
val score = calculateSignedSimilarityFast(mask, reference)
|
||||
|
||||
val score = similarity(mask, silhouette.signedMask)
|
||||
val valid = score >= 0.40f
|
||||
|
||||
return Instruction(
|
||||
message = if (valid) "Pose Correct" else "Pose Incorrect (%.2f)".format(score),
|
||||
message = if (valid) "Pose Correct" else "Adjust Position",
|
||||
isValid = valid,
|
||||
result = detection
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* ML Kit → Binary mask (0/1) using ByteArray
|
||||
*/
|
||||
private suspend fun getAnimalMaskFast(bitmap: Bitmap): ByteArray? =
|
||||
private suspend fun segment(bitmap: Bitmap): ByteArray? =
|
||||
suspendCancellableCoroutine { cont ->
|
||||
|
||||
val image = InputImage.fromBitmap(bitmap, 0)
|
||||
segmenter.process(image)
|
||||
.addOnSuccessListener { result ->
|
||||
val mask = result.foregroundConfidenceMask
|
||||
?: run {
|
||||
cont.resume(null)
|
||||
return@addOnSuccessListener
|
||||
}
|
||||
|
||||
val buffer = mask
|
||||
buffer.rewind()
|
||||
|
||||
segmenter.process(InputImage.fromBitmap(bitmap, 0))
|
||||
.addOnSuccessListener { r ->
|
||||
val buf = r.foregroundConfidenceMask
|
||||
?: return@addOnSuccessListener cont.resume(null)
|
||||
buf.rewind()
|
||||
val out = ByteArray(bitmap.width * bitmap.height)
|
||||
|
||||
for (i in out.indices) {
|
||||
out[i] = if (buffer.get() > 0.5f) 1 else 0
|
||||
}
|
||||
|
||||
for (i in out.indices) out[i] = if (buf.get() > 0.5f) 1 else 0
|
||||
cont.resume(out)
|
||||
}
|
||||
.addOnFailureListener {
|
||||
cont.resume(null)
|
||||
.addOnFailureListener { cont.resume(null) }
|
||||
}
|
||||
|
||||
private fun similarity(mask: ByteArray, ref: SignedMask): Float {
|
||||
var s = 0f
|
||||
var i = 0
|
||||
for (row in ref.mask)
|
||||
for (v in row)
|
||||
s += mask[i++] * v
|
||||
return if (ref.maxValue == 0f) 0f else s / ref.maxValue
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Signed similarity (flat arrays, cache-friendly)
|
||||
* Range [-1, 1]
|
||||
*/
|
||||
private fun calculateSignedSimilarityFast(
|
||||
mask: ByteArray,
|
||||
reference: SignedMask
|
||||
): Float {
|
||||
/* ============================================================= */
|
||||
/* ALIGNMENT HELPERS */
|
||||
/* ============================================================= */
|
||||
|
||||
var score = 0f
|
||||
val maxScore = reference.maxValue
|
||||
var idx = 0
|
||||
enum class AlignmentIssue { TOO_SMALL, TOO_LARGE, MOVE_LEFT, MOVE_RIGHT, MOVE_UP, MOVE_DOWN, OK }
|
||||
|
||||
for (y in reference.mask.indices) {
|
||||
val row = reference.mask[y]
|
||||
for (x in row.indices) {
|
||||
val r = row[x]
|
||||
score += mask[idx++] * r
|
||||
data class AlignmentResult(val issue: AlignmentIssue, val scale: Float, val dx: Float, val dy: Float)
|
||||
|
||||
fun checkAlignment(d: RectF, s: RectF, tol: Float): AlignmentResult {
|
||||
|
||||
val scale = min(d.width() / s.width(), d.height() / s.height())
|
||||
val dx = d.centerX() - s.centerX()
|
||||
val dy = d.centerY() - s.centerY()
|
||||
|
||||
if (scale < 1f - tol) return AlignmentResult(AlignmentIssue.TOO_SMALL, scale, dx, dy)
|
||||
if (scale > 1f + tol) return AlignmentResult(AlignmentIssue.TOO_LARGE, scale, dx, dy)
|
||||
|
||||
val tx = s.width() * tol
|
||||
val ty = s.height() * tol
|
||||
|
||||
return when {
|
||||
dx > tx -> AlignmentResult(AlignmentIssue.MOVE_LEFT, scale, dx, dy)
|
||||
dx < -tx -> AlignmentResult(AlignmentIssue.MOVE_RIGHT, scale, dx, dy)
|
||||
dy > ty -> AlignmentResult(AlignmentIssue.MOVE_UP, scale, dx, dy)
|
||||
dy < -ty -> AlignmentResult(AlignmentIssue.MOVE_DOWN, scale, dx, dy)
|
||||
else -> AlignmentResult(AlignmentIssue.OK, scale, dx, dy)
|
||||
}
|
||||
}
|
||||
|
||||
return if (maxScore == 0f) 0f else score / maxScore
|
||||
}
|
||||
fun alignmentToInstruction(a: AlignmentResult) = when (a.issue) {
|
||||
AlignmentIssue.TOO_SMALL -> Instruction("Move closer", isValid = false)
|
||||
AlignmentIssue.TOO_LARGE -> Instruction("Move backward", isValid = false)
|
||||
AlignmentIssue.MOVE_LEFT -> Instruction("Move right", isValid = false)
|
||||
AlignmentIssue.MOVE_RIGHT -> Instruction("Move left", isValid = false)
|
||||
AlignmentIssue.MOVE_UP -> Instruction("Move down", isValid = false)
|
||||
AlignmentIssue.MOVE_DOWN -> Instruction("Move up", isValid = false)
|
||||
AlignmentIssue.OK -> Instruction("Hold steady", isValid = true)
|
||||
}
|
||||
|
||||
/* ============================================================= */
|
||||
/* CAPTURE + MEASUREMENT (UNCHANGED) */
|
||||
/* ============================================================= */
|
||||
|
||||
class DefaultCaptureHandler : CaptureHandler {
|
||||
override suspend fun capture(input: PipelineInput, detectionResult: DetectionResult): CaptureData {
|
||||
val image = input.image ?: throw IllegalStateException("Image cannot be null during capture")
|
||||
|
||||
val segmentationMask = BooleanArray(100) { true }
|
||||
|
||||
val animalMetrics = ObjectMetrics(
|
||||
relativeHeight = 0.5f,
|
||||
relativeWidth = 0.3f,
|
||||
distance = 1.2f
|
||||
)
|
||||
|
||||
return CaptureData(
|
||||
image = image,
|
||||
segmentationMask = segmentationMask,
|
||||
animalMetrics = animalMetrics,
|
||||
override suspend fun capture(input: PipelineInput, detectionResult: DetectionResult): CaptureData =
|
||||
CaptureData(
|
||||
image = input.image!!,
|
||||
segmentationMask = BooleanArray(0),
|
||||
animalMetrics = ObjectMetrics(0f, 0f, 1f),
|
||||
referenceObjects = detectionResult.referenceObjects
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
class DefaultMeasurementCalculator : MeasurementCalculator {
|
||||
override fun calculateRealMetrics(
|
||||
|
|
@ -367,18 +335,16 @@ class DefaultMeasurementCalculator : MeasurementCalculator {
|
|||
referenceObject: ReferenceObject,
|
||||
currentMetrics: ObjectMetrics
|
||||
): RealWorldMetrics {
|
||||
if (referenceObject.relativeHeight == 0f) return RealWorldMetrics(0f, 0f, 0f)
|
||||
|
||||
if (referenceObject.relativeHeight == 0f)
|
||||
return RealWorldMetrics(0f, 0f, 0f)
|
||||
|
||||
val scale = targetHeight / referenceObject.relativeHeight
|
||||
|
||||
val realHeight = currentMetrics.relativeHeight * scale
|
||||
val realWidth = currentMetrics.relativeWidth * scale
|
||||
val realDistance = currentMetrics.distance
|
||||
|
||||
return RealWorldMetrics(
|
||||
height = realHeight,
|
||||
width = realWidth,
|
||||
distance = realDistance
|
||||
height = currentMetrics.relativeHeight * scale,
|
||||
width = currentMetrics.relativeWidth * scale,
|
||||
distance = currentMetrics.distance
|
||||
)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -55,7 +55,7 @@ data class ReferenceObject(
|
|||
val bounds: RectF,
|
||||
val relativeHeight: Float,
|
||||
val relativeWidth: Float,
|
||||
val distance: Float
|
||||
val distance: Float? = null
|
||||
)
|
||||
|
||||
enum class CameraOrientation {
|
||||
|
|
|
|||
|
|
@ -1,9 +1,7 @@
|
|||
package com.example.livingai.utils
|
||||
|
||||
import android.content.Context
|
||||
import android.graphics.Bitmap
|
||||
import android.graphics.BitmapFactory
|
||||
import android.graphics.Color
|
||||
import android.graphics.*
|
||||
import android.util.Log
|
||||
import com.example.livingai.R
|
||||
import java.util.concurrent.ConcurrentHashMap
|
||||
|
|
@ -14,20 +12,25 @@ data class SignedMask(
|
|||
val maxValue: Float
|
||||
)
|
||||
|
||||
data class SilhouetteData(
|
||||
val croppedBitmap: Bitmap,
|
||||
val boundingBox: RectF,
|
||||
val signedMask: SignedMask
|
||||
)
|
||||
|
||||
object SilhouetteManager {
|
||||
|
||||
private val originals = ConcurrentHashMap<String, Bitmap>()
|
||||
private val invertedPurple = ConcurrentHashMap<String, Bitmap>()
|
||||
private val weightedMasks = ConcurrentHashMap<String, SignedMask>()
|
||||
private val silhouettes = ConcurrentHashMap<String, SilhouetteData>()
|
||||
|
||||
fun getOriginal(name: String): Bitmap? = originals[name]
|
||||
fun getInvertedPurple(name: String): Bitmap? = invertedPurple[name]
|
||||
fun getWeightedMask(name: String): SignedMask? = weightedMasks[name]
|
||||
fun getSilhouette(name: String): SilhouetteData? = silhouettes[name]
|
||||
|
||||
fun initialize(context: Context, width: Int, height: Int) {
|
||||
val resources = context.resources
|
||||
fun initialize(context: Context, screenW: Int, screenH: Int) {
|
||||
|
||||
val silhouetteList = mapOf(
|
||||
val res = context.resources
|
||||
|
||||
val map = mapOf(
|
||||
"front" to R.drawable.front_silhouette,
|
||||
"back" to R.drawable.back_silhouette,
|
||||
"left" to R.drawable.left_silhouette,
|
||||
|
|
@ -37,58 +40,94 @@ object SilhouetteManager {
|
|||
"angleview" to R.drawable.angleview_silhouette
|
||||
)
|
||||
|
||||
silhouetteList.forEach { (name, resId) ->
|
||||
val bmp = BitmapFactory.decodeResource(resources, resId)
|
||||
originals[name] = bmp
|
||||
map.forEach { (name, resId) ->
|
||||
|
||||
// Fit image appropriately (front/back = W/H, others rotated)
|
||||
val fitted = if (name == "front" || name == "back")
|
||||
createInvertedPurpleBitmap(bmp, width, height)
|
||||
else
|
||||
createInvertedPurpleBitmap(bmp, height, width)
|
||||
val src = BitmapFactory.decodeResource(res, resId)
|
||||
originals[name] = src
|
||||
|
||||
invertedPurple[name] = fitted
|
||||
val fitted = Bitmap.createScaledBitmap(
|
||||
invertToPurple(src),
|
||||
screenW,
|
||||
screenH,
|
||||
true
|
||||
)
|
||||
|
||||
weightedMasks[name] = createSignedWeightedMask(fitted, fadeInside = 10, fadeOutside = 20)
|
||||
val bbox = computeBoundingBox(fitted)
|
||||
|
||||
Log.d("Silhouette", "Loaded mask: $name (${fitted.width} x ${fitted.height})")
|
||||
val cropped = Bitmap.createBitmap(
|
||||
fitted,
|
||||
bbox.left.toInt(),
|
||||
bbox.top.toInt(),
|
||||
bbox.width().toInt(),
|
||||
bbox.height().toInt()
|
||||
)
|
||||
|
||||
val signedMask = createSignedWeightedMask(cropped)
|
||||
|
||||
silhouettes[name] = SilhouetteData(
|
||||
croppedBitmap = cropped,
|
||||
boundingBox = bbox,
|
||||
signedMask = signedMask
|
||||
)
|
||||
|
||||
Log.d("Silhouette", "Loaded $name (${bbox.width()} x ${bbox.height()})")
|
||||
}
|
||||
}
|
||||
|
||||
// ------------------------------------------------------------------------
|
||||
// STEP 1: Create "inverted purple" mask (transparent object becomes purple)
|
||||
// ------------------------------------------------------------------------
|
||||
|
||||
private fun createInvertedPurpleBitmap(
|
||||
src: Bitmap,
|
||||
targetWidth: Int,
|
||||
targetHeight: Int
|
||||
): Bitmap {
|
||||
/* ---------------------------------------------------------- */
|
||||
|
||||
private fun invertToPurple(src: Bitmap): Bitmap {
|
||||
val w = src.width
|
||||
val h = src.height
|
||||
|
||||
val pixels = IntArray(w * h)
|
||||
src.getPixels(pixels, 0, w, 0, 0, w, h)
|
||||
|
||||
val purple = Color.argb(255, 128, 0, 128)
|
||||
|
||||
for (i in pixels.indices) {
|
||||
val alpha = pixels[i] ushr 24
|
||||
pixels[i] = if (alpha == 0) purple else 0x00000000
|
||||
pixels[i] =
|
||||
if ((pixels[i] ushr 24) == 0) purple
|
||||
else 0x00000000
|
||||
}
|
||||
|
||||
val inverted = Bitmap.createBitmap(pixels, w, h, Bitmap.Config.ARGB_8888)
|
||||
return Bitmap.createScaledBitmap(inverted, targetWidth, targetHeight, true)
|
||||
return Bitmap.createBitmap(pixels, w, h, Bitmap.Config.ARGB_8888)
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a signed weighted mask in range [-1, +1]
|
||||
*
|
||||
* +1 : deep inside object
|
||||
* 0 : object boundary
|
||||
* -1 : far outside object
|
||||
*/
|
||||
private fun computeBoundingBox(bitmap: Bitmap): RectF {
|
||||
|
||||
val w = bitmap.width
|
||||
val h = bitmap.height
|
||||
val pixels = IntArray(w * h)
|
||||
bitmap.getPixels(pixels, 0, w, 0, 0, w, h)
|
||||
|
||||
var minX = w
|
||||
var minY = h
|
||||
var maxX = 0
|
||||
var maxY = 0
|
||||
|
||||
for (y in 0 until h) {
|
||||
for (x in 0 until w) {
|
||||
if ((pixels[y * w + x] ushr 24) > 0) {
|
||||
minX = min(minX, x)
|
||||
minY = min(minY, y)
|
||||
maxX = maxOf(maxX, x)
|
||||
maxY = maxOf(maxY, y)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return RectF(
|
||||
minX.toFloat(),
|
||||
minY.toFloat(),
|
||||
maxX.toFloat(),
|
||||
maxY.toFloat()
|
||||
)
|
||||
}
|
||||
|
||||
/* ---------------------------------------------------------- */
|
||||
/* SIGNED WEIGHTED MASK */
|
||||
/* ---------------------------------------------------------- */
|
||||
|
||||
fun createSignedWeightedMask(
|
||||
bitmap: Bitmap,
|
||||
fadeInside: Int = 10,
|
||||
|
|
@ -101,96 +140,54 @@ object SilhouetteManager {
|
|||
val pixels = IntArray(w * h)
|
||||
bitmap.getPixels(pixels, 0, w, 0, 0, w, h)
|
||||
|
||||
val inside = IntArray(w * h)
|
||||
for (i in pixels.indices)
|
||||
inside[i] = if ((pixels[i] ushr 24) > 0) 1 else 0
|
||||
|
||||
fun idx(x: Int, y: Int) = y * w + x
|
||||
|
||||
// --------------------------------------------------------------------
|
||||
// 1. Binary mask
|
||||
// --------------------------------------------------------------------
|
||||
val inside = IntArray(w * h)
|
||||
for (i in pixels.indices) {
|
||||
inside[i] = if ((pixels[i] ushr 24) > 0) 1 else 0
|
||||
val distIn = IntArray(w * h) { Int.MAX_VALUE }
|
||||
val distOut = IntArray(w * h) { Int.MAX_VALUE }
|
||||
|
||||
for (i in inside.indices) {
|
||||
if (inside[i] == 0) distIn[i] = 0
|
||||
else distOut[i] = 0
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------
|
||||
// 2. Distance transform (inside → outside)
|
||||
// --------------------------------------------------------------------
|
||||
val distInside = IntArray(w * h) { Int.MAX_VALUE }
|
||||
for (i in inside.indices) if (inside[i] == 0) distInside[i] = 0
|
||||
|
||||
for (y in 0 until h) {
|
||||
for (y in 0 until h)
|
||||
for (x in 0 until w) {
|
||||
val i = idx(x, y)
|
||||
var best = distInside[i]
|
||||
if (x > 0) best = min(best, distInside[idx(x - 1, y)] + 1)
|
||||
if (y > 0) best = min(best, distInside[idx(x, y - 1)] + 1)
|
||||
distInside[i] = best
|
||||
}
|
||||
if (x > 0) distIn[i] = min(distIn[i], distIn[idx(x - 1, y)] + 1)
|
||||
if (y > 0) distIn[i] = min(distIn[i], distIn[idx(x, y - 1)] + 1)
|
||||
if (x > 0) distOut[i] = min(distOut[i], distOut[idx(x - 1, y)] + 1)
|
||||
if (y > 0) distOut[i] = min(distOut[i], distOut[idx(x, y - 1)] + 1)
|
||||
}
|
||||
|
||||
for (y in h - 1 downTo 0) {
|
||||
for (y in h - 1 downTo 0)
|
||||
for (x in w - 1 downTo 0) {
|
||||
val i = idx(x, y)
|
||||
var best = distInside[i]
|
||||
if (x < w - 1) best = min(best, distInside[idx(x + 1, y)] + 1)
|
||||
if (y < h - 1) best = min(best, distInside[idx(x, y + 1)] + 1)
|
||||
distInside[i] = best
|
||||
}
|
||||
if (x < w - 1) distIn[i] = min(distIn[i], distIn[idx(x + 1, y)] + 1)
|
||||
if (y < h - 1) distIn[i] = min(distIn[i], distIn[idx(x, y + 1)] + 1)
|
||||
if (x < w - 1) distOut[i] = min(distOut[i], distOut[idx(x + 1, y)] + 1)
|
||||
if (y < h - 1) distOut[i] = min(distOut[i], distOut[idx(x, y + 1)] + 1)
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------
|
||||
// 3. Distance transform (outside → inside)
|
||||
// --------------------------------------------------------------------
|
||||
val distOutside = IntArray(w * h) { Int.MAX_VALUE }
|
||||
for (i in inside.indices) if (inside[i] == 1) distOutside[i] = 0
|
||||
val mask = Array(h) { FloatArray(w) }
|
||||
var maxVal = Float.NEGATIVE_INFINITY
|
||||
|
||||
for (y in 0 until h) {
|
||||
for (y in 0 until h)
|
||||
for (x in 0 until w) {
|
||||
val i = idx(x, y)
|
||||
var best = distOutside[i]
|
||||
if (x > 0) best = min(best, distOutside[idx(x - 1, y)] + 1)
|
||||
if (y > 0) best = min(best, distOutside[idx(x, y - 1)] + 1)
|
||||
distOutside[i] = best
|
||||
}
|
||||
val v =
|
||||
if (inside[i] == 1)
|
||||
min(1f, distIn[i].toFloat() / fadeInside)
|
||||
else
|
||||
maxOf(-1f, -distOut[i].toFloat() / fadeOutside)
|
||||
|
||||
mask[y][x] = v
|
||||
if (v > maxVal) maxVal = v
|
||||
}
|
||||
|
||||
for (y in h - 1 downTo 0) {
|
||||
for (x in w - 1 downTo 0) {
|
||||
val i = idx(x, y)
|
||||
var best = distOutside[i]
|
||||
if (x < w - 1) best = min(best, distOutside[idx(x + 1, y)] + 1)
|
||||
if (y < h - 1) best = min(best, distOutside[idx(x, y + 1)] + 1)
|
||||
distOutside[i] = best
|
||||
return SignedMask(mask, maxVal)
|
||||
}
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------
|
||||
// 4. Build signed mask + track max value
|
||||
// --------------------------------------------------------------------
|
||||
val result = Array(h) { FloatArray(w) }
|
||||
var maxValue = Float.NEGATIVE_INFINITY
|
||||
|
||||
for (y in 0 until h) {
|
||||
for (x in 0 until w) {
|
||||
val i = idx(x, y)
|
||||
|
||||
val weight = if (inside[i] == 1) {
|
||||
val d = distInside[i]
|
||||
if (d >= fadeInside) 1f
|
||||
else d.toFloat() / fadeInside
|
||||
} else {
|
||||
val d = distOutside[i]
|
||||
(-d.toFloat() / fadeOutside).coerceAtLeast(-1f)
|
||||
}
|
||||
|
||||
result[y][x] = weight
|
||||
if (weight > maxValue) maxValue = weight
|
||||
}
|
||||
}
|
||||
|
||||
return SignedMask(
|
||||
mask = result,
|
||||
maxValue = maxValue
|
||||
)
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue