summaryrefslogtreecommitdiff
path: root/drivers/misc/habanalabs/common/command_submission.c
diff options
context:
space:
mode:
Diffstat (limited to 'drivers/misc/habanalabs/common/command_submission.c')
-rw-r--r--drivers/misc/habanalabs/common/command_submission.c104
1 files changed, 70 insertions, 34 deletions
diff --git a/drivers/misc/habanalabs/common/command_submission.c b/drivers/misc/habanalabs/common/command_submission.c
index 7b0516cf808b..6dafff375f1c 100644
--- a/drivers/misc/habanalabs/common/command_submission.c
+++ b/drivers/misc/habanalabs/common/command_submission.c
@@ -405,7 +405,7 @@ static void staged_cs_put(struct hl_device *hdev, struct hl_cs *cs)
static void cs_handle_tdr(struct hl_device *hdev, struct hl_cs *cs)
{
bool next_entry_found = false;
- struct hl_cs *next;
+ struct hl_cs *next, *first_cs;
if (!cs_needs_timeout(cs))
return;
@@ -415,9 +415,16 @@ static void cs_handle_tdr(struct hl_device *hdev, struct hl_cs *cs)
/* We need to handle tdr only once for the complete staged submission.
* Hence, we choose the CS that reaches this function first which is
* the CS marked as 'staged_last'.
+ * In case single staged cs was submitted which has both first and last
+ * indications, then "cs_find_first" below will return NULL, since we
+ * removed the cs node from the list before getting here,
+ * in such cases just continue with the cs to cancel it's TDR work.
*/
- if (cs->staged_cs && cs->staged_last)
- cs = hl_staged_cs_find_first(hdev, cs->staged_sequence);
+ if (cs->staged_cs && cs->staged_last) {
+ first_cs = hl_staged_cs_find_first(hdev, cs->staged_sequence);
+ if (first_cs)
+ cs = first_cs;
+ }
spin_unlock(&hdev->cs_mirror_lock);
@@ -1288,6 +1295,12 @@ static int cs_ioctl_default(struct hl_fpriv *hpriv, void __user *chunks,
if (rc)
goto free_cs_object;
+ /* If this is a staged submission we must return the staged sequence
+ * rather than the internal CS sequence
+ */
+ if (cs->staged_cs)
+ *cs_seq = cs->staged_sequence;
+
/* Validate ALL the CS chunks before submitting the CS */
for (i = 0 ; i < num_chunks ; i++) {
struct hl_cs_chunk *chunk = &cs_chunk_array[i];
@@ -1988,6 +2001,15 @@ static int cs_ioctl_signal_wait(struct hl_fpriv *hpriv, enum hl_cs_type cs_type,
goto free_cs_chunk_array;
}
+ if (!hdev->nic_ports_mask) {
+ atomic64_inc(&ctx->cs_counters.validation_drop_cnt);
+ atomic64_inc(&cntr->validation_drop_cnt);
+ dev_err(hdev->dev,
+ "Collective operations not supported when NIC ports are disabled");
+ rc = -EINVAL;
+ goto free_cs_chunk_array;
+ }
+
collective_engine_id = chunk->collective_engine_id;
}
@@ -2026,9 +2048,10 @@ static int cs_ioctl_signal_wait(struct hl_fpriv *hpriv, enum hl_cs_type cs_type,
spin_unlock(&ctx->sig_mgr.lock);
if (!handle_found) {
- dev_err(hdev->dev, "Cannot find encapsulated signals handle for seq 0x%llx\n",
+ /* treat as signal CS already finished */
+ dev_dbg(hdev->dev, "Cannot find encapsulated signals handle for seq 0x%llx\n",
signal_seq);
- rc = -EINVAL;
+ rc = 0;
goto free_cs_chunk_array;
}
@@ -2613,7 +2636,8 @@ static int hl_multi_cs_wait_ioctl(struct hl_fpriv *hpriv, void *data)
* completed after the poll function.
*/
if (!mcs_data.completion_bitmap) {
- dev_err(hdev->dev, "Multi-CS got completion on wait but no CS completed\n");
+ dev_warn_ratelimited(hdev->dev,
+ "Multi-CS got completion on wait but no CS completed\n");
rc = -EFAULT;
}
}
@@ -2625,11 +2649,18 @@ put_ctx:
free_seq_arr:
kfree(cs_seq_arr);
- /* update output args */
- memset(args, 0, sizeof(*args));
if (rc)
return rc;
+ if (mcs_data.wait_status == -ERESTARTSYS) {
+ dev_err_ratelimited(hdev->dev,
+ "user process got signal while waiting for Multi-CS\n");
+ return -EINTR;
+ }
+
+ /* update output args */
+ memset(args, 0, sizeof(*args));
+
if (mcs_data.completion_bitmap) {
args->out.status = HL_WAIT_CS_STATUS_COMPLETED;
args->out.cs_completion_map = mcs_data.completion_bitmap;
@@ -2643,8 +2674,6 @@ free_seq_arr:
/* update if some CS was gone */
if (mcs_data.timestamp)
args->out.flags |= HL_WAIT_CS_STATUS_FLAG_GONE;
- } else if (mcs_data.wait_status == -ERESTARTSYS) {
- args->out.status = HL_WAIT_CS_STATUS_INTERRUPTED;
} else {
args->out.status = HL_WAIT_CS_STATUS_BUSY;
}
@@ -2664,16 +2693,17 @@ static int hl_cs_wait_ioctl(struct hl_fpriv *hpriv, void *data)
rc = _hl_cs_wait_ioctl(hdev, hpriv->ctx, args->in.timeout_us, seq,
&status, &timestamp);
+ if (rc == -ERESTARTSYS) {
+ dev_err_ratelimited(hdev->dev,
+ "user process got signal while waiting for CS handle %llu\n",
+ seq);
+ return -EINTR;
+ }
+
memset(args, 0, sizeof(*args));
if (rc) {
- if (rc == -ERESTARTSYS) {
- dev_err_ratelimited(hdev->dev,
- "user process got signal while waiting for CS handle %llu\n",
- seq);
- args->out.status = HL_WAIT_CS_STATUS_INTERRUPTED;
- rc = -EINTR;
- } else if (rc == -ETIMEDOUT) {
+ if (rc == -ETIMEDOUT) {
dev_err_ratelimited(hdev->dev,
"CS %llu has timed-out while user process is waiting for it\n",
seq);
@@ -2740,10 +2770,20 @@ static int _hl_interrupt_wait_ioctl(struct hl_device *hdev, struct hl_ctx *ctx,
else
interrupt = &hdev->user_interrupt[interrupt_offset];
+ /* Add pending user interrupt to relevant list for the interrupt
+ * handler to monitor
+ */
+ spin_lock_irqsave(&interrupt->wait_list_lock, flags);
+ list_add_tail(&pend->wait_list_node, &interrupt->wait_list_head);
+ spin_unlock_irqrestore(&interrupt->wait_list_lock, flags);
+
+ /* We check for completion value as interrupt could have been received
+ * before we added the node to the wait list
+ */
if (copy_from_user(&completion_value, u64_to_user_ptr(user_address), 4)) {
dev_err(hdev->dev, "Failed to copy completion value from user\n");
rc = -EFAULT;
- goto free_fence;
+ goto remove_pending_user_interrupt;
}
if (completion_value >= target_value)
@@ -2752,14 +2792,7 @@ static int _hl_interrupt_wait_ioctl(struct hl_device *hdev, struct hl_ctx *ctx,
*status = CS_WAIT_STATUS_BUSY;
if (!timeout_us || (*status == CS_WAIT_STATUS_COMPLETED))
- goto free_fence;
-
- /* Add pending user interrupt to relevant list for the interrupt
- * handler to monitor
- */
- spin_lock_irqsave(&interrupt->wait_list_lock, flags);
- list_add_tail(&pend->wait_list_node, &interrupt->wait_list_head);
- spin_unlock_irqrestore(&interrupt->wait_list_lock, flags);
+ goto remove_pending_user_interrupt;
wait_again:
/* Wait for interrupt handler to signal completion */
@@ -2770,6 +2803,15 @@ wait_again:
* If comparison fails, keep waiting until timeout expires
*/
if (completion_rc > 0) {
+ spin_lock_irqsave(&interrupt->wait_list_lock, flags);
+ /* reinit_completion must be called before we check for user
+ * completion value, otherwise, if interrupt is received after
+ * the comparison and before the next wait_for_completion,
+ * we will reach timeout and fail
+ */
+ reinit_completion(&pend->fence.completion);
+ spin_unlock_irqrestore(&interrupt->wait_list_lock, flags);
+
if (copy_from_user(&completion_value, u64_to_user_ptr(user_address), 4)) {
dev_err(hdev->dev, "Failed to copy completion value from user\n");
rc = -EFAULT;
@@ -2780,18 +2822,13 @@ wait_again:
if (completion_value >= target_value) {
*status = CS_WAIT_STATUS_COMPLETED;
} else {
- spin_lock_irqsave(&interrupt->wait_list_lock, flags);
- reinit_completion(&pend->fence.completion);
timeout = completion_rc;
-
- spin_unlock_irqrestore(&interrupt->wait_list_lock, flags);
goto wait_again;
}
} else if (completion_rc == -ERESTARTSYS) {
dev_err_ratelimited(hdev->dev,
"user process got signal while waiting for interrupt ID %d\n",
interrupt->interrupt_id);
- *status = HL_WAIT_CS_STATUS_INTERRUPTED;
rc = -EINTR;
} else {
*status = CS_WAIT_STATUS_BUSY;
@@ -2802,7 +2839,6 @@ remove_pending_user_interrupt:
list_del(&pend->wait_list_node);
spin_unlock_irqrestore(&interrupt->wait_list_lock, flags);
-free_fence:
kfree(pend);
hl_ctx_put(ctx);
@@ -2847,8 +2883,6 @@ static int hl_interrupt_wait_ioctl(struct hl_fpriv *hpriv, void *data)
args->in.interrupt_timeout_us, args->in.addr,
args->in.target, interrupt_offset, &status);
- memset(args, 0, sizeof(*args));
-
if (rc) {
if (rc != -EINTR)
dev_err_ratelimited(hdev->dev,
@@ -2857,6 +2891,8 @@ static int hl_interrupt_wait_ioctl(struct hl_fpriv *hpriv, void *data)
return rc;
}
+ memset(args, 0, sizeof(*args));
+
switch (status) {
case CS_WAIT_STATUS_COMPLETED:
args->out.status = HL_WAIT_CS_STATUS_COMPLETED;