Merge sort a linked list

Hi,

This is my code for merge sorting a linked list. My code works for the case of a linked list of two elements, but in the case of four elements, I got a never-ending loop. Could anyone point out what my problem is? This is from LxxxCode Problem 148.

# Definition for singly-linked list.
class ListNode:
    def __init__(self, val=0, next=None):
        self.val = val
        self.next = next
class Solution:
    #def sortList(self, head: Optional[ListNode]) -> Optional[ListNode]:
    def sortList(self, head):
        def midpoint(head): # Find the midpoint of a linked list given the head; the end of the first half
            fast = slow = head
            while(fast and fast.next):
                slow = slow.next
                fast = fast.next.next
            return slow  # Note that slow.next is the head of the second half
        
        def twolist(head): # break the linked list into the first half and the second half; return heads of the first half and the second half
            mid1 = midpoint(head)
            head2 = mid1.next
            mid1.next = None # separte the first half from the second half
            return head, head2
        
        def merge_sorted(head1, head2): # merge two sorted L.L.
            pt = dummy = ListNode(-1) # pt: pointer
            if head1.val < head2.val:
                dummy.next = head1
            else:
                dummy.next = head2
            pt = dummy.next    
            
            while head1 and head2:
                if head1.val < head2.val:
                    pt.next = head1
                    head1 = head1.next
                else:
                    pt.next = head2
                    head2 = head2.next
                pt = pt.next
                
            if head1:
                pt.next = head1
            else:
                pt.next = head2
            return dummy.next
        
        # start the merge sort function
        if not head:  # 0 element
            return head
        if not head.next: # 1 element
            return head
        if not head.next.next: # 2 elements
            if head.val < head.next.val: # right order
                return head
            else:   # wrong order
                dummy = ListNode(-1)
                dummy.next = head.next 
                head.next.next = head
                head.next = None
            return dummy.next
        
        # 3 elements or more
        first_half_head, second_half_head = twolist(head)
        first_half_sorted_head = self.sortList(first_half_head)
        second_half_sorted_head = self.sortList(second_half_head)
        return merge_sorted(first_half_sorted_head, second_half_sorted_head)
# Try this
sol1 = Solution()
e1 = ListNode(4)
e2 = ListNode(2)
e3 = ListNode(1)
e4 = ListNode(3)
e1.next = e2
e2.next = e3
e3.next = e4
es = sol1.sortList(e1)
print(es.val)
print(es.next.val)
print(es.next.next.val)
print(es.next.next.next.val)
            

Thank you. My code eventually worked after two minor modifications:

# Definition for singly-linked list.
# class ListNode:
#     def __init__(self, val=0, next=None):
#         self.val = val
#         self.next = next
class Solution:
    #def sortList(self, head: Optional[ListNode]) -> Optional[ListNode]:
    def sortList(self, head):
        def midpoint(head): # Find the midpoint of a linked list given the head; the end of the first half
            slow = head
            fast = head.next  # Without this line, I will get the head of the second half, not the tail of the first half
            while(fast and fast.next):
                slow = slow.next
                fast = fast.next.next
            return slow  # Note that slow.next is the head of the second half
        
        def twolist(head): # break the linked list into the first half and the second half; return heads of the first half and the second half
            mid1 = midpoint(head)
            head2 = mid1.next
            mid1.next = None # separte the first half from the second half
            return head, head2
        
        def merge_sorted(head1, head2): # merge two sorted L.L.
            if not head1:
                return head2
            if not head2:
                return head1
            pt = dummy = ListNode(-1) # pt: pointer
            if head1.val < head2.val:
                dummy.next = head1
                head1 = head1.next  # my answer did not run because I lack this line!!
            else:
                dummy.next = head2
                head2 = head2.next # my answer did not run because I lack this line!!
            pt = dummy.next    
            
            while head1 and head2:
                if head1.val < head2.val:
                    pt.next = head1
                    head1 = head1.next
                else:
                    pt.next = head2
                    head2 = head2.next
                pt = pt.next
                
            if head1:
                pt.next = head1
            else:
                pt.next = head2
            return dummy.next
        
        # start the merge sort function
        if not head:  # 0 element
            return head
        if not head.next: # 1 element
            return head
        if not head.next.next: # 2 elements
            if head.val < head.next.val: # right order
                return head
            else:   # wrong order
                dummy = ListNode(-1)
                dummy.next = head.next 
                head.next.next = head
                head.next = None
            return dummy.next
        # 3 elements or more
        first_half_head, second_half_head = twolist(head)
        first_half_sorted_head = self.sortList(first_half_head)
        second_half_sorted_head = self.sortList(second_half_head)
        return merge_sorted(first_half_sorted_head, second_half_sorted_head)