adding additional checks in pipeline, to reduce segmentation load

This commit is contained in:
SaiD 2025-12-13 13:38:09 +05:30
parent 9750472027
commit 986c0da505
4 changed files with 332 additions and 369 deletions

View File

@ -2,7 +2,6 @@ package com.example.livingai.data.camera
import android.content.Context
import android.graphics.Bitmap
import android.graphics.Color
import android.graphics.RectF
import android.util.Log
import com.example.livingai.R
@ -10,225 +9,200 @@ import com.example.livingai.domain.camera.*
import com.example.livingai.domain.model.camera.*
import com.example.livingai.utils.SignedMask
import com.example.livingai.utils.SilhouetteManager
import com.google.mlkit.vision.common.InputImage
import com.google.mlkit.vision.segmentation.subject.SubjectSegmentation
import com.google.mlkit.vision.segmentation.subject.SubjectSegmenterOptions
import kotlinx.coroutines.suspendCancellableCoroutine
import org.tensorflow.lite.Interpreter
import org.tensorflow.lite.support.common.FileUtil
import java.io.IOException
import java.nio.ByteBuffer
import java.nio.ByteOrder
import kotlin.math.abs
import com.google.mlkit.vision.common.InputImage
import com.google.mlkit.vision.segmentation.subject.SubjectSegmentation
import com.google.mlkit.vision.segmentation.subject.SubjectSegmenterOptions
import kotlinx.coroutines.suspendCancellableCoroutine
import kotlin.coroutines.resume
import kotlin.math.abs
import kotlin.math.min
/* ============================================================= */
/* ORIENTATION CHECKER */
/* ============================================================= */
class DefaultOrientationChecker : OrientationChecker {
override suspend fun analyze(input: PipelineInput): Instruction {
val orientationLower = input.orientation.lowercase()
val isPortraitRequired = orientationLower == "front" || orientationLower == "back"
// Corrected Logic:
// 90 or 270 degrees means the device is held in PORTRAIT
val isDevicePortrait = input.deviceOrientation == 90 || input.deviceOrientation == 270
// 0 or 180 degrees means the device is held in LANDSCAPE
val isDeviceLandscape = input.deviceOrientation == 0 || input.deviceOrientation == 180
val isPortraitRequired =
input.orientation.lowercase() == "front" ||
input.orientation.lowercase() == "back"
var isValid = true
var message = "Orientation Correct"
val isPortrait = input.deviceOrientation == 90 || input.deviceOrientation == 270
val isLandscape = input.deviceOrientation == 0 || input.deviceOrientation == 180
if (isPortraitRequired && !isDevicePortrait) {
isValid = false
message = "Turn to portrait mode"
} else if (!isPortraitRequired && !isDeviceLandscape) {
isValid = false
message = "Turn to landscape mode"
}
val animRes = if (!isValid) R.drawable.ic_launcher_foreground else null
val valid = if (isPortraitRequired) isPortrait else isLandscape
return Instruction(
message = message,
animationResId = animRes,
isValid = isValid,
result = OrientationResult(input.deviceOrientation, if (isPortraitRequired) CameraOrientation.PORTRAIT else CameraOrientation.LANDSCAPE)
message = if (valid) "Orientation Correct"
else if (isPortraitRequired) "Turn to portrait mode"
else "Turn to landscape mode",
animationResId = if (valid) null else R.drawable.ic_launcher_foreground,
isValid = valid,
result = OrientationResult(
input.deviceOrientation,
if (isPortraitRequired) CameraOrientation.PORTRAIT else CameraOrientation.LANDSCAPE
)
)
}
}
/* ============================================================= */
/* TILT CHECKER */
/* ============================================================= */
class DefaultTiltChecker : TiltChecker {
override suspend fun analyze(input: PipelineInput): Instruction {
Log.d("TiltChecker", "Required Orientation: ${input.requiredOrientation}, Pitch: ${input.devicePitch}, Roll: ${input.deviceRoll}, Azimuth: ${input.deviceAzimuth}")
val tolerance = 25.0f
val isLevel: Boolean
val tolerance = 25f
if (input.requiredOrientation == CameraOrientation.PORTRAIT) {
// Ideal for portrait: pitch around -90, roll around 0
val idealPitch = -90.0f
isLevel = abs(input.devicePitch - idealPitch) <= tolerance
} else { // LANDSCAPE
// Ideal for landscape: pitch around 0, roll around +/-90
val idealPitch = 0.0f
isLevel = abs(input.devicePitch - idealPitch) <= tolerance
val isLevel = when (input.requiredOrientation) {
CameraOrientation.PORTRAIT ->
abs(input.devicePitch + 90f) <= tolerance
CameraOrientation.LANDSCAPE ->
abs(input.devicePitch) <= tolerance
}
val message = if (isLevel) "Device is level" else "Keep the phone straight"
return Instruction(
message = message,
message = if (isLevel) "Device is level" else "Keep the phone straight",
isValid = isLevel,
result = TiltResult(input.deviceRoll, input.devicePitch, isLevel)
)
}
}
/* ============================================================= */
/* TFLITE OBJECT DETECTOR (PRIMARY + REFERENCE OBJECTS) */
/* ============================================================= */
class TFLiteObjectDetector(context: Context) : ObjectDetector {
private var interpreter: Interpreter? = null
private var labels: List<String> = emptyList()
private var modelInputWidth: Int = 0
private var modelInputHeight: Int = 0
private var maxDetections: Int = 25
private var inputW = 0
private var inputH = 0
private var maxDetections = 25
init {
try {
val modelBuffer = FileUtil.loadMappedFile(context, "efficientdet-lite0.tflite")
interpreter = Interpreter(modelBuffer)
interpreter = Interpreter(
FileUtil.loadMappedFile(context, "efficientdet-lite0.tflite")
)
labels = FileUtil.loadLabels(context, "labels.txt")
val inputTensor = interpreter?.getInputTensor(0)
val inputShape = inputTensor?.shape()
if (inputShape != null && inputShape.size >= 3) {
modelInputWidth = inputShape[1]
modelInputHeight = inputShape[2]
} else {
Log.e("TFLiteObjectDetector", "Invalid input tensor shape.")
}
val inputShape = interpreter!!.getInputTensor(0).shape()
inputW = inputShape[1]
inputH = inputShape[2]
val outputTensor = interpreter?.getOutputTensor(0)
val outputShape = outputTensor?.shape()
if (outputShape != null && outputShape.size >= 2) {
maxDetections = outputShape[1]
Log.d("TFLiteObjectDetector", "Max detections from model: $maxDetections")
}
maxDetections = interpreter!!.getOutputTensor(0).shape()[1]
Log.d("TFLiteObjectDetector", "TFLite model loaded successfully.")
} catch (e: IOException) {
Log.e("TFLiteObjectDetector", "Error loading TFLite model or labels from assets.", e)
Log.e("TFLiteObjectDetector", "Please ensure 'efficientdet-lite0.tflite' and 'labelmap.txt' are in the 'app/src/main/assets' directory.")
Log.e("Detector", "Failed to load model", e)
interpreter = null
}
}
override suspend fun analyze(input: PipelineInput): Instruction {
if (interpreter == null) {
return Instruction("Object detector not initialized. Check asset files.", isValid = false)
val image = input.image
?: return Instruction("Waiting for camera", isValid = false)
val resized = Bitmap.createScaledBitmap(image, inputW, inputH, true)
val buffer = bitmapToBuffer(resized)
val locations = Array(1) { Array(maxDetections) { FloatArray(4) } }
val classes = Array(1) { FloatArray(maxDetections) }
val scores = Array(1) { FloatArray(maxDetections) }
val count = FloatArray(1)
interpreter?.runForMultipleInputsOutputs(
arrayOf(buffer),
mapOf(0 to locations, 1 to classes, 2 to scores, 3 to count)
)
val detections = mutableListOf<Detection>()
for (i in 0 until count[0].toInt()) {
if (scores[0][i] < 0.5f) continue
val label = labels.getOrElse(classes[0][i].toInt()) { "Unknown" }
val b = locations[0][i]
detections += Detection(
label,
scores[0][i],
RectF(
b[1] * image.width,
b[0] * image.height,
b[3] * image.width,
b[2] * image.height
)
)
}
val image = input.image ?: return Instruction("Waiting for camera...", isValid = false)
val primary = detections
.filter { it.label.equals(input.targetAnimal, true) }
.maxByOrNull { it.confidence }
val resizedBitmap = Bitmap.createScaledBitmap(image, modelInputWidth, modelInputHeight, true)
val byteBuffer = convertBitmapToByteBuffer(resizedBitmap)
// Define model outputs with the correct size
val outputLocations = Array(1) { Array(maxDetections) { FloatArray(4) } }
val outputClasses = Array(1) { FloatArray(maxDetections) }
val outputScores = Array(1) { FloatArray(maxDetections) }
val numDetections = FloatArray(1)
val outputs: MutableMap<Int, Any> = HashMap()
outputs[0] = outputLocations
outputs[1] = outputClasses
outputs[2] = outputScores
outputs[3] = numDetections
interpreter?.runForMultipleInputsOutputs(arrayOf(byteBuffer), outputs)
val detectedObjects = mutableListOf<Detection>()
val detectionCount = numDetections[0].toInt()
for (i in 0 until detectionCount) {
val score = outputScores[0][i]
if (score > 0.5f) { // Confidence threshold
val classIndex = outputClasses[0][i].toInt()
val label = labels.getOrElse(classIndex) { "Unknown" }
val location = outputLocations[0][i]
// TF Lite model returns ymin, xmin, ymax, xmax in normalized coordinates
val ymin = location[0] * image.height
val xmin = location[1] * image.width
val ymax = location[2] * image.height
val xmax = location[3] * image.width
val boundingBox = RectF(xmin, ymin, xmax, ymax)
detectedObjects.add(Detection(label, score, boundingBox))
}
}
val targetAnimalDetected = detectedObjects.find { it.label.equals(input.targetAnimal, ignoreCase = true) }
val isValid = targetAnimalDetected != null
val message = if (isValid) {
"${input.targetAnimal} Detected"
} else {
if (detectedObjects.isEmpty()) "No objects detected" else "Animal not detected, move closer or point camera to the animal"
}
val refObjects = detectedObjects
.filter { it !== targetAnimalDetected }
.mapIndexed { index, detection ->
val refs = detections
.filter { it !== primary }
.mapIndexed { i, d ->
ReferenceObject(
id = "ref_$index",
label = detection.label,
bounds = detection.bounds,
relativeHeight = detection.bounds.height() / image.height,
relativeWidth = detection.bounds.width() / image.width,
distance = 1.0f // Placeholder
id = "ref_$i",
label = d.label,
bounds = d.bounds,
relativeHeight = d.bounds.height() / image.height,
relativeWidth = d.bounds.width() / image.width,
distance = null
)
}
return Instruction(
message = message,
isValid = isValid,
message = if (primary != null) "Cow detected" else "Cow not detected",
isValid = primary != null,
result = DetectionResult(
isAnimalDetected = isValid,
animalBounds = targetAnimalDetected?.bounds,
referenceObjects = refObjects,
label = targetAnimalDetected?.label,
confidence = targetAnimalDetected?.confidence ?: 0f
isAnimalDetected = primary != null,
animalBounds = primary?.bounds,
referenceObjects = refs,
label = primary?.label,
confidence = primary?.confidence ?: 0f
)
)
}
private fun convertBitmapToByteBuffer(bitmap: Bitmap): ByteBuffer {
val byteBuffer = ByteBuffer.allocateDirect(1 * modelInputWidth * modelInputHeight * 3)
byteBuffer.order(ByteOrder.nativeOrder())
val intValues = IntArray(modelInputWidth * modelInputHeight)
bitmap.getPixels(intValues, 0, bitmap.width, 0, 0, bitmap.width, bitmap.height)
var pixel = 0
for (i in 0 until modelInputWidth) {
for (j in 0 until modelInputHeight) {
val `val` = intValues[pixel++]
// Assuming model expects UINT8 [0, 255]
byteBuffer.put(((`val` shr 16) and 0xFF).toByte())
byteBuffer.put(((`val` shr 8) and 0xFF).toByte())
byteBuffer.put((`val` and 0xFF).toByte())
private fun bitmapToBuffer(bitmap: Bitmap): ByteBuffer {
val buffer = ByteBuffer.allocateDirect(inputW * inputH * 3)
buffer.order(ByteOrder.nativeOrder())
val pixels = IntArray(inputW * inputH)
bitmap.getPixels(pixels, 0, inputW, 0, 0, inputW, inputH)
for (p in pixels) {
buffer.put(((p shr 16) and 0xFF).toByte())
buffer.put(((p shr 8) and 0xFF).toByte())
buffer.put((p and 0xFF).toByte())
}
}
return byteBuffer
return buffer
}
data class Detection(val label: String, val confidence: Float, val bounds: RectF)
}
/* ============================================================= */
/* POSE ANALYZER (ALIGNMENT → CROP → SEGMENT) */
/* ============================================================= */
class MockPoseAnalyzer : PoseAnalyzer {
private val segmenter by lazy {
val options = SubjectSegmenterOptions.Builder()
SubjectSegmentation.getClient(
SubjectSegmenterOptions.Builder()
.enableForegroundConfidenceMask()
.build()
SubjectSegmentation.getClient(options)
)
}
override suspend fun analyze(input: PipelineInput): Instruction {
@ -236,130 +210,124 @@ class MockPoseAnalyzer : PoseAnalyzer {
val detection = input.previousDetectionResult
?: return Instruction("No detection", isValid = false)
val bounds = detection.animalBounds
?: return Instruction("Animal not detected", isValid = false)
val cowBox = detection.animalBounds
?: return Instruction("Cow not detected", isValid = false)
val image = input.image
?: return Instruction("No image", isValid = false)
// --------------------------------------------------------------------
// 1. Reference silhouette (FloatArray)
// --------------------------------------------------------------------
val reference = SilhouetteManager.getWeightedMask(input.orientation)
val silhouette = SilhouetteManager.getSilhouette(input.orientation)
?: return Instruction("Silhouette missing", isValid = false)
val refH = reference.mask.size
val refW = reference.mask[0].size
val align = checkAlignment(cowBox, silhouette.boundingBox, 0.15f)
if (align.issue != AlignmentIssue.OK) {
return alignmentToInstruction(align)
}
// --------------------------------------------------------------------
// 2. Crop only animal region (BIG WIN)
// --------------------------------------------------------------------
val cropped = Bitmap.createBitmap(
image,
bounds.left.toInt(),
bounds.top.toInt(),
bounds.width().toInt(),
bounds.height().toInt()
cowBox.left.toInt(),
cowBox.top.toInt(),
cowBox.width().toInt(),
cowBox.height().toInt()
)
val scaled = Bitmap.createScaledBitmap(cropped, refW, refH, true)
val resized = Bitmap.createScaledBitmap(
cropped,
silhouette.croppedBitmap.width,
silhouette.croppedBitmap.height,
true
)
// --------------------------------------------------------------------
// 3. Get binary segmentation mask (ByteArray)
// --------------------------------------------------------------------
val mask = getAnimalMaskFast(scaled)
val mask = segment(resized)
?: return Instruction("Segmentation failed", isValid = false)
// --------------------------------------------------------------------
// 4. Fast signed similarity
// --------------------------------------------------------------------
val score = calculateSignedSimilarityFast(mask, reference)
val score = similarity(mask, silhouette.signedMask)
val valid = score >= 0.40f
return Instruction(
message = if (valid) "Pose Correct" else "Pose Incorrect (%.2f)".format(score),
message = if (valid) "Pose Correct" else "Adjust Position",
isValid = valid,
result = detection
)
}
/**
* ML Kit Binary mask (0/1) using ByteArray
*/
private suspend fun getAnimalMaskFast(bitmap: Bitmap): ByteArray? =
private suspend fun segment(bitmap: Bitmap): ByteArray? =
suspendCancellableCoroutine { cont ->
val image = InputImage.fromBitmap(bitmap, 0)
segmenter.process(image)
.addOnSuccessListener { result ->
val mask = result.foregroundConfidenceMask
?: run {
cont.resume(null)
return@addOnSuccessListener
}
val buffer = mask
buffer.rewind()
segmenter.process(InputImage.fromBitmap(bitmap, 0))
.addOnSuccessListener { r ->
val buf = r.foregroundConfidenceMask
?: return@addOnSuccessListener cont.resume(null)
buf.rewind()
val out = ByteArray(bitmap.width * bitmap.height)
for (i in out.indices) {
out[i] = if (buffer.get() > 0.5f) 1 else 0
}
for (i in out.indices) out[i] = if (buf.get() > 0.5f) 1 else 0
cont.resume(out)
}
.addOnFailureListener {
cont.resume(null)
.addOnFailureListener { cont.resume(null) }
}
private fun similarity(mask: ByteArray, ref: SignedMask): Float {
var s = 0f
var i = 0
for (row in ref.mask)
for (v in row)
s += mask[i++] * v
return if (ref.maxValue == 0f) 0f else s / ref.maxValue
}
}
/**
* Signed similarity (flat arrays, cache-friendly)
* Range [-1, 1]
*/
private fun calculateSignedSimilarityFast(
mask: ByteArray,
reference: SignedMask
): Float {
/* ============================================================= */
/* ALIGNMENT HELPERS */
/* ============================================================= */
var score = 0f
val maxScore = reference.maxValue
var idx = 0
enum class AlignmentIssue { TOO_SMALL, TOO_LARGE, MOVE_LEFT, MOVE_RIGHT, MOVE_UP, MOVE_DOWN, OK }
for (y in reference.mask.indices) {
val row = reference.mask[y]
for (x in row.indices) {
val r = row[x]
score += mask[idx++] * r
data class AlignmentResult(val issue: AlignmentIssue, val scale: Float, val dx: Float, val dy: Float)
fun checkAlignment(d: RectF, s: RectF, tol: Float): AlignmentResult {
val scale = min(d.width() / s.width(), d.height() / s.height())
val dx = d.centerX() - s.centerX()
val dy = d.centerY() - s.centerY()
if (scale < 1f - tol) return AlignmentResult(AlignmentIssue.TOO_SMALL, scale, dx, dy)
if (scale > 1f + tol) return AlignmentResult(AlignmentIssue.TOO_LARGE, scale, dx, dy)
val tx = s.width() * tol
val ty = s.height() * tol
return when {
dx > tx -> AlignmentResult(AlignmentIssue.MOVE_LEFT, scale, dx, dy)
dx < -tx -> AlignmentResult(AlignmentIssue.MOVE_RIGHT, scale, dx, dy)
dy > ty -> AlignmentResult(AlignmentIssue.MOVE_UP, scale, dx, dy)
dy < -ty -> AlignmentResult(AlignmentIssue.MOVE_DOWN, scale, dx, dy)
else -> AlignmentResult(AlignmentIssue.OK, scale, dx, dy)
}
}
return if (maxScore == 0f) 0f else score / maxScore
}
fun alignmentToInstruction(a: AlignmentResult) = when (a.issue) {
AlignmentIssue.TOO_SMALL -> Instruction("Move closer", isValid = false)
AlignmentIssue.TOO_LARGE -> Instruction("Move backward", isValid = false)
AlignmentIssue.MOVE_LEFT -> Instruction("Move right", isValid = false)
AlignmentIssue.MOVE_RIGHT -> Instruction("Move left", isValid = false)
AlignmentIssue.MOVE_UP -> Instruction("Move down", isValid = false)
AlignmentIssue.MOVE_DOWN -> Instruction("Move up", isValid = false)
AlignmentIssue.OK -> Instruction("Hold steady", isValid = true)
}
/* ============================================================= */
/* CAPTURE + MEASUREMENT (UNCHANGED) */
/* ============================================================= */
class DefaultCaptureHandler : CaptureHandler {
override suspend fun capture(input: PipelineInput, detectionResult: DetectionResult): CaptureData {
val image = input.image ?: throw IllegalStateException("Image cannot be null during capture")
val segmentationMask = BooleanArray(100) { true }
val animalMetrics = ObjectMetrics(
relativeHeight = 0.5f,
relativeWidth = 0.3f,
distance = 1.2f
)
return CaptureData(
image = image,
segmentationMask = segmentationMask,
animalMetrics = animalMetrics,
override suspend fun capture(input: PipelineInput, detectionResult: DetectionResult): CaptureData =
CaptureData(
image = input.image!!,
segmentationMask = BooleanArray(0),
animalMetrics = ObjectMetrics(0f, 0f, 1f),
referenceObjects = detectionResult.referenceObjects
)
}
}
class DefaultMeasurementCalculator : MeasurementCalculator {
override fun calculateRealMetrics(
@ -367,18 +335,16 @@ class DefaultMeasurementCalculator : MeasurementCalculator {
referenceObject: ReferenceObject,
currentMetrics: ObjectMetrics
): RealWorldMetrics {
if (referenceObject.relativeHeight == 0f) return RealWorldMetrics(0f, 0f, 0f)
if (referenceObject.relativeHeight == 0f)
return RealWorldMetrics(0f, 0f, 0f)
val scale = targetHeight / referenceObject.relativeHeight
val realHeight = currentMetrics.relativeHeight * scale
val realWidth = currentMetrics.relativeWidth * scale
val realDistance = currentMetrics.distance
return RealWorldMetrics(
height = realHeight,
width = realWidth,
distance = realDistance
height = currentMetrics.relativeHeight * scale,
width = currentMetrics.relativeWidth * scale,
distance = currentMetrics.distance
)
}
}

View File

@ -55,7 +55,7 @@ data class ReferenceObject(
val bounds: RectF,
val relativeHeight: Float,
val relativeWidth: Float,
val distance: Float
val distance: Float? = null
)
enum class CameraOrientation {

View File

@ -1,9 +1,7 @@
package com.example.livingai.utils
import android.content.Context
import android.graphics.Bitmap
import android.graphics.BitmapFactory
import android.graphics.Color
import android.graphics.*
import android.util.Log
import com.example.livingai.R
import java.util.concurrent.ConcurrentHashMap
@ -14,20 +12,25 @@ data class SignedMask(
val maxValue: Float
)
data class SilhouetteData(
val croppedBitmap: Bitmap,
val boundingBox: RectF,
val signedMask: SignedMask
)
object SilhouetteManager {
private val originals = ConcurrentHashMap<String, Bitmap>()
private val invertedPurple = ConcurrentHashMap<String, Bitmap>()
private val weightedMasks = ConcurrentHashMap<String, SignedMask>()
private val silhouettes = ConcurrentHashMap<String, SilhouetteData>()
fun getOriginal(name: String): Bitmap? = originals[name]
fun getInvertedPurple(name: String): Bitmap? = invertedPurple[name]
fun getWeightedMask(name: String): SignedMask? = weightedMasks[name]
fun getSilhouette(name: String): SilhouetteData? = silhouettes[name]
fun initialize(context: Context, width: Int, height: Int) {
val resources = context.resources
fun initialize(context: Context, screenW: Int, screenH: Int) {
val silhouetteList = mapOf(
val res = context.resources
val map = mapOf(
"front" to R.drawable.front_silhouette,
"back" to R.drawable.back_silhouette,
"left" to R.drawable.left_silhouette,
@ -37,58 +40,94 @@ object SilhouetteManager {
"angleview" to R.drawable.angleview_silhouette
)
silhouetteList.forEach { (name, resId) ->
val bmp = BitmapFactory.decodeResource(resources, resId)
originals[name] = bmp
map.forEach { (name, resId) ->
// Fit image appropriately (front/back = W/H, others rotated)
val fitted = if (name == "front" || name == "back")
createInvertedPurpleBitmap(bmp, width, height)
else
createInvertedPurpleBitmap(bmp, height, width)
val src = BitmapFactory.decodeResource(res, resId)
originals[name] = src
invertedPurple[name] = fitted
val fitted = Bitmap.createScaledBitmap(
invertToPurple(src),
screenW,
screenH,
true
)
weightedMasks[name] = createSignedWeightedMask(fitted, fadeInside = 10, fadeOutside = 20)
val bbox = computeBoundingBox(fitted)
Log.d("Silhouette", "Loaded mask: $name (${fitted.width} x ${fitted.height})")
val cropped = Bitmap.createBitmap(
fitted,
bbox.left.toInt(),
bbox.top.toInt(),
bbox.width().toInt(),
bbox.height().toInt()
)
val signedMask = createSignedWeightedMask(cropped)
silhouettes[name] = SilhouetteData(
croppedBitmap = cropped,
boundingBox = bbox,
signedMask = signedMask
)
Log.d("Silhouette", "Loaded $name (${bbox.width()} x ${bbox.height()})")
}
}
// ------------------------------------------------------------------------
// STEP 1: Create "inverted purple" mask (transparent object becomes purple)
// ------------------------------------------------------------------------
private fun createInvertedPurpleBitmap(
src: Bitmap,
targetWidth: Int,
targetHeight: Int
): Bitmap {
/* ---------------------------------------------------------- */
private fun invertToPurple(src: Bitmap): Bitmap {
val w = src.width
val h = src.height
val pixels = IntArray(w * h)
src.getPixels(pixels, 0, w, 0, 0, w, h)
val purple = Color.argb(255, 128, 0, 128)
for (i in pixels.indices) {
val alpha = pixels[i] ushr 24
pixels[i] = if (alpha == 0) purple else 0x00000000
pixels[i] =
if ((pixels[i] ushr 24) == 0) purple
else 0x00000000
}
val inverted = Bitmap.createBitmap(pixels, w, h, Bitmap.Config.ARGB_8888)
return Bitmap.createScaledBitmap(inverted, targetWidth, targetHeight, true)
return Bitmap.createBitmap(pixels, w, h, Bitmap.Config.ARGB_8888)
}
/**
* Creates a signed weighted mask in range [-1, +1]
*
* +1 : deep inside object
* 0 : object boundary
* -1 : far outside object
*/
private fun computeBoundingBox(bitmap: Bitmap): RectF {
val w = bitmap.width
val h = bitmap.height
val pixels = IntArray(w * h)
bitmap.getPixels(pixels, 0, w, 0, 0, w, h)
var minX = w
var minY = h
var maxX = 0
var maxY = 0
for (y in 0 until h) {
for (x in 0 until w) {
if ((pixels[y * w + x] ushr 24) > 0) {
minX = min(minX, x)
minY = min(minY, y)
maxX = maxOf(maxX, x)
maxY = maxOf(maxY, y)
}
}
}
return RectF(
minX.toFloat(),
minY.toFloat(),
maxX.toFloat(),
maxY.toFloat()
)
}
/* ---------------------------------------------------------- */
/* SIGNED WEIGHTED MASK */
/* ---------------------------------------------------------- */
fun createSignedWeightedMask(
bitmap: Bitmap,
fadeInside: Int = 10,
@ -101,96 +140,54 @@ object SilhouetteManager {
val pixels = IntArray(w * h)
bitmap.getPixels(pixels, 0, w, 0, 0, w, h)
val inside = IntArray(w * h)
for (i in pixels.indices)
inside[i] = if ((pixels[i] ushr 24) > 0) 1 else 0
fun idx(x: Int, y: Int) = y * w + x
// --------------------------------------------------------------------
// 1. Binary mask
// --------------------------------------------------------------------
val inside = IntArray(w * h)
for (i in pixels.indices) {
inside[i] = if ((pixels[i] ushr 24) > 0) 1 else 0
val distIn = IntArray(w * h) { Int.MAX_VALUE }
val distOut = IntArray(w * h) { Int.MAX_VALUE }
for (i in inside.indices) {
if (inside[i] == 0) distIn[i] = 0
else distOut[i] = 0
}
// --------------------------------------------------------------------
// 2. Distance transform (inside → outside)
// --------------------------------------------------------------------
val distInside = IntArray(w * h) { Int.MAX_VALUE }
for (i in inside.indices) if (inside[i] == 0) distInside[i] = 0
for (y in 0 until h) {
for (y in 0 until h)
for (x in 0 until w) {
val i = idx(x, y)
var best = distInside[i]
if (x > 0) best = min(best, distInside[idx(x - 1, y)] + 1)
if (y > 0) best = min(best, distInside[idx(x, y - 1)] + 1)
distInside[i] = best
}
if (x > 0) distIn[i] = min(distIn[i], distIn[idx(x - 1, y)] + 1)
if (y > 0) distIn[i] = min(distIn[i], distIn[idx(x, y - 1)] + 1)
if (x > 0) distOut[i] = min(distOut[i], distOut[idx(x - 1, y)] + 1)
if (y > 0) distOut[i] = min(distOut[i], distOut[idx(x, y - 1)] + 1)
}
for (y in h - 1 downTo 0) {
for (y in h - 1 downTo 0)
for (x in w - 1 downTo 0) {
val i = idx(x, y)
var best = distInside[i]
if (x < w - 1) best = min(best, distInside[idx(x + 1, y)] + 1)
if (y < h - 1) best = min(best, distInside[idx(x, y + 1)] + 1)
distInside[i] = best
}
if (x < w - 1) distIn[i] = min(distIn[i], distIn[idx(x + 1, y)] + 1)
if (y < h - 1) distIn[i] = min(distIn[i], distIn[idx(x, y + 1)] + 1)
if (x < w - 1) distOut[i] = min(distOut[i], distOut[idx(x + 1, y)] + 1)
if (y < h - 1) distOut[i] = min(distOut[i], distOut[idx(x, y + 1)] + 1)
}
// --------------------------------------------------------------------
// 3. Distance transform (outside → inside)
// --------------------------------------------------------------------
val distOutside = IntArray(w * h) { Int.MAX_VALUE }
for (i in inside.indices) if (inside[i] == 1) distOutside[i] = 0
val mask = Array(h) { FloatArray(w) }
var maxVal = Float.NEGATIVE_INFINITY
for (y in 0 until h) {
for (y in 0 until h)
for (x in 0 until w) {
val i = idx(x, y)
var best = distOutside[i]
if (x > 0) best = min(best, distOutside[idx(x - 1, y)] + 1)
if (y > 0) best = min(best, distOutside[idx(x, y - 1)] + 1)
distOutside[i] = best
}
val v =
if (inside[i] == 1)
min(1f, distIn[i].toFloat() / fadeInside)
else
maxOf(-1f, -distOut[i].toFloat() / fadeOutside)
mask[y][x] = v
if (v > maxVal) maxVal = v
}
for (y in h - 1 downTo 0) {
for (x in w - 1 downTo 0) {
val i = idx(x, y)
var best = distOutside[i]
if (x < w - 1) best = min(best, distOutside[idx(x + 1, y)] + 1)
if (y < h - 1) best = min(best, distOutside[idx(x, y + 1)] + 1)
distOutside[i] = best
return SignedMask(mask, maxVal)
}
}
// --------------------------------------------------------------------
// 4. Build signed mask + track max value
// --------------------------------------------------------------------
val result = Array(h) { FloatArray(w) }
var maxValue = Float.NEGATIVE_INFINITY
for (y in 0 until h) {
for (x in 0 until w) {
val i = idx(x, y)
val weight = if (inside[i] == 1) {
val d = distInside[i]
if (d >= fadeInside) 1f
else d.toFloat() / fadeInside
} else {
val d = distOutside[i]
(-d.toFloat() / fadeOutside).coerceAtLeast(-1f)
}
result[y][x] = weight
if (weight > maxValue) maxValue = weight
}
}
return SignedMask(
mask = result,
maxValue = maxValue
)
}
}